diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc index a3edcc073e..499034cb89 100644 --- a/src/net/tools/naive/naive_proxy_bin.cc +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -16,6 +16,7 @@ #include "base/json/json_writer.h" #include "base/logging.h" #include "base/macros.h" +#include "base/rand_util.h" #include "base/run_loop.h" #include "base/strings/string16.h" #include "base/strings/string_number_conversions.h" @@ -37,6 +38,8 @@ #include "net/http/http_auth.h" #include "net/http/http_auth_cache.h" #include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" #include "net/http/http_transaction_factory.h" #include "net/log/file_net_log_observer.h" #include "net/log/net_log.h" @@ -80,6 +83,7 @@ struct CommandLine { std::string listen; std::string proxy; bool padding; + std::string extra_headers; std::string host_resolver_rules; std::string resolver_range; bool no_log; @@ -93,6 +97,7 @@ struct Params { std::string listen_addr; int listen_port; bool use_padding; + net::HttpRequestHeaders extra_headers; std::string proxy_url; base::string16 proxy_user; base::string16 proxy_pass; @@ -118,79 +123,6 @@ std::unique_ptr GetConstants( return constants_dict; } -std::unique_ptr BuildCertURLRequestContext( - net::NetLog* net_log) { - net::URLRequestContextBuilder builder; - - builder.DisableHttpCache(); - builder.set_net_log(net_log); - - net::ProxyConfig proxy_config; - auto proxy_service = - net::ConfiguredProxyResolutionService::CreateWithoutProxyResolver( - std::make_unique( - net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), - net_log); - proxy_service->ForceReloadProxyConfig(); - builder.set_proxy_resolution_service(std::move(proxy_service)); - - return builder.Build(); -} - -// Builds a URLRequestContext assuming there's only a single loop. -std::unique_ptr BuildURLRequestContext( - const Params& params, - scoped_refptr cert_net_fetcher, - net::NetLog* net_log) { - net::URLRequestContextBuilder builder; - - builder.DisableHttpCache(); - builder.set_net_log(net_log); - - net::ProxyConfig proxy_config; - proxy_config.proxy_rules().ParseFromString(params.proxy_url); - LOG(INFO) << "Proxying via " << params.proxy_url; - auto proxy_service = - net::ConfiguredProxyResolutionService::CreateWithoutProxyResolver( - std::make_unique( - net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), - net_log); - proxy_service->ForceReloadProxyConfig(); - builder.set_proxy_resolution_service(std::move(proxy_service)); - - if (!params.host_resolver_rules.empty()) { - builder.set_host_mapping_rules(params.host_resolver_rules); - } - - builder.SetCertVerifier( - net::CertVerifier::CreateDefault(std::move(cert_net_fetcher))); - - auto context = builder.Build(); - - if (!params.proxy_url.empty() && !params.proxy_user.empty() && - !params.proxy_pass.empty()) { - auto* session = context->http_transaction_factory()->GetSession(); - auto* auth_cache = session->http_auth_cache(); - std::string proxy_url = params.proxy_url; - if (proxy_url.compare(0, 7, "quic://") == 0) { - proxy_url.replace(0, 4, "https"); - auto* quic = context->quic_context()->params(); - const auto& versions = quic::SupportedVersions(); - quic->supported_versions.assign(versions.begin(), versions.end()); - quic->origins_to_force_quic_on.insert( - net::HostPortPair::FromURL(GURL(proxy_url))); - } - GURL auth_origin(proxy_url); - net::AuthCredentials credentials(params.proxy_user, params.proxy_pass); - auth_cache->Add(auth_origin, net::HttpAuth::AUTH_PROXY, - /*realm=*/std::string(), net::HttpAuth::AUTH_SCHEME_BASIC, - net::NetworkIsolationKey(), /*challenge=*/"Basic", - credentials, /*path=*/"/"); - } - - return context; -} - void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { if (proc.HasSwitch("h") || proc.HasSwitch("help")) { std::cout << "Usage: naive { OPTIONS | config.json }\n" @@ -204,6 +136,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { "--proxy=://[:@][:]\n" " proto: https, quic\n" "--padding Use padding\n" + "--extra-headers=... Extra headers split by CRLF\n" "--host-resolver-rules=... Resolver rules\n" "--resolver-range=... Redirect resolver range\n" "--log[=] Log to stderr, or file\n" @@ -221,6 +154,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { cmdline->listen = proc.GetSwitchValueASCII("listen"); cmdline->proxy = proc.GetSwitchValueASCII("proxy"); cmdline->padding = proc.HasSwitch("padding"); + cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers"); cmdline->host_resolver_rules = proc.GetSwitchValueASCII("host-resolver-rules"); cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range"); @@ -245,36 +179,41 @@ void GetCommandLineFromConfig(const base::FilePath& config_path, std::cerr << "Invalid config format" << std::endl; exit(EXIT_FAILURE); } - if (value->FindKeyOfType("listen", base::Value::Type::STRING)) { - cmdline->listen = value->FindKey("listen")->GetString(); + const auto* listen = value->FindStringKey("listen"); + if (listen) { + cmdline->listen = *listen; } - if (value->FindKeyOfType("proxy", base::Value::Type::STRING)) { - cmdline->proxy = value->FindKey("proxy")->GetString(); + const auto* proxy = value->FindStringKey("proxy"); + if (proxy) { + cmdline->proxy = *proxy; } - cmdline->padding = false; - if (value->FindKeyOfType("padding", base::Value::Type::BOOLEAN)) { - cmdline->padding = value->FindKey("padding")->GetBool(); + cmdline->padding = value->FindBoolKey("padding").value_or(false); + const auto* extra_headers = value->FindStringKey("extra-headers"); + if (extra_headers) { + cmdline->extra_headers = *extra_headers; } - if (value->FindKeyOfType("host-resolver-rules", base::Value::Type::STRING)) { - cmdline->host_resolver_rules = - value->FindKey("host-resolver-rules")->GetString(); + const auto* host_resolver_rules = value->FindStringKey("host-resolver-rules"); + if (host_resolver_rules) { + cmdline->host_resolver_rules = *host_resolver_rules; } - if (value->FindKeyOfType("resolver-range", base::Value::Type::STRING)) { - cmdline->resolver_range = value->FindKey("resolver-range")->GetString(); + const auto* resolver_range = value->FindStringKey("resolver-range"); + if (resolver_range) { + cmdline->resolver_range = *resolver_range; } cmdline->no_log = true; - if (value->FindKeyOfType("log", base::Value::Type::STRING)) { + const auto* log = value->FindStringKey("log"); + if (log) { cmdline->no_log = false; - cmdline->log = - base::FilePath::FromUTF8Unsafe(value->FindKey("log")->GetString()); + cmdline->log = base::FilePath::FromUTF8Unsafe(*log); } - if (value->FindKeyOfType("log-net-log", base::Value::Type::STRING)) { - cmdline->log_net_log = base::FilePath::FromUTF8Unsafe( - value->FindKey("log-net-log")->GetString()); + const auto* log_net_log = value->FindStringKey("log-net-log"); + if (log_net_log) { + cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(*log_net_log); } - if (value->FindKeyOfType("ssl-key-log-file", base::Value::Type::STRING)) { - cmdline->ssl_key_log_file = base::FilePath::FromUTF8Unsafe( - value->FindKey("ssl-key-log-file")->GetString()); + const auto* ssl_key_log_file = value->FindStringKey("ssl-key-log-file"); + if (ssl_key_log_file) { + cmdline->ssl_key_log_file = + base::FilePath::FromUTF8Unsafe(*ssl_key_log_file); } } @@ -347,6 +286,8 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->use_padding = cmdline.padding; + params->extra_headers.AddHeadersFromString(cmdline.extra_headers); + params->host_resolver_rules = cmdline.host_resolver_rules; if (params->protocol == net::NaiveConnection::kRedir) { @@ -424,7 +365,110 @@ class PrintingLogObserver : public NetLog::ThreadSafeObserver { private: DISALLOW_COPY_AND_ASSIGN(PrintingLogObserver); }; +} // namespace +class ProxyInfo; +class ProxyServer; + +namespace { +class NaiveProxyDelegate : public ProxyDelegate { + public: + NaiveProxyDelegate(const Params& params) : params_(params) {} + void OnResolveProxy(const GURL& url, + const std::string& method, + const ProxyRetryInfoMap& proxy_retry_info, + ProxyInfo* result) override {} + void OnFallback(const ProxyServer& bad_proxy, int net_error) override {} + + void OnBeforeTunnelRequest(const ProxyServer& proxy_server, + HttpRequestHeaders* extra_headers) override { + extra_headers->SetHeader("Padding", + std::string(base::RandInt(16, 32), '.')); + extra_headers->MergeFrom(params_.extra_headers); + } + + Error OnTunnelHeadersReceived( + const ProxyServer& proxy_server, + const HttpResponseHeaders& response_headers) override { + return OK; + } + + private: + const Params& params_; +}; + +std::unique_ptr BuildCertURLRequestContext(NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + return builder.Build(); +} + +// Builds a URLRequestContext assuming there's only a single loop. +std::unique_ptr BuildURLRequestContext( + const Params& params, + scoped_refptr cert_net_fetcher, + NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString(params.proxy_url); + LOG(INFO) << "Proxying via " << params.proxy_url; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + if (!params.host_resolver_rules.empty()) { + builder.set_host_mapping_rules(params.host_resolver_rules); + } + + builder.SetCertVerifier( + CertVerifier::CreateDefault(std::move(cert_net_fetcher))); + + builder.set_proxy_delegate(std::make_unique(params)); + + auto context = builder.Build(); + + if (!params.proxy_url.empty() && !params.proxy_user.empty() && + !params.proxy_pass.empty()) { + auto* session = context->http_transaction_factory()->GetSession(); + auto* auth_cache = session->http_auth_cache(); + std::string proxy_url = params.proxy_url; + if (proxy_url.compare(0, 7, "quic://") == 0) { + proxy_url.replace(0, 4, "https"); + auto* quic = context->quic_context()->params(); + const auto& versions = quic::SupportedVersions(); + quic->supported_versions.assign(versions.begin(), versions.end()); + quic->origins_to_force_quic_on.insert( + net::HostPortPair::FromURL(GURL(proxy_url))); + } + GURL auth_origin(proxy_url); + AuthCredentials credentials(params.proxy_user, params.proxy_pass); + auth_cache->Add(auth_origin, HttpAuth::AUTH_PROXY, + /*realm=*/{}, HttpAuth::AUTH_SCHEME_BASIC, {}, + /*challenge=*/"Basic", credentials, /*path=*/"/"); + } + + return context; +} } // namespace } // namespace net @@ -497,14 +541,14 @@ int main(int argc, char* argv[]) { net::NetLogCaptureMode::kDefault); } - auto cert_context = BuildCertURLRequestContext(net_log); + auto cert_context = net::BuildCertURLRequestContext(net_log); scoped_refptr cert_net_fetcher; #if defined(OS_LINUX) || defined(OS_MAC) cert_net_fetcher = base::MakeRefCounted(); cert_net_fetcher->SetURLRequestContext(cert_context.get()); #endif auto context = - BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log); + net::BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log); auto* session = context->http_transaction_factory()->GetSession(); auto listen_socket =