diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc index 339990beb0..a2e51c9fc2 100644 --- a/src/net/tools/naive/naive_connection.cc +++ b/src/net/tools/naive/naive_connection.cc @@ -55,6 +55,7 @@ NaiveConnection::NaiveConnection( const SSLConfig& proxy_ssl_config, RedirectResolver* resolver, HttpNetworkSession* session, + const NetworkIsolationKey& network_isolation_key, const NetLogWithSource& net_log, std::unique_ptr accepted_socket, const NetworkTrafficAnnotationTag& traffic_annotation) @@ -66,6 +67,7 @@ NaiveConnection::NaiveConnection( proxy_ssl_config_(proxy_ssl_config), resolver_(resolver), session_(session), + network_isolation_key_(network_isolation_key), net_log_(net_log), next_state_(STATE_NONE), client_socket_(std::move(accepted_socket)), @@ -239,8 +241,9 @@ int NaiveConnection::DoConnectServer() { // Ignores socket limit set by socket pool for this type of socket. return InitSocketHandleForRawConnect2( origin, session_, LOAD_IGNORE_LIMITS, MAXIMUM_PRIORITY, proxy_info_, - server_ssl_config_, proxy_ssl_config_, PRIVACY_MODE_DISABLED, net_log_, - server_socket_handle_.get(), io_callback_); + server_ssl_config_, proxy_ssl_config_, PRIVACY_MODE_DISABLED, + network_isolation_key_, net_log_, server_socket_handle_.get(), + io_callback_); } int NaiveConnection::DoConnectServerComplete(int result) { diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h index ef9c7f8f11..d4a26e8641 100644 --- a/src/net/tools/naive/naive_connection.h +++ b/src/net/tools/naive/naive_connection.h @@ -28,6 +28,7 @@ class StreamSocket; struct NetworkTrafficAnnotationTag; struct SSLConfig; class RedirectResolver; +class NetworkIsolationKey; class NaiveConnection { public: @@ -55,6 +56,7 @@ class NaiveConnection { const SSLConfig& proxy_ssl_config, RedirectResolver* resolver, HttpNetworkSession* session, + const NetworkIsolationKey& network_isolation_key, const NetLogWithSource& net_log, std::unique_ptr accepted_socket, const NetworkTrafficAnnotationTag& traffic_annotation); @@ -107,6 +109,7 @@ class NaiveConnection { const SSLConfig& proxy_ssl_config_; RedirectResolver* resolver_; HttpNetworkSession* session_; + const NetworkIsolationKey& network_isolation_key_; const NetLogWithSource& net_log_; CompletionRepeatingCallback io_callback_; diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc index aeef545e0d..097d357a15 100644 --- a/src/net/tools/naive/naive_proxy.cc +++ b/src/net/tools/naive/naive_proxy.cc @@ -5,6 +5,7 @@ #include "net/tools/naive/naive_proxy.h" +#include #include #include "base/bind.h" @@ -28,12 +29,14 @@ namespace net { NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, NaiveConnection::Protocol protocol, bool use_padding, + int concurrency, RedirectResolver* resolver, HttpNetworkSession* session, const NetworkTrafficAnnotationTag& traffic_annotation) : listen_socket_(std::move(listen_socket)), protocol_(protocol), use_padding_(use_padding), + concurrency_(std::min(4, std::max(1, concurrency))), resolver_(resolver), session_(session), net_log_( @@ -54,6 +57,10 @@ NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, session_->GetSSLConfig(&server_ssl_config_, &proxy_ssl_config_); proxy_ssl_config_.disable_cert_verification_network_fetches = true; + for (int i = 0; i < concurrency_; i++) { + network_isolation_keys_.push_back(NetworkIsolationKey::CreateTransient()); + } + DCHECK(listen_socket_); // Start accepting connections in next run loop in case when delegate is not // ready to get callbacks. @@ -110,9 +117,11 @@ void NaiveProxy::DoConnect() { if (!use_padding_) { pad_direction = NaiveConnection::kNone; } + last_id_++; + const auto& nik = network_isolation_keys_[last_id_ % concurrency_]; auto connection_ptr = std::make_unique( - ++last_id_, protocol_, pad_direction, proxy_info_, server_ssl_config_, - proxy_ssl_config_, resolver_, session_, net_log_, std::move(socket), + last_id_, protocol_, pad_direction, proxy_info_, server_ssl_config_, + proxy_ssl_config_, resolver_, session_, nik, net_log_, std::move(socket), traffic_annotation_); auto* connection = connection_ptr.get(); connection_by_id_[connection->id()] = std::move(connection_ptr); diff --git a/src/net/tools/naive/naive_proxy.h b/src/net/tools/naive/naive_proxy.h index fbae124249..de7535483f 100644 --- a/src/net/tools/naive/naive_proxy.h +++ b/src/net/tools/naive/naive_proxy.h @@ -8,10 +8,12 @@ #include #include +#include #include "base/macros.h" #include "base/memory/weak_ptr.h" #include "net/base/completion_repeating_callback.h" +#include "net/base/network_isolation_key.h" #include "net/log/net_log_with_source.h" #include "net/proxy_resolution/proxy_info.h" #include "net/ssl/ssl_config.h" @@ -32,6 +34,7 @@ class NaiveProxy { NaiveProxy(std::unique_ptr server_socket, NaiveConnection::Protocol protocol, bool use_padding, + int concurrency, RedirectResolver* resolver, HttpNetworkSession* session, const NetworkTrafficAnnotationTag& traffic_annotation); @@ -57,6 +60,7 @@ class NaiveProxy { std::unique_ptr listen_socket_; NaiveConnection::Protocol protocol_; bool use_padding_; + int concurrency_; ProxyInfo proxy_info_; SSLConfig server_ssl_config_; SSLConfig proxy_ssl_config_; @@ -68,6 +72,8 @@ class NaiveProxy { std::unique_ptr accepted_socket_; + std::vector network_isolation_keys_; + std::map> connection_by_id_; const NetworkTrafficAnnotationTag& traffic_annotation_; diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc index d4b071df3b..30414deacf 100644 --- a/src/net/tools/naive/naive_proxy_bin.cc +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -11,6 +11,7 @@ #include "base/at_exit.h" #include "base/command_line.h" +#include "base/feature_list.h" #include "base/files/file_path.h" #include "base/json/json_file_value_serializer.h" #include "base/json/json_writer.h" @@ -80,6 +81,7 @@ struct CommandLine { std::string listen; std::string proxy; bool padding; + std::string concurrency; std::string extra_headers; std::string host_resolver_rules; std::string resolver_range; @@ -94,6 +96,7 @@ struct Params { std::string listen_addr; int listen_port; bool use_padding; + int concurrency; net::HttpRequestHeaders extra_headers; std::string proxy_url; std::string proxy_user; @@ -141,6 +144,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { "--proxy=://[:@][:]\n" " proto: https, quic\n" "--padding Use padding\n" + "--concurrency= Use N connections, less secure\n" "--extra-headers=... Extra headers split by CRLF\n" "--host-resolver-rules=... Resolver rules\n" "--resolver-range=... Redirect resolver range\n" @@ -159,6 +163,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { cmdline->listen = proc.GetSwitchValueASCII("listen"); cmdline->proxy = proc.GetSwitchValueASCII("proxy"); cmdline->padding = proc.HasSwitch("padding"); + cmdline->concurrency = proc.GetSwitchValueASCII("concurrency"); cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers"); cmdline->host_resolver_rules = proc.GetSwitchValueASCII("host-resolver-rules"); @@ -193,6 +198,10 @@ void GetCommandLineFromConfig(const base::FilePath& config_path, cmdline->proxy = *proxy; } cmdline->padding = value->FindBoolKey("padding").value_or(false); + const auto* concurrency = value->FindStringKey("concurrency"); + if (concurrency) { + cmdline->concurrency = *concurrency; + } const auto* extra_headers = value->FindStringKey("extra_headers"); if (extra_headers) { cmdline->extra_headers = *extra_headers; @@ -292,6 +301,16 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->use_padding = cmdline.padding; + if (!cmdline.concurrency.empty()) { + if (!base::StringToInt(cmdline.concurrency, ¶ms->concurrency) || + params->concurrency < 1 || params->concurrency > 4) { + std::cerr << "Invalid concurrency" << std::endl; + return false; + } + } else { + params->concurrency = 1; + } + params->extra_headers.AddHeadersFromString(cmdline.extra_headers); params->host_resolver_rules = cmdline.host_resolver_rules; @@ -475,6 +494,8 @@ std::unique_ptr BuildURLRequestContext( } // namespace net int main(int argc, char* argv[]) { + base::FeatureList::InitializeInstance( + "PartitionConnectionsByNetworkIsolationKey", std::string()); base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO); base::ThreadPoolInstance::CreateAndStartWithDefaultParams("naive"); base::AtExitManager exit_manager; @@ -589,8 +610,8 @@ int main(int argc, char* argv[]) { } net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, - params.use_padding, resolver.get(), session, - kTrafficAnnotation); + params.use_padding, params.concurrency, + resolver.get(), session, kTrafficAnnotation); base::RunLoop().Run();