diff --git a/src/net/BUILD.gn b/src/net/BUILD.gn index 29a3541ac4..a183cd50dc 100644 --- a/src/net/BUILD.gn +++ b/src/net/BUILD.gn @@ -1727,3 +1727,39 @@ static_library("preload_decoder") { ] deps = [ "//base" ] } + +executable("naive") { + sources = [ + "tools/naive/naive_connection.cc", + "tools/naive/naive_connection.h", + "tools/naive/naive_padding_framer.cc", + "tools/naive/naive_padding_framer.h", + "tools/naive/naive_padding_socket.cc", + "tools/naive/naive_padding_socket.h", + "tools/naive/naive_protocol.cc", + "tools/naive/naive_protocol.h", + "tools/naive/naive_proxy.cc", + "tools/naive/naive_proxy.h", + "tools/naive/naive_proxy_bin.cc", + "tools/naive/naive_proxy_delegate.h", + "tools/naive/naive_proxy_delegate.cc", + "tools/naive/http_proxy_server_socket.cc", + "tools/naive/http_proxy_server_socket.h", + "tools/naive/redirect_resolver.h", + "tools/naive/redirect_resolver.cc", + "tools/naive/socks5_server_socket.cc", + "tools/naive/socks5_server_socket.h", + ] + + deps = [ + ":net", + "//base", + "//build/win:default_exe_manifest", + "//components/version_info:version_info", + "//url", + ] + + if (is_apple) { + deps += [ "//base/allocator:early_zone_registration_apple" ] + } +} diff --git a/src/net/log/net_log_event_type_list.h b/src/net/log/net_log_event_type_list.h index 68d2258457..8275b133bb 100644 --- a/src/net/log/net_log_event_type_list.h +++ b/src/net/log/net_log_event_type_list.h @@ -537,6 +537,11 @@ EVENT_TYPE(SOCKS_HOSTNAME_TOO_BIG) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE) +EVENT_TYPE(SOCKS_NO_REQUESTED_AUTH) +EVENT_TYPE(SOCKS_NO_ACCEPTABLE_AUTH) +EVENT_TYPE(SOCKS_ZERO_LENGTH_DOMAIN) +EVENT_TYPE(SOCKS_UNEXPECTED_COMMAND) + // This event indicates that a bad version number was received in the // proxy server's response. The extra parameters show its value: // { diff --git a/src/net/tools/naive/http_proxy_server_socket.cc b/src/net/tools/naive/http_proxy_server_socket.cc new file mode 100644 index 0000000000..f4afd1ff68 --- /dev/null +++ b/src/net/tools/naive/http_proxy_server_socket.cc @@ -0,0 +1,461 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/http_proxy_server_socket.h" + +#include +#include +#include +#include + +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/strings/string_split.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/base/url_util.h" +#include "net/http/http_request_headers.h" +#include "net/log/net_log.h" +#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/tools/naive/naive_proxy_delegate.h" +#include "url/gurl.h" + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +constexpr size_t kMaxHeaderSize = 64 * 1024; +constexpr char kResponseHeader[] = "HTTP/1.1 200 OK\r\nPadding: "; +constexpr int kResponseHeaderSize = sizeof(kResponseHeader) - 1; +// A plain 200 is 10 bytes. Expected 48 bytes. "Padding" uses up 7 bytes. +constexpr int kMinPaddingSize = 30; +constexpr int kMaxPaddingSize = kMinPaddingSize + 32; +} // namespace + +HttpProxyServerSocket::HttpProxyServerSocket( + std::unique_ptr transport_socket, + ClientPaddingDetectorDelegate* padding_detector_delegate, + const NetworkTrafficAnnotationTag& traffic_annotation, + const std::vector& supported_padding_types) + : io_callback_(base::BindRepeating(&HttpProxyServerSocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + padding_detector_delegate_(padding_detector_delegate), + next_state_(STATE_NONE), + completed_handshake_(false), + was_ever_used_(false), + header_write_size_(-1), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation), + supported_padding_types_(supported_padding_types) {} + +HttpProxyServerSocket::~HttpProxyServerSocket() { + Disconnect(); +} + +const HostPortPair& HttpProxyServerSocket::request_endpoint() const { + return request_endpoint_; +} + +int HttpProxyServerSocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_HEADER_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } + return rv; +} + +void HttpProxyServerSocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool HttpProxyServerSocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool HttpProxyServerSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& HttpProxyServerSocket::NetLog() const { + return net_log_; +} + +bool HttpProxyServerSocket::WasEverUsed() const { + return was_ever_used_; +} + +NextProto HttpProxyServerSocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool HttpProxyServerSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +int64_t HttpProxyServerSocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void HttpProxyServerSocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the HTTP header is complete. +int HttpProxyServerSocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + if (!buffer_.empty()) { + was_ever_used_ = true; + int data_len = buffer_.size(); + if (data_len <= buf_len) { + std::memcpy(buf->data(), buffer_.data(), data_len); + buffer_.clear(); + return data_len; + } else { + std::memcpy(buf->data(), buffer_.data(), buf_len); + buffer_ = buffer_.substr(buf_len); + return buf_len; + } + } + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&HttpProxyServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// HTTP CONNECT request is complete. +int HttpProxyServerSocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&HttpProxyServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int HttpProxyServerSocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int HttpProxyServerSocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void HttpProxyServerSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void HttpProxyServerSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +void HttpProxyServerSocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int HttpProxyServerSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_HEADER_READ: + DCHECK_EQ(OK, rv); + rv = DoHeaderRead(); + break; + case STATE_HEADER_READ_COMPLETE: + rv = DoHeaderReadComplete(rv); + break; + case STATE_HEADER_WRITE: + DCHECK_EQ(OK, rv); + rv = DoHeaderWrite(); + break; + case STATE_HEADER_WRITE_COMPLETE: + rv = DoHeaderWriteComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int HttpProxyServerSocket::DoHeaderRead() { + next_state_ = STATE_HEADER_READ_COMPLETE; + + handshake_buf_ = base::MakeRefCounted(kBufferSize); + return transport_->Read(handshake_buf_.get(), kBufferSize, io_callback_); +} + +std::optional HttpProxyServerSocket::ParsePaddingHeaders( + const HttpRequestHeaders& headers) { + bool has_padding = headers.HasHeader(kPaddingHeader); + std::string padding_type_request; + bool has_padding_type_request = + headers.GetHeader(kPaddingTypeRequestHeader, &padding_type_request); + + if (!has_padding_type_request) { + // Backward compatibility with before kVariant1 when the padding-version + // header does not exist. + if (has_padding) { + return PaddingType::kVariant1; + } else { + return PaddingType::kNone; + } + } + + std::vector padding_type_strs = base::SplitStringPiece( + padding_type_request, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL); + for (std::string_view padding_type_str : padding_type_strs) { + std::optional padding_type = + ParsePaddingType(padding_type_str); + if (!padding_type.has_value()) { + LOG(ERROR) << "Invalid padding type: " << padding_type_str; + return std::nullopt; + } + if (std::find(supported_padding_types_.begin(), + supported_padding_types_.end(), + *padding_type) != supported_padding_types_.end()) { + return padding_type; + } + } + LOG(ERROR) << "No padding type is supported: " << padding_type_request; + return std::nullopt; +} + +int HttpProxyServerSocket::DoHeaderReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_CONNECTION_CLOSED; + } + + buffer_.append(handshake_buf_->data(), result); + if (buffer_.size() > kMaxHeaderSize) { + return ERR_MSG_TOO_BIG; + } + + size_t header_end = buffer_.find("\r\n\r\n"); + if (header_end == std::string::npos) { + next_state_ = STATE_HEADER_READ; + return OK; + } + + size_t first_line_end = buffer_.find("\r\n"); + size_t first_space = buffer_.find(' '); + bool is_http_1_0 = false; + if (first_space == std::string::npos || first_space + 1 >= first_line_end) { + LOG(WARNING) << "Invalid request: " << buffer_.substr(0, first_line_end); + return ERR_INVALID_ARGUMENT; + } + size_t second_space = buffer_.find(' ', first_space + 1); + if (second_space == std::string::npos || second_space >= first_line_end) { + LOG(WARNING) << "Invalid request: " << buffer_.substr(0, first_line_end); + return ERR_INVALID_ARGUMENT; + } + + std::string method = buffer_.substr(0, first_space); + std::string uri = + buffer_.substr(first_space + 1, second_space - (first_space + 1)); + std::string version = + buffer_.substr(second_space + 1, first_line_end - (second_space + 1)); + if (method == HttpRequestHeaders::kConnectMethod) { + request_endpoint_ = HostPortPair::FromString(uri); + } else { + // postprobe endpoint handling + is_http_1_0 = true; + } + + size_t second_line = first_line_end + 2; + HttpRequestHeaders headers; + std::string headers_str; + if (second_line < header_end) { + headers_str = buffer_.substr(second_line, header_end - second_line); + headers.AddHeadersFromString(headers_str); + } + + if (is_http_1_0) { + GURL url(uri); + if (!url.is_valid()) { + LOG(WARNING) << "Invalid URI: " << uri; + return ERR_INVALID_ARGUMENT; + } + + std::string host; + int port; + + std::string host_str; + if (headers.GetHeader(HttpRequestHeaders::kHost, &host_str)) { + if (!ParseHostAndPort(host_str, &host, &port)) { + LOG(WARNING) << "Invalid Host: " << host_str; + return ERR_INVALID_ARGUMENT; + } + if (port == -1) { + port = 80; + } + } else { + if (!url.has_host()) { + LOG(WARNING) << "Missing host: " << uri; + return ERR_INVALID_ARGUMENT; + } + host = url.host(); + port = url.EffectiveIntPort(); + + host_str = url.host(); + if (url.has_port()) { + host_str.append(":").append(url.port()); + } + headers.SetHeader(HttpRequestHeaders::kHost, host_str); + } + // Host is already known. Converts any absolute URI to relative. + uri = url.path(); + if (url.has_query()) { + uri.append("?").append(url.query()); + } + + request_endpoint_.set_host(host); + request_endpoint_.set_port(port); + } + + std::optional padding_type = ParsePaddingHeaders(headers); + if (!padding_type.has_value()) { + return ERR_INVALID_ARGUMENT; + } + padding_detector_delegate_->SetClientPaddingType(*padding_type); + + if (is_http_1_0) { + // Regenerates http header to make sure don't leak them to end servers + HttpRequestHeaders sanitized_headers = headers; + sanitized_headers.RemoveHeader(HttpRequestHeaders::kProxyConnection); + sanitized_headers.RemoveHeader(HttpRequestHeaders::kProxyAuthorization); + std::ostringstream ss; + ss << method << " " << uri << " " << version << "\r\n" + << sanitized_headers.ToString(); + if (buffer_.size() > header_end + 4) { + ss << buffer_.substr(header_end + 4); + } + buffer_ = ss.str(); + // Skips padding write for raw http proxy + completed_handshake_ = true; + next_state_ = STATE_NONE; + return OK; + } + + buffer_ = buffer_.substr(header_end + 4); + + next_state_ = STATE_HEADER_WRITE; + return OK; +} + +int HttpProxyServerSocket::DoHeaderWrite() { + next_state_ = STATE_HEADER_WRITE_COMPLETE; + + // Adds padding. + int padding_size = base::RandInt(kMinPaddingSize, kMaxPaddingSize); + header_write_size_ = kResponseHeaderSize + padding_size + 4; + handshake_buf_ = base::MakeRefCounted(header_write_size_); + char* p = handshake_buf_->data(); + std::memcpy(p, kResponseHeader, kResponseHeaderSize); + FillNonindexHeaderValue(base::RandUint64(), p + kResponseHeaderSize, + padding_size); + std::memcpy(p + kResponseHeaderSize + padding_size, "\r\n\r\n", 4); + + return transport_->Write(handshake_buf_.get(), header_write_size_, + io_callback_, traffic_annotation_); +} + +int HttpProxyServerSocket::DoHeaderWriteComplete(int result) { + if (result < 0) + return result; + + if (result != header_write_size_) { + return ERR_FAILED; + } + + completed_handshake_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int HttpProxyServerSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int HttpProxyServerSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/http_proxy_server_socket.h b/src/net/tools/naive/http_proxy_server_socket.h new file mode 100644 index 0000000000..22d6ed43ac --- /dev/null +++ b/src/net/tools/naive/http_proxy_server_socket.h @@ -0,0 +1,131 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_HTTP_PROXY_SERVER_SOCKET_H_ +#define NET_TOOLS_NAIVE_HTTP_PROXY_SERVER_SOCKET_H_ + +#include +#include +#include +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/http/http_request_headers.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" +#include "net/tools/naive/naive_protocol.h" + +namespace net { +struct NetworkTrafficAnnotationTag; +class ClientPaddingDetectorDelegate; + +// This StreamSocket is used to setup a HTTP CONNECT tunnel. +class HttpProxyServerSocket : public StreamSocket { + public: + HttpProxyServerSocket( + std::unique_ptr transport_socket, + ClientPaddingDetectorDelegate* padding_detector_delegate, + const NetworkTrafficAnnotationTag& traffic_annotation, + const std::vector& supported_padding_types); + HttpProxyServerSocket(const HttpProxyServerSocket&) = delete; + HttpProxyServerSocket& operator=(const HttpProxyServerSocket&) = delete; + + // On destruction Disconnect() is called. + ~HttpProxyServerSocket() override; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_HEADER_READ, + STATE_HEADER_READ_COMPLETE, + STATE_HEADER_WRITE, + STATE_HEADER_WRITE_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoHeaderWrite(); + int DoHeaderWriteComplete(int result); + int DoHeaderRead(); + int DoHeaderReadComplete(int result); + + std::optional ParsePaddingHeaders( + const HttpRequestHeaders& headers); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + ClientPaddingDetectorDelegate* padding_detector_delegate_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + std::string buffer_; + bool completed_handshake_; + bool was_ever_used_; + int header_write_size_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + std::vector supported_padding_types_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_HTTP_PROXY_SERVER_SOCKET_H_ diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc new file mode 100644 index 0000000000..f5e4b0fe03 --- /dev/null +++ b/src/net/tools/naive/naive_connection.cc @@ -0,0 +1,474 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_connection.h" + +#include +#include + +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/strings/strcat.h" +#include "base/task/single_thread_task_runner.h" +#include "base/time/time.h" +#include "build/build_config.h" +#include "net/base/io_buffer.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/base/privacy_mode.h" +#include "net/base/url_util.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/stream_socket.h" +#include "net/spdy/spdy_session.h" +#include "net/tools/naive/http_proxy_server_socket.h" +#include "net/tools/naive/naive_padding_socket.h" +#include "net/tools/naive/redirect_resolver.h" +#include "net/tools/naive/socks5_server_socket.h" +#include "url/scheme_host_port.h" + +#if BUILDFLAG(IS_LINUX) +#include +#include +#include + +#include "net/base/ip_endpoint.h" +#include "net/base/sockaddr_storage.h" +#include "net/socket/tcp_client_socket.h" +#endif + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +} // namespace + +NaiveConnection::NaiveConnection( + unsigned int id, + ClientProtocol protocol, + std::unique_ptr padding_detector_delegate, + const ProxyInfo& proxy_info, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkAnonymizationKey& network_anonymization_key, + const NetLogWithSource& net_log, + std::unique_ptr accepted_socket, + const NetworkTrafficAnnotationTag& traffic_annotation) + : id_(id), + protocol_(protocol), + padding_detector_delegate_(std::move(padding_detector_delegate)), + proxy_info_(proxy_info), + resolver_(resolver), + session_(session), + network_anonymization_key_(network_anonymization_key), + net_log_(net_log), + next_state_(STATE_NONE), + client_socket_(std::move(accepted_socket)), + server_socket_handle_(std::make_unique()), + sockets_{nullptr, nullptr}, + errors_{OK, OK}, + write_pending_{false, false}, + early_pull_pending_(false), + can_push_to_server_(false), + early_pull_result_(ERR_IO_PENDING), + full_duplex_(false), + time_func_(&base::TimeTicks::Now), + traffic_annotation_(traffic_annotation) { + io_callback_ = base::BindRepeating(&NaiveConnection::OnIOComplete, + weak_ptr_factory_.GetWeakPtr()); +} + +NaiveConnection::~NaiveConnection() { + Disconnect(); +} + +int NaiveConnection::Connect(CompletionOnceCallback callback) { + DCHECK(client_socket_); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + if (full_duplex_) + return OK; + + next_state_ = STATE_CONNECT_CLIENT; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = std::move(callback); + } + return rv; +} + +void NaiveConnection::Disconnect() { + full_duplex_ = false; + // Closes server side first because latency is higher. + if (server_socket_handle_->socket()) + server_socket_handle_->socket()->Disconnect(); + client_socket_->Disconnect(); + + next_state_ = STATE_NONE; + connect_callback_.Reset(); + run_callback_.Reset(); +} + +void NaiveConnection::DoCallback(int result) { + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(connect_callback_); + + // Since Run() may result in Read being called, + // clear connect_callback_ up front. + std::move(connect_callback_).Run(result); +} + +void NaiveConnection::OnIOComplete(int result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +int NaiveConnection::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT_CLIENT: + DCHECK_EQ(rv, OK); + rv = DoConnectClient(); + break; + case STATE_CONNECT_CLIENT_COMPLETE: + rv = DoConnectClientComplete(rv); + break; + case STATE_CONNECT_SERVER: + DCHECK_EQ(rv, OK); + rv = DoConnectServer(); + break; + case STATE_CONNECT_SERVER_COMPLETE: + rv = DoConnectServerComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int NaiveConnection::DoConnectClient() { + next_state_ = STATE_CONNECT_CLIENT_COMPLETE; + + return client_socket_->Connect(io_callback_); +} + +int NaiveConnection::DoConnectClientComplete(int result) { + if (result < 0) + return result; + + std::optional client_padding_type = + padding_detector_delegate_->GetClientPaddingType(); + CHECK(client_padding_type.has_value()); + + sockets_[kClient] = std::make_unique( + client_socket_.get(), *client_padding_type, kClient); + + // For proxy client sockets, padding support detection is finished after the + // first server response which means there will be one missed early pull. For + // proxy server sockets (HttpProxyServerSocket), padding support detection is + // done during client connect, so there shouldn't be any missed early pull. + if (!padding_detector_delegate_->GetServerPaddingType().has_value()) { + early_pull_pending_ = false; + early_pull_result_ = 0; + next_state_ = STATE_CONNECT_SERVER; + return OK; + } + + early_pull_pending_ = true; + Pull(kClient, kServer); + if (early_pull_result_ != ERR_IO_PENDING) { + // Pull has completed synchronously. + if (early_pull_result_ <= 0) { + return early_pull_result_ ? early_pull_result_ : ERR_CONNECTION_CLOSED; + } + } + + next_state_ = STATE_CONNECT_SERVER; + return OK; +} + +int NaiveConnection::DoConnectServer() { + next_state_ = STATE_CONNECT_SERVER_COMPLETE; + + HostPortPair origin; + if (protocol_ == ClientProtocol::kSocks5) { + const auto* socket = + static_cast(client_socket_.get()); + origin = socket->request_endpoint(); + } else if (protocol_ == ClientProtocol::kHttp) { + const auto* socket = + static_cast(client_socket_.get()); + origin = socket->request_endpoint(); + } else if (protocol_ == ClientProtocol::kRedir) { +#if BUILDFLAG(IS_LINUX) + const auto* socket = + static_cast(client_socket_.get()); + IPEndPoint peer_endpoint; + int rv; + rv = socket->GetPeerAddress(&peer_endpoint); + if (rv != OK) { + LOG(ERROR) << "Connection " << id_ + << " cannot get peer address: " << ErrorToShortString(rv); + return rv; + } + int sd = socket->SocketDescriptorForTesting(); + SockaddrStorage dst; + if (peer_endpoint.GetFamily() == ADDRESS_FAMILY_IPV4 || + peer_endpoint.address().IsIPv4MappedIPv6()) { + rv = getsockopt(sd, SOL_IP, SO_ORIGINAL_DST, dst.addr, &dst.addr_len); + } else { + rv = getsockopt(sd, SOL_IPV6, SO_ORIGINAL_DST, dst.addr, &dst.addr_len); + } + if (rv == 0) { + IPEndPoint ipe; + if (ipe.FromSockAddr(dst.addr, dst.addr_len)) { + const auto& addr = ipe.address(); + auto name = resolver_->FindNameByAddress(addr); + if (!name.empty()) { + origin = HostPortPair(name, ipe.port()); + } else if (!resolver_->IsInResolvedRange(addr)) { + origin = HostPortPair::FromIPEndPoint(ipe); + } else { + LOG(ERROR) << "Connection " << id_ << " to unresolved name for " + << addr.ToString(); + return ERR_ADDRESS_INVALID; + } + } + } else { + LOG(ERROR) << "Failed to get original destination address"; + return ERR_ADDRESS_INVALID; + } +#else + static_cast(resolver_); +#endif + } + + url::CanonHostInfo host_info; + url::SchemeHostPort endpoint( + "http", CanonicalizeHost(origin.HostForURL(), &host_info), origin.port(), + url::SchemeHostPort::ALREADY_CANONICALIZED); + if (!endpoint.IsValid()) { + LOG(ERROR) << "Connection " << id_ << " to invalid origin " + << origin.ToString(); + return ERR_ADDRESS_INVALID; + } + + LOG(INFO) << "Connection " << id_ << " to " << origin.ToString(); + + // Ignores socket limit set by socket pool for this type of socket. + return InitSocketHandleForHttpRequest( + std::move(endpoint), LOAD_IGNORE_LIMITS, MAXIMUM_PRIORITY, session_, + proxy_info_, {}, PRIVACY_MODE_DISABLED, + network_anonymization_key_, SecureDnsPolicy::kDisable, SocketTag(), + net_log_, server_socket_handle_.get(), io_callback_, + ClientSocketPool::ProxyAuthCallback()); +} + +int NaiveConnection::DoConnectServerComplete(int result) { + if (result < 0) + return result; + + std::optional server_padding_type = + padding_detector_delegate_->GetServerPaddingType(); + CHECK(server_padding_type.has_value()); + + sockets_[kServer] = std::make_unique( + server_socket_handle_->socket(), *server_padding_type, kServer); + + full_duplex_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int NaiveConnection::Run(CompletionOnceCallback callback) { + DCHECK(sockets_[kServer]); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + // The client-side socket may be closed before the server-side + // socket is connected. + if (errors_[kClient] != OK || sockets_[kClient] == nullptr) + return errors_[kClient]; + if (errors_[kServer] != OK) + return errors_[kServer]; + + run_callback_ = std::move(callback); + + bytes_passed_without_yielding_[kClient] = 0; + bytes_passed_without_yielding_[kServer] = 0; + + yield_after_time_[kClient] = + time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds); + yield_after_time_[kServer] = yield_after_time_[kClient]; + + can_push_to_server_ = true; + // early_pull_result_ == 0 means the early pull was not started because + // padding support was not yet known. + if (!early_pull_pending_ && early_pull_result_ == 0) { + Pull(kClient, kServer); + } else if (!early_pull_pending_) { + DCHECK_GT(early_pull_result_, 0); + Push(kClient, kServer, early_pull_result_); + } + Pull(kServer, kClient); + + return ERR_IO_PENDING; +} + +void NaiveConnection::Pull(Direction from, Direction to) { + if (errors_[kClient] < 0 || errors_[kServer] < 0) + return; + + int read_size = kBufferSize; + read_buffers_[from] = base::MakeRefCounted(kBufferSize); + + DCHECK(sockets_[from]); + int rv = sockets_[from]->Read( + read_buffers_[from].get(), read_size, + base::BindRepeating(&NaiveConnection::OnPullComplete, + weak_ptr_factory_.GetWeakPtr(), from, to)); + + if (from == kClient && early_pull_pending_) + early_pull_result_ = rv; + + if (rv != ERR_IO_PENDING) + OnPullComplete(from, to, rv); +} + +void NaiveConnection::Push(Direction from, Direction to, int size) { + write_buffers_[to] = base::MakeRefCounted( + std::move(read_buffers_[from]), size); + write_pending_[to] = true; + DCHECK(sockets_[to]); + int rv = sockets_[to]->Write( + write_buffers_[to].get(), write_buffers_[to]->BytesRemaining(), + base::BindRepeating(&NaiveConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to), + traffic_annotation_); + + if (rv != ERR_IO_PENDING) + OnPushComplete(from, to, rv); +} + +void NaiveConnection::Disconnect(Direction side) { + if (sockets_[side]) { + sockets_[side]->Disconnect(); + sockets_[side] = nullptr; + write_pending_[side] = false; + } +} + +bool NaiveConnection::IsConnected(Direction side) { + return sockets_[side] != nullptr; +} + +void NaiveConnection::OnBothDisconnected() { + if (run_callback_) { + int error = OK; + if (errors_[kClient] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kClient]; + if (errors_[kServer] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kServer]; + std::move(run_callback_).Run(error); + } +} + +void NaiveConnection::OnPullError(Direction from, Direction to, int error) { + DCHECK_LT(error, 0); + + errors_[from] = error; + Disconnect(from); + + if (!write_pending_[to]) + Disconnect(to); + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPushError(Direction from, Direction to, int error) { + DCHECK_LE(error, 0); + DCHECK(!write_pending_[to]); + + if (error < 0) { + errors_[to] = error; + Disconnect(kServer); + Disconnect(kClient); + } else if (!IsConnected(from)) { + Disconnect(to); + } + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) { + if (from == kClient && early_pull_pending_) { + early_pull_pending_ = false; + early_pull_result_ = result ? result : ERR_CONNECTION_CLOSED; + } + + if (result <= 0) { + OnPullError(from, to, result ? result : ERR_CONNECTION_CLOSED); + return; + } + + if (from == kClient && !can_push_to_server_) + return; + + Push(from, to, result); +} + +void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) { + if (result >= 0 && write_buffers_[to] != nullptr) { + bytes_passed_without_yielding_[from] += result; + write_buffers_[to]->DidConsume(result); + int size = write_buffers_[to]->BytesRemaining(); + if (size > 0) { + int rv = sockets_[to]->Write( + write_buffers_[to].get(), size, + base::BindRepeating(&NaiveConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to), + traffic_annotation_); + if (rv != ERR_IO_PENDING) + OnPushComplete(from, to, rv); + return; + } + } + + write_pending_[to] = false; + // Checks for termination even if result is OK. + OnPushError(from, to, result >= 0 ? OK : result); + + if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead || + time_func_() > yield_after_time_[from]) { + bytes_passed_without_yielding_[from] = 0; + yield_after_time_[from] = + time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds); + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, + base::BindRepeating(&NaiveConnection::Pull, + weak_ptr_factory_.GetWeakPtr(), from, to)); + } else { + Pull(from, to); + } +} + +} // namespace net diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h new file mode 100644 index 0000000000..83b14fff28 --- /dev/null +++ b/src/net/tools/naive/naive_connection.h @@ -0,0 +1,134 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ +#define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ + +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/tools/naive/naive_padding_socket.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/tools/naive/naive_proxy_delegate.h" + +namespace net { + +class ClientSocketHandle; +class DrainableIOBuffer; +class HttpNetworkSession; +class IOBuffer; +class NetLogWithSource; +class ProxyInfo; +class StreamSocket; +struct NetworkTrafficAnnotationTag; +struct SSLConfig; +class RedirectResolver; +class NetworkAnonymizationKey; + +class NaiveConnection { + public: + using TimeFunc = base::TimeTicks (*)(); + + NaiveConnection( + unsigned int id, + ClientProtocol protocol, + std::unique_ptr padding_detector_delegate, + const ProxyInfo& proxy_info, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkAnonymizationKey& network_anonymization_key, + const NetLogWithSource& net_log, + std::unique_ptr accepted_socket, + const NetworkTrafficAnnotationTag& traffic_annotation); + ~NaiveConnection(); + NaiveConnection(const NaiveConnection&) = delete; + NaiveConnection& operator=(const NaiveConnection&) = delete; + + unsigned int id() const { return id_; } + int Connect(CompletionOnceCallback callback); + void Disconnect(); + int Run(CompletionOnceCallback callback); + + private: + enum State { + STATE_CONNECT_CLIENT, + STATE_CONNECT_CLIENT_COMPLETE, + STATE_CONNECT_SERVER, + STATE_CONNECT_SERVER_COMPLETE, + STATE_NONE, + }; + + enum PaddingState { + STATE_READ_PAYLOAD_LENGTH_1, + STATE_READ_PAYLOAD_LENGTH_2, + STATE_READ_PADDING_LENGTH, + STATE_READ_PAYLOAD, + STATE_READ_PADDING, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + int DoLoop(int last_io_result); + int DoConnectClient(); + int DoConnectClientComplete(int result); + int DoConnectServer(); + int DoConnectServerComplete(int result); + void Pull(Direction from, Direction to); + void Push(Direction from, Direction to, int size); + void Disconnect(Direction side); + bool IsConnected(Direction side); + void OnBothDisconnected(); + void OnPullError(Direction from, Direction to, int error); + void OnPushError(Direction from, Direction to, int error); + void OnPullComplete(Direction from, Direction to, int result); + void OnPushComplete(Direction from, Direction to, int result); + + unsigned int id_; + ClientProtocol protocol_; + std::unique_ptr padding_detector_delegate_; + const ProxyInfo& proxy_info_; + RedirectResolver* resolver_; + HttpNetworkSession* session_; + const NetworkAnonymizationKey& network_anonymization_key_; + const NetLogWithSource& net_log_; + + CompletionRepeatingCallback io_callback_; + CompletionOnceCallback connect_callback_; + CompletionOnceCallback run_callback_; + + State next_state_; + + std::unique_ptr client_socket_; + std::unique_ptr server_socket_handle_; + + std::unique_ptr sockets_[kNumDirections]; + scoped_refptr read_buffers_[kNumDirections]; + scoped_refptr write_buffers_[kNumDirections]; + int errors_[kNumDirections]; + bool write_pending_[kNumDirections]; + int bytes_passed_without_yielding_[kNumDirections]; + base::TimeTicks yield_after_time_[kNumDirections]; + + bool early_pull_pending_; + bool can_push_to_server_; + int early_pull_result_; + + bool full_duplex_; + + TimeFunc time_func_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ diff --git a/src/net/tools/naive/naive_padding_framer.cc b/src/net/tools/naive/naive_padding_framer.cc new file mode 100644 index 0000000000..1853bada56 --- /dev/null +++ b/src/net/tools/naive/naive_padding_framer.cc @@ -0,0 +1,118 @@ +// Copyright 2023 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "net/tools/naive/naive_padding_framer.h" + +#include +#include +#include +#include +#include + +#include "base/check.h" +#include "base/check_op.h" + +namespace net { +NaivePaddingFramer::NaivePaddingFramer(std::optional max_read_frames) + : max_read_frames_(max_read_frames) { + if (max_read_frames.has_value()) { + CHECK_GE(*max_read_frames, 0); + } +} + +int NaivePaddingFramer::Read(const char* padded, + int padded_len, + char* payload_buf, + int payload_buf_capacity) { + // This check guarantees write_ptr does not overflow. + CHECK_GE(payload_buf_capacity, padded_len); + + char* write_ptr = payload_buf; + while (padded_len > 0) { + int copy_size; + switch (state_) { + case ReadState::kPayloadLength1: + if (max_read_frames_.has_value() && + num_read_frames_ >= *max_read_frames_) { + std::memcpy(write_ptr, padded, padded_len); + padded += padded_len; + write_ptr += padded_len; + padded_len = 0; + break; + } + read_payload_length_ = static_cast(padded[0]); + ++padded; + --padded_len; + state_ = ReadState::kPayloadLength2; + break; + case ReadState::kPayloadLength2: + read_payload_length_ = + read_payload_length_ * 256 + static_cast(padded[0]); + ++padded; + --padded_len; + state_ = ReadState::kPaddingLength1; + break; + case ReadState::kPaddingLength1: + read_padding_length_ = static_cast(padded[0]); + ++padded; + --padded_len; + state_ = ReadState::kPayload; + break; + case ReadState::kPayload: + copy_size = std::min(read_payload_length_, padded_len); + read_payload_length_ -= copy_size; + if (read_payload_length_ == 0) { + state_ = ReadState::kPadding; + } + + std::memcpy(write_ptr, padded, copy_size); + padded += copy_size; + write_ptr += copy_size; + padded_len -= copy_size; + break; + case ReadState::kPadding: + copy_size = std::min(read_padding_length_, padded_len); + read_padding_length_ -= copy_size; + if (read_padding_length_ == 0) { + if (num_read_frames_ < std::numeric_limits::max() - 1) { + ++num_read_frames_; + } + state_ = ReadState::kPayloadLength1; + } + + padded += copy_size; + padded_len -= copy_size; + break; + } + } + return write_ptr - payload_buf; +} + +int NaivePaddingFramer::Write(const char* payload_buf, + int payload_buf_len, + int padding_size, + char* padded, + int padded_capacity, + int& payload_consumed_len) { + CHECK_GE(payload_buf_len, 0); + CHECK_LE(padding_size, max_padding_size()); + CHECK_GE(padding_size, 0); + + payload_consumed_len = std::min( + payload_buf_len, padded_capacity - frame_header_size() - padding_size); + int padded_buf_len = + frame_header_size() + payload_consumed_len + padding_size; + + padded[0] = payload_consumed_len / 256; + padded[1] = payload_consumed_len % 256; + padded[2] = padding_size; + std::memcpy(padded + frame_header_size(), payload_buf, payload_consumed_len); + std::memset(padded + frame_header_size() + payload_consumed_len, '\0', + padding_size); + + if (num_written_frames_ < std::numeric_limits::max() - 1) { + ++num_written_frames_; + } + return padded_buf_len; +} +} // namespace net diff --git a/src/net/tools/naive/naive_padding_framer.h b/src/net/tools/naive/naive_padding_framer.h new file mode 100644 index 0000000000..bdb5865276 --- /dev/null +++ b/src/net/tools/naive/naive_padding_framer.h @@ -0,0 +1,76 @@ +// Copyright 2023 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_PADDING_FRAMER_H_ +#define NET_TOOLS_NAIVE_NAIVE_PADDING_FRAMER_H_ + +#include +#include +#include + +namespace net { + +// struct PaddedFrame { +// uint16_t payload_size; // big-endian +// uint8_t padding_size; // big-endian +// uint8_t payload[payload_size]; +// uint8_t zeros[padding_size]; +// }; +class NaivePaddingFramer { + public: + // `max_read_frames`: Assumes the byte stream stops using the padding + // framing after `max_read_frames` frames. If -1, it means + // the byte stream always uses the padding framing. + explicit NaivePaddingFramer(std::optional max_read_frames); + + int max_payload_size() const { return std::numeric_limits::max(); } + + int max_padding_size() const { return std::numeric_limits::max(); } + + int frame_header_size() const { return 3; } + + int num_read_frames() const { return num_read_frames_; } + + int num_written_frames() const { return num_written_frames_; } + + // Reads `padded` for `padded_len` bytes and extracts unpadded payload to + // `payload_buf`. + // Returns the number of payload bytes extracted. + // Returning zero indicates a pure padding instead of EOF. + int Read(const char* padded, + int padded_len, + char* payload_buf, + int payload_buf_capacity); + + // Writes `payload_buf` for up to `payload_buf_len` bytes into `padded`. + // Returns the number of padded bytes written. + // If the padded bytes would exceed `padded_capacity`, the payload is + // truncated to `payload_consumed_len`. + int Write(const char* payload_buf, + int payload_buf_len, + int padding_size, + char* padded, + int padded_capacity, + int& payload_consumed_len); + + private: + enum class ReadState { + kPayloadLength1, + kPayloadLength2, + kPaddingLength1, + kPayload, + kPadding, + }; + + std::optional max_read_frames_; + + ReadState state_ = ReadState::kPayloadLength1; + int read_payload_length_ = 0; + int read_padding_length_ = 0; + int num_read_frames_ = 0; + + int num_written_frames_ = 0; +}; +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PADDING_FRAMER_H_ diff --git a/src/net/tools/naive/naive_padding_socket.cc b/src/net/tools/naive/naive_padding_socket.cc new file mode 100644 index 0000000000..3dfd2ec14a --- /dev/null +++ b/src/net/tools/naive/naive_padding_socket.cc @@ -0,0 +1,271 @@ +// Copyright 2023 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "net/tools/naive/naive_padding_socket.h" + +#include +#include +#include +#include +#include + +#include "base/rand_util.h" +#include "base/task/single_thread_task_runner.h" +#include "net/base/io_buffer.h" + +namespace net { + +namespace { +constexpr int kMaxBufferSize = 64 * 1024; +constexpr int kFirstPaddings = 8; +} // namespace + +NaivePaddingSocket::NaivePaddingSocket(StreamSocket* transport_socket, + PaddingType padding_type, + Direction direction) + : transport_socket_(transport_socket), + padding_type_(padding_type), + direction_(direction), + read_buf_(base::MakeRefCounted(kMaxBufferSize)), + framer_(kFirstPaddings) {} + +NaivePaddingSocket::~NaivePaddingSocket() { + Disconnect(); +} + +void NaivePaddingSocket::Disconnect() { + transport_socket_->Disconnect(); +} + +int NaivePaddingSocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(!callback.is_null()); + + switch (padding_type_) { + case PaddingType::kNone: + return ReadNoPadding(buf, buf_len, std::move(callback)); + case PaddingType::kVariant1: + if (framer_.num_read_frames() < kFirstPaddings) { + return ReadPaddingV1(buf, buf_len, std::move(callback)); + } else { + return ReadNoPadding(buf, buf_len, std::move(callback)); + } + default: + NOTREACHED(); + } +} + +int NaivePaddingSocket::ReadNoPadding(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + int rv = transport_socket_->Read( + buf, buf_len, + base::BindOnce(&NaivePaddingSocket::OnReadNoPaddingComplete, + base::Unretained(this), std::move(callback))); + return rv; +} + +void NaivePaddingSocket::OnReadNoPaddingComplete( + CompletionOnceCallback callback, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(callback); + + std::move(callback).Run(rv); +} + +int NaivePaddingSocket::ReadPaddingV1(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(!callback.is_null()); + DCHECK(read_user_buf_ == nullptr); + + // Truncates user requested buf len if it is too large for the padding + // buffer. + buf_len = std::min(buf_len, kMaxBufferSize); + read_user_buf_ = buf; + read_user_buf_len_ = buf_len; + + int rv = ReadPaddingV1Payload(); + + if (rv == ERR_IO_PENDING) { + read_callback_ = std::move(callback); + return rv; + } + + read_user_buf_ = nullptr; + + return rv; +} + +void NaivePaddingSocket::OnReadPaddingV1Complete(int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(read_callback_); + DCHECK(read_user_buf_ != nullptr); + + if (rv > 0) { + rv = framer_.Read(read_buf_->data(), rv, read_user_buf_->data(), + read_user_buf_len_); + if (rv == 0) { + rv = ReadPaddingV1Payload(); + if (rv == ERR_IO_PENDING) + return; + } + } + + // Must reset read_user_buf_ before invoking read_callback_, which may reenter + // Read(). + read_user_buf_ = nullptr; + + std::move(read_callback_).Run(rv); +} + +int NaivePaddingSocket::ReadPaddingV1Payload() { + for (;;) { + int rv = transport_socket_->Read( + read_buf_.get(), read_user_buf_len_, + base::BindOnce(&NaivePaddingSocket::OnReadPaddingV1Complete, + base::Unretained(this))); + if (rv <= 0) { + return rv; + } + rv = framer_.Read(read_buf_->data(), rv, read_user_buf_->data(), + read_user_buf_len_); + if (rv > 0) { + return rv; + } + } +} + +int NaivePaddingSocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + switch (padding_type_) { + case PaddingType::kNone: + return WriteNoPadding(buf, buf_len, std::move(callback), + traffic_annotation); + case PaddingType::kVariant1: + if (framer_.num_written_frames() < kFirstPaddings) { + return WritePaddingV1(buf, buf_len, std::move(callback), + traffic_annotation); + } else { + return WriteNoPadding(buf, buf_len, std::move(callback), + traffic_annotation); + } + default: + NOTREACHED(); + } +} + +int NaivePaddingSocket::WriteNoPadding( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + return transport_socket_->Write( + buf, buf_len, + base::BindOnce(&NaivePaddingSocket::OnWriteNoPaddingComplete, + base::Unretained(this), std::move(callback), + traffic_annotation), + traffic_annotation); +} + +void NaivePaddingSocket::OnWriteNoPaddingComplete( + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(callback); + + std::move(callback).Run(rv); +} + +int NaivePaddingSocket::WritePaddingV1( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(write_buf_ == nullptr); + + auto padded = base::MakeRefCounted(kMaxBufferSize); + int padding_size; + if (direction_ == kServer) { + if (buf_len < 100) { + padding_size = base::RandInt(framer_.max_padding_size() - buf_len, + framer_.max_padding_size()); + } else { + padding_size = base::RandInt(0, framer_.max_padding_size()); + } + } else { + padding_size = base::RandInt(0, framer_.max_padding_size()); + } + int write_buf_len = + framer_.Write(buf->data(), buf_len, padding_size, padded->data(), + kMaxBufferSize, write_user_payload_len_); + // Using DrainableIOBuffer here because we do not want to + // repeatedly encode the padding frames when short writes happen. + write_buf_ = + base::MakeRefCounted(std::move(padded), write_buf_len); + + int rv = WritePaddingV1Drain(traffic_annotation); + if (rv == ERR_IO_PENDING) { + write_callback_ = std::move(callback); + return rv; + } + + write_buf_ = nullptr; + write_user_payload_len_ = 0; + + return rv; +} + +void NaivePaddingSocket::OnWritePaddingV1Complete( + const NetworkTrafficAnnotationTag& traffic_annotation, + int rv) { + DCHECK_NE(ERR_IO_PENDING, rv); + DCHECK(write_callback_); + DCHECK(write_buf_ != nullptr); + + if (rv > 0) { + write_buf_->DidConsume(rv); + rv = WritePaddingV1Drain(traffic_annotation); + if (rv == ERR_IO_PENDING) + return; + } + + // Must reset these before invoking write_callback_, which may reenter + // Write(). + write_buf_ = nullptr; + write_user_payload_len_ = 0; + + std::move(write_callback_).Run(rv); +} + +int NaivePaddingSocket::WritePaddingV1Drain( + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(write_buf_ != nullptr); + + while (write_buf_->BytesRemaining() > 0) { + int remaining = write_buf_->BytesRemaining(); + if (direction_ == kServer && write_user_payload_len_ > 400 && + write_user_payload_len_ < 1024) { + remaining = std::min(remaining, base::RandInt(200, 300)); + } + int rv = transport_socket_->Write( + write_buf_.get(), remaining, + base::BindOnce(&NaivePaddingSocket::OnWritePaddingV1Complete, + base::Unretained(this), traffic_annotation), + traffic_annotation); + if (rv <= 0) { + return rv; + } + write_buf_->DidConsume(rv); + } + // Synchronously drained the buffer. + return write_user_payload_len_; +} + +} // namespace net diff --git a/src/net/tools/naive/naive_padding_socket.h b/src/net/tools/naive/naive_padding_socket.h new file mode 100644 index 0000000000..01a5720e32 --- /dev/null +++ b/src/net/tools/naive/naive_padding_socket.h @@ -0,0 +1,105 @@ +// Copyright 2023 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PADDING_SOCKET_H_ +#define NET_TOOLS_NAIVE_NAIVE_PADDING_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "net/base/address_list.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/net_export.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/naive_padding_framer.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "url/gurl.h" + +namespace net { + +class NaivePaddingSocket { + public: + NaivePaddingSocket(StreamSocket* transport_socket, + PaddingType padding_type, + Direction direction); + + NaivePaddingSocket(const NaivePaddingSocket&) = delete; + NaivePaddingSocket& operator=(const NaivePaddingSocket&) = delete; + + // On destruction Disconnect() is called. + ~NaivePaddingSocket(); + + void Disconnect(); + + int Read(IOBuffer* buf, int buf_len, CompletionOnceCallback callback); + + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation); + + private: + int ReadNoPadding(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback); + int WriteNoPadding(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation); + void OnReadNoPaddingComplete(CompletionOnceCallback callback, int rv); + void OnWriteNoPaddingComplete( + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation, + int rv); + + int ReadPaddingV1(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback); + int WritePaddingV1(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation); + void OnReadPaddingV1Complete(int rv); + void OnWritePaddingV1Complete( + const NetworkTrafficAnnotationTag& traffic_annotation, + int rv); + + // Exhausts synchronous reads if it is a pure padding + // so this does not return zero for non-EOF condition. + int ReadPaddingV1Payload(); + + int WritePaddingV1Drain( + const NetworkTrafficAnnotationTag& traffic_annotation); + + // Stores the underlying socket. + // Non-owning because this socket does not take part in the client socket pool + // handling and making it owning the transport socket may interfere badly + // with the client socket pool. + StreamSocket* transport_socket_; + + PaddingType padding_type_; + Direction direction_; + + IOBuffer* read_user_buf_ = nullptr; + int read_user_buf_len_ = 0; + CompletionOnceCallback read_callback_; + scoped_refptr read_buf_; + + int write_user_payload_len_ = 0; + CompletionOnceCallback write_callback_; + scoped_refptr write_buf_; + + NaivePaddingFramer framer_; +}; + +} // namespace net + +#endif // NET_TOOLS_NAIVE_NAIVE_PADDING_SOCKET_H_ diff --git a/src/net/tools/naive/naive_protocol.cc b/src/net/tools/naive/naive_protocol.cc new file mode 100644 index 0000000000..449a1b6fdc --- /dev/null +++ b/src/net/tools/naive/naive_protocol.cc @@ -0,0 +1,57 @@ +// Copyright 2023 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "net/tools/naive/naive_protocol.h" + +#include +#include + +#include "base/strings/string_piece.h" + +namespace net { +const char* ToString(ClientProtocol value) { + switch (value) { + case ClientProtocol::kSocks5: + return "socks"; + case ClientProtocol::kHttp: + return "http"; + case ClientProtocol::kRedir: + return "redir"; + default: + return ""; + } +} + +std::optional ParsePaddingType(std::string_view str) { + if (str == "0") { + return PaddingType::kNone; + } else if (str == "1") { + return PaddingType::kVariant1; + } else { + return std::nullopt; + } +} + +const char* ToString(PaddingType value) { + switch (value) { + case PaddingType::kNone: + return "0"; + case PaddingType::kVariant1: + return "1"; + default: + return ""; + } +} + +const char* ToReadableString(PaddingType value) { + switch (value) { + case PaddingType::kNone: + return "None"; + case PaddingType::kVariant1: + return "Variant1"; + default: + return ""; + } +} + +} // namespace net diff --git a/src/net/tools/naive/naive_protocol.h b/src/net/tools/naive/naive_protocol.h new file mode 100644 index 0000000000..1935269e20 --- /dev/null +++ b/src/net/tools/naive/naive_protocol.h @@ -0,0 +1,64 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ + +#include +#include +#include + +namespace net { +enum class ClientProtocol { + kSocks5, + kHttp, + kRedir, +}; + +const char* ToString(ClientProtocol value); + +// Adds padding for traffic from this direction. +// Removes padding for traffic from the opposite direction. +enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + kNone = 2, +}; + +enum class PaddingType { + // Wire format: "0". + kNone = 0, + + // Pads the first 8 reads and writes with padding bytes of random size + // uniformly distributed in [0, 255]. + // struct PaddedFrame { + // uint8_t original_data_size_high; // original_data_size / 256 + // uint8_t original_data_size_low; // original_data_size % 256 + // uint8_t padding_size; + // uint8_t original_data[original_data_size]; + // uint8_t zeros[padding_size]; + // }; + // Wire format: "1". + kVariant1 = 1, +}; + +// Returns empty if `str` is invalid. +std::optional ParsePaddingType(std::string_view str); + +const char* ToString(PaddingType value); + +const char* ToReadableString(PaddingType value); + +constexpr const char* kPaddingHeader = "padding"; + +// Contains a comma separated list of requested padding types. +// Preferred types come first. +constexpr const char* kPaddingTypeRequestHeader = "padding-type-request"; + +// Contains a single number representing the negotiated padding type. +// Must be one of PaddingType. +constexpr const char* kPaddingTypeReplyHeader = "padding-type-reply"; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc new file mode 100644 index 0000000000..0c50b0de2c --- /dev/null +++ b/src/net/tools/naive/naive_proxy.cc @@ -0,0 +1,203 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_proxy.h" + +#include +#include + +#include "base/functional/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/task/single_thread_task_runner.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/proxy_resolution/configured_proxy_resolution_service.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_list.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/http_proxy_server_socket.h" +#include "net/tools/naive/naive_proxy_delegate.h" +#include "net/tools/naive/socks5_server_socket.h" + +namespace net { + +NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, + ClientProtocol protocol, + const std::string& listen_user, + const std::string& listen_pass, + int concurrency, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation, + const std::vector& supported_padding_types) + : listen_socket_(std::move(listen_socket)), + protocol_(protocol), + listen_user_(listen_user), + listen_pass_(listen_pass), + concurrency_(concurrency), + resolver_(resolver), + session_(session), + net_log_( + NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)), + last_id_(0), + traffic_annotation_(traffic_annotation), + supported_padding_types_(supported_padding_types) { + const auto& proxy_config = static_cast( + session_->proxy_resolution_service()) + ->config(); + DCHECK(proxy_config); + const ProxyList& proxy_list = + proxy_config.value().value().proxy_rules().single_proxies; + DCHECK(!proxy_list.IsEmpty()); + proxy_info_.UseProxyList(proxy_list); + proxy_info_.set_traffic_annotation( + net::MutableNetworkTrafficAnnotationTag(traffic_annotation_)); + + for (int i = 0; i < concurrency_; i++) { + network_anonymization_keys_.push_back( + NetworkAnonymizationKey::CreateTransient()); + } + + DCHECK(listen_socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, base::BindOnce(&NaiveProxy::DoAcceptLoop, + weak_ptr_factory_.GetWeakPtr())); +} + +NaiveProxy::~NaiveProxy() = default; + +void NaiveProxy::DoAcceptLoop() { + int result; + do { + result = listen_socket_->Accept( + &accepted_socket_, base::BindRepeating(&NaiveProxy::OnAcceptComplete, + weak_ptr_factory_.GetWeakPtr())); + if (result == ERR_IO_PENDING) + return; + HandleAcceptResult(result); + } while (result == OK); +} + +void NaiveProxy::OnAcceptComplete(int result) { + HandleAcceptResult(result); + if (result == OK) + DoAcceptLoop(); +} + +void NaiveProxy::HandleAcceptResult(int result) { + if (result != OK) { + LOG(ERROR) << "Accept error: " << ErrorToShortString(result); + return; + } + DoConnect(); +} + +void NaiveProxy::DoConnect() { + std::unique_ptr socket; + auto* proxy_delegate = + static_cast(session_->context().proxy_delegate); + DCHECK(proxy_delegate); + DCHECK(!proxy_info_.is_empty()); + const ProxyChain& proxy_server = proxy_info_.proxy_chain(); + auto padding_detector_delegate = std::make_unique( + proxy_delegate, proxy_server, protocol_); + + if (protocol_ == ClientProtocol::kSocks5) { + socket = std::make_unique(std::move(accepted_socket_), + listen_user_, listen_pass_, + traffic_annotation_); + } else if (protocol_ == ClientProtocol::kHttp) { + socket = std::make_unique( + std::move(accepted_socket_), padding_detector_delegate.get(), + traffic_annotation_, supported_padding_types_); + } else if (protocol_ == ClientProtocol::kRedir) { + socket = std::move(accepted_socket_); + } else { + return; + } + + last_id_++; + int tunnel_session_id = last_id_ % concurrency_; + const auto& nak = network_anonymization_keys_[tunnel_session_id]; + auto connection_ptr = std::make_unique( + last_id_, protocol_, std::move(padding_detector_delegate), proxy_info_, + resolver_, session_, nak, net_log_, std::move(socket), + traffic_annotation_); + auto* connection = connection_ptr.get(); + connection_by_id_[connection->id()] = std::move(connection_ptr); + int result = connection->Connect( + base::BindRepeating(&NaiveProxy::OnConnectComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleConnectResult(connection, result); +} + +void NaiveProxy::OnConnectComplete(unsigned int connection_id, int result) { + auto* connection = FindConnection(connection_id); + if (!connection) + return; + HandleConnectResult(connection, result); +} + +void NaiveProxy::HandleConnectResult(NaiveConnection* connection, int result) { + if (result != OK) { + Close(connection->id(), result); + return; + } + DoRun(connection); +} + +void NaiveProxy::DoRun(NaiveConnection* connection) { + int result = connection->Run( + base::BindRepeating(&NaiveProxy::OnRunComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::OnRunComplete(unsigned int connection_id, int result) { + auto* connection = FindConnection(connection_id); + if (!connection) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::HandleRunResult(NaiveConnection* connection, int result) { + Close(connection->id(), result); +} + +void NaiveProxy::Close(unsigned int connection_id, int reason) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return; + + LOG(INFO) << "Connection " << connection_id + << " closed: " << ErrorToShortString(reason); + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::SingleThreadTaskRunner::GetCurrentDefault()->DeleteSoon( + FROM_HERE, std::move(it->second)); + connection_by_id_.erase(it); +} + +NaiveConnection* NaiveProxy::FindConnection(unsigned int connection_id) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return nullptr; + return it->second.get(); +} + +} // namespace net diff --git a/src/net/tools/naive/naive_proxy.h b/src/net/tools/naive/naive_proxy.h new file mode 100644 index 0000000000..626ac7e950 --- /dev/null +++ b/src/net/tools/naive/naive_proxy.h @@ -0,0 +1,90 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROXY_H_ + +#include +#include +#include +#include + +#include "base/memory/weak_ptr.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/network_isolation_key.h" +#include "net/log/net_log_with_source.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/ssl/ssl_config.h" +#include "net/tools/naive/naive_connection.h" +#include "net/tools/naive/naive_protocol.h" + +namespace net { + +class ClientSocketHandle; +class HttpNetworkSession; +class NaiveConnection; +class ServerSocket; +class StreamSocket; +struct NetworkTrafficAnnotationTag; +class RedirectResolver; + +class NaiveProxy { + public: + NaiveProxy(std::unique_ptr server_socket, + ClientProtocol protocol, + const std::string& listen_user, + const std::string& listen_pass, + int concurrency, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation, + const std::vector& supported_padding_types); + ~NaiveProxy(); + NaiveProxy(const NaiveProxy&) = delete; + NaiveProxy& operator=(const NaiveProxy&) = delete; + + private: + void DoAcceptLoop(); + void OnAcceptComplete(int result); + void HandleAcceptResult(int result); + + void DoConnect(); + void OnConnectComplete(unsigned int connection_id, int result); + void HandleConnectResult(NaiveConnection* connection, int result); + + void DoRun(NaiveConnection* connection); + void OnRunComplete(unsigned int connection_id, int result); + void HandleRunResult(NaiveConnection* connection, int result); + + void Close(unsigned int connection_id, int reason); + + NaiveConnection* FindConnection(unsigned int connection_id); + + std::unique_ptr listen_socket_; + ClientProtocol protocol_; + std::string listen_user_; + std::string listen_pass_; + int concurrency_; + ProxyInfo proxy_info_; + RedirectResolver* resolver_; + HttpNetworkSession* session_; + NetLogWithSource net_log_; + + unsigned int last_id_; + + std::unique_ptr accepted_socket_; + + std::vector network_anonymization_keys_; + + std::map> connection_by_id_; + + const NetworkTrafficAnnotationTag& traffic_annotation_; + + std::vector supported_padding_types_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_H_ diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc new file mode 100644 index 0000000000..879e2c21b5 --- /dev/null +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -0,0 +1,726 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include "base/allocator/allocator_check.h" +#include "base/allocator/partition_alloc_support.h" +#include "base/allocator/partition_allocator/src/partition_alloc/shim/allocator_shim.h" +#include "base/at_exit.h" +#include "base/check.h" +#include "base/command_line.h" +#include "base/feature_list.h" +#include "base/files/file_path.h" +#include "base/json/json_file_value_serializer.h" +#include "base/json/json_writer.h" +#include "base/logging.h" +#include "base/process/memory.h" +#include "base/rand_util.h" +#include "base/run_loop.h" +#include "base/strings/escape.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" +#include "base/strings/utf_string_conversions.h" +#include "base/system/sys_info.h" +#include "base/task/single_thread_task_executor.h" +#include "base/task/thread_pool/thread_pool_instance.h" +#include "base/values.h" +#include "build/build_config.h" +#include "components/version_info/version_info.h" +#include "net/base/auth.h" +#include "net/base/network_isolation_key.h" +#include "net/base/url_util.h" +#include "net/cert/cert_verifier.h" +#include "net/cert_net/cert_net_fetcher_url_request.h" +#include "net/dns/host_resolver.h" +#include "net/dns/mapped_host_resolver.h" +#include "net/http/http_auth.h" +#include "net/http/http_auth_cache.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_transaction_factory.h" +#include "net/log/file_net_log_observer.h" +#include "net/log/net_log.h" +#include "net/log/net_log_capture_mode.h" +#include "net/log/net_log_entry.h" +#include "net/log/net_log_event_type.h" +#include "net/log/net_log_source.h" +#include "net/log/net_log_util.h" +#include "net/proxy_resolution/configured_proxy_resolution_service.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_config_service_fixed.h" +#include "net/proxy_resolution/proxy_config_with_annotation.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_server_socket.h" +#include "net/socket/udp_server_socket.h" +#include "net/ssl/ssl_key_logger_impl.h" +#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/tools/naive/naive_proxy.h" +#include "net/tools/naive/naive_proxy_delegate.h" +#include "net/tools/naive/redirect_resolver.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_builder.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" +#include "url/url_util.h" + +#if BUILDFLAG(IS_APPLE) +#include "base/allocator/early_zone_registration_apple.h" +#include "base/apple/scoped_nsautorelease_pool.h" +#endif + +namespace { + +constexpr int kListenBackLog = 512; +constexpr int kDefaultMaxSocketsPerPool = 256; +constexpr int kDefaultMaxSocketsPerGroup = 255; +constexpr int kExpectedMaxUsers = 8; +constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("naive", ""); + +struct CommandLine { + std::vector listens; + std::string proxy; + std::string concurrency; + std::string extra_headers; + std::string host_resolver_rules; + std::string resolver_range; + bool no_log; + base::FilePath log; + base::FilePath log_net_log; + base::FilePath ssl_key_log_file; +}; + +struct ListenParams { + net::ClientProtocol protocol; + std::string listen_user; + std::string listen_pass; + std::string listen_addr; + int listen_port; +}; + +struct Params { + std::vector listens; + int concurrency; + net::HttpRequestHeaders extra_headers; + std::string proxy_url; + std::u16string proxy_user; + std::u16string proxy_pass; + std::string host_resolver_rules; + net::IPAddress resolver_range; + size_t resolver_prefix; + logging::LoggingSettings log_settings; + base::FilePath net_log_path; + base::FilePath ssl_key_path; +}; + +std::unique_ptr GetConstants() { + base::Value::Dict constants_dict = net::GetNetConstants(); + base::Value::Dict dict; + std::string os_type = base::StringPrintf( + "%s: %s (%s)", base::SysInfo::OperatingSystemName().c_str(), + base::SysInfo::OperatingSystemVersion().c_str(), + base::SysInfo::OperatingSystemArchitecture().c_str()); + dict.Set("os_type", os_type); + constants_dict.Set("clientInfo", std::move(dict)); + return std::make_unique(std::move(constants_dict)); +} + +class MultipleListenCollector : public base::DuplicateSwitchHandler { + public: + void ResolveDuplicate(std::string_view key, + base::CommandLine::StringPieceType new_value, + base::CommandLine::StringType& out_value) override { + out_value = new_value; + if (key == "listen") { +#if BUILDFLAG(IS_WIN) + all_values_.push_back(base::WideToUTF8(new_value)); +#else + all_values_.push_back(std::string(new_value)); +#endif + } + } + + const std::vector& GetAllValues() const { + return all_values_; + } + + private: + std::vector all_values_; +}; + +void GetCommandLine(const base::CommandLine& proc, + CommandLine* cmdline, + MultipleListenCollector& multiple_listens) { + if (proc.HasSwitch("h") || proc.HasSwitch("help")) { + std::cout << "Usage: naive { OPTIONS | config.json }\n" + "\n" + "Options:\n" + "-h, --help Show this message\n" + "--version Print version\n" + "--listen=://[addr][:port] [--listen=...]\n" + " proto: socks, http\n" + " redir (Linux only)\n" + "--proxy=://[:@][:]\n" + " proto: https, quic\n" + "--insecure-concurrency= Use N connections, insecure\n" + "--extra-headers=... Extra headers split by CRLF\n" + "--host-resolver-rules=... Resolver rules\n" + "--resolver-range=... Redirect resolver range\n" + "--log[=] Log to stderr, or file\n" + "--log-net-log= Save NetLog\n" + "--ssl-key-log-file= Save SSL keys for Wireshark\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + if (proc.HasSwitch("version")) { + std::cout << "naive " << version_info::GetVersionNumber() << std::endl; + exit(EXIT_SUCCESS); + } + + cmdline->listens = multiple_listens.GetAllValues(); + cmdline->proxy = proc.GetSwitchValueASCII("proxy"); + cmdline->concurrency = proc.GetSwitchValueASCII("insecure-concurrency"); + cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers"); + cmdline->host_resolver_rules = + proc.GetSwitchValueASCII("host-resolver-rules"); + cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range"); + cmdline->no_log = !proc.HasSwitch("log"); + cmdline->log = proc.GetSwitchValuePath("log"); + cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log"); + cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file"); +} + +void GetCommandLineFromConfig(const base::FilePath& config_path, + CommandLine* cmdline) { + JSONFileValueDeserializer reader(config_path); + int error_code; + std::string error_message; + std::unique_ptr value = + reader.Deserialize(&error_code, &error_message); + if (value == nullptr) { + std::cerr << "Error reading " << config_path << ": (" << error_code << ") " + << error_message << std::endl; + exit(EXIT_FAILURE); + } + base::Value::Dict* value_dict = value->GetIfDict(); + if (value_dict == nullptr) { + std::cerr << "Invalid config format" << std::endl; + exit(EXIT_FAILURE); + } + const std::string* listen = value_dict->FindString("listen"); + if (listen != nullptr) { + cmdline->listens = {*listen}; + } else { + const base::Value::List* listen_list = value_dict->FindList("listen"); + if (listen_list != nullptr) { + for (const auto& listen_element : *listen_list) { + const std::string* listen_elemet_str = listen_element.GetIfString(); + if (listen_elemet_str == nullptr) { + std::cerr << "Invalid listen element" << std::endl; + exit(EXIT_FAILURE); + } + cmdline->listens.push_back(*listen_elemet_str); + } + } + } + const std::string* proxy = value_dict->FindString("proxy"); + if (proxy) { + cmdline->proxy = *proxy; + } + const std::string* concurrency = + value_dict->FindString("insecure-concurrency"); + if (concurrency) { + cmdline->concurrency = *concurrency; + } + const std::string* extra_headers = value_dict->FindString("extra-headers"); + if (extra_headers) { + cmdline->extra_headers = *extra_headers; + } + const std::string* host_resolver_rules = + value_dict->FindString("host-resolver-rules"); + if (host_resolver_rules) { + cmdline->host_resolver_rules = *host_resolver_rules; + } + const std::string* resolver_range = value_dict->FindString("resolver-range"); + if (resolver_range) { + cmdline->resolver_range = *resolver_range; + } + cmdline->no_log = true; + const std::string* log = value_dict->FindString("log"); + if (log) { + cmdline->no_log = false; + cmdline->log = base::FilePath::FromUTF8Unsafe(*log); + } + const std::string* log_net_log = value_dict->FindString("log-net-log"); + if (log_net_log) { + cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(*log_net_log); + } + const std::string* ssl_key_log_file = + value_dict->FindString("ssl-key-log-file"); + if (ssl_key_log_file) { + cmdline->ssl_key_log_file = + base::FilePath::FromUTF8Unsafe(*ssl_key_log_file); + } +} + +bool ParseListenParams(const std::string& listen_str, + ListenParams& listen_params) { + GURL url(listen_str); + if (url.scheme() == "socks") { + listen_params.protocol = net::ClientProtocol::kSocks5; + } else if (url.scheme() == "http") { + listen_params.protocol = net::ClientProtocol::kHttp; + } else if (url.scheme() == "redir") { +#if BUILDFLAG(IS_LINUX) + listen_params.protocol = net::ClientProtocol::kRedir; +#else + std::cerr << "Redir protocol only supports Linux." << std::endl; + return false; +#endif + } else { + std::cerr << "Invalid scheme in --listen" << std::endl; + return false; + } + if (!url.username().empty()) { + listen_params.listen_user = + base::UnescapeBinaryURLComponent(url.username()); + } + if (!url.password().empty()) { + listen_params.listen_pass = + base::UnescapeBinaryURLComponent(url.password()); + } + if (!url.host().empty()) { + listen_params.listen_addr = url.HostNoBrackets(); + } else { + listen_params.listen_addr = "0.0.0.0"; + } + int port = url.EffectiveIntPort(); + if (port == url::PORT_INVALID) { + std::cerr << "Invalid port in --listen" << std::endl; + return false; + } else if (port == url::PORT_UNSPECIFIED) { + port = 1080; + } + listen_params.listen_port = port; + return true; +} + +bool ParseCommandLine(const CommandLine& cmdline, Params* params) { + url::AddStandardScheme("socks", + url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); + url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT); + + bool any_redir_protocol = false; + if (!cmdline.listens.empty()) { + for (const std::string& listen : cmdline.listens) { + ListenParams listen_params; + if (!ParseListenParams(listen, listen_params)) { + std::cerr << "Invalid listen: " << listen << std::endl; + return false; + } + if (listen_params.protocol == net::ClientProtocol::kRedir) { + any_redir_protocol = true; + } + params->listens.push_back(listen_params); + } + } else { + ListenParams default_listen = { + .protocol = net::ClientProtocol::kSocks5, + .listen_addr = "0.0.0.0", + .listen_port = 1080, + }; + params->listens = {default_listen}; + } + + params->proxy_url = "direct://"; + GURL url(cmdline.proxy); + GURL::Replacements remove_auth; + remove_auth.ClearUsername(); + remove_auth.ClearPassword(); + GURL url_no_auth = url.ReplaceComponents(remove_auth); + if (!cmdline.proxy.empty()) { + params->proxy_url = url_no_auth.GetWithEmptyPath().spec(); + if (params->proxy_url.empty()) { + std::cerr << "Invalid proxy URL" << std::endl; + return false; + } else if (params->proxy_url.back() == '/') { + params->proxy_url.pop_back(); + } + net::GetIdentityFromURL(url, ¶ms->proxy_user, ¶ms->proxy_pass); + } + + if (!cmdline.concurrency.empty()) { + if (!base::StringToInt(cmdline.concurrency, ¶ms->concurrency) || + params->concurrency < 1) { + std::cerr << "Invalid concurrency" << std::endl; + return false; + } + } else { + params->concurrency = 1; + } + + params->extra_headers.AddHeadersFromString(cmdline.extra_headers); + + params->host_resolver_rules = cmdline.host_resolver_rules; + + if (any_redir_protocol) { + std::string range = "100.64.0.0/10"; + if (!cmdline.resolver_range.empty()) + range = cmdline.resolver_range; + + if (!net::ParseCIDRBlock(range, ¶ms->resolver_range, + ¶ms->resolver_prefix)) { + std::cerr << "Invalid resolver range" << std::endl; + return false; + } + if (params->resolver_range.IsIPv6()) { + std::cerr << "IPv6 resolver range not supported" << std::endl; + return false; + } + } + + if (!cmdline.no_log) { + if (!cmdline.log.empty()) { + params->log_settings.logging_dest = logging::LOG_TO_FILE; + params->log_settings.log_file_path = cmdline.log.value().c_str(); + } else { + params->log_settings.logging_dest = logging::LOG_TO_STDERR; + } + } else { + params->log_settings.logging_dest = logging::LOG_NONE; + } + + params->net_log_path = cmdline.log_net_log; + params->ssl_key_path = cmdline.ssl_key_log_file; + + return true; +} +} // namespace + +namespace net { +namespace { +// NetLog::ThreadSafeObserver implementation that simply prints events +// to the logs. +class PrintingLogObserver : public NetLog::ThreadSafeObserver { + public: + PrintingLogObserver() = default; + PrintingLogObserver(const PrintingLogObserver&) = delete; + PrintingLogObserver& operator=(const PrintingLogObserver&) = delete; + + ~PrintingLogObserver() override { + // This is guaranteed to be safe as this program is single threaded. + net_log()->RemoveObserver(this); + } + + // NetLog::ThreadSafeObserver implementation: + void OnAddEntry(const NetLogEntry& entry) override { + switch (entry.type) { + case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS: + case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP: + case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_SESSION_SEND_WINDOW: + case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_STREAM_SEND_WINDOW: + case NetLogEventType::HTTP2_SESSION_STALLED_MAX_STREAMS: + case NetLogEventType::HTTP2_STREAM_FLOW_CONTROL_UNSTALLED: + break; + default: + return; + } + const char* source_type = NetLog::SourceTypeToString(entry.source.type); + const char* event_type = NetLogEventTypeToString(entry.type); + const char* event_phase = NetLog::EventPhaseToString(entry.phase); + base::Value params(entry.ToDict()); + std::string params_str; + base::JSONWriter::Write(params, ¶ms_str); + params_str.insert(0, ": "); + + VLOG(1) << source_type << "(" << entry.source.id << "): " << event_type + << ": " << event_phase << params_str; + } +}; +} // namespace + +namespace { +std::unique_ptr BuildCertURLRequestContext(NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + return builder.Build(); +} + +// Builds a URLRequestContext assuming there's only a single loop. +std::unique_ptr BuildURLRequestContext( + const Params& params, + scoped_refptr cert_net_fetcher, + NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString(params.proxy_url); + LOG(INFO) << "Proxying via " << params.proxy_url; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + if (!params.host_resolver_rules.empty()) { + builder.set_host_mapping_rules(params.host_resolver_rules); + } + + builder.SetCertVerifier( + CertVerifier::CreateDefault(std::move(cert_net_fetcher))); + + builder.set_proxy_delegate(std::make_unique( + params.extra_headers, + std::vector{PaddingType::kVariant1, PaddingType::kNone})); + + auto context = builder.Build(); + + if (!params.proxy_url.empty() && !params.proxy_user.empty() && + !params.proxy_pass.empty()) { + auto* session = context->http_transaction_factory()->GetSession(); + auto* auth_cache = session->http_auth_cache(); + std::string proxy_url = params.proxy_url; + GURL proxy_gurl(proxy_url); + if (proxy_url.compare(0, 7, "quic://") == 0) { + proxy_url.replace(0, 4, "https"); + proxy_gurl = GURL(proxy_url); + auto* quic = context->quic_context()->params(); + quic->supported_versions = {quic::ParsedQuicVersion::RFCv1()}; + quic->origins_to_force_quic_on.insert( + net::HostPortPair::FromURL(proxy_gurl)); + } + url::SchemeHostPort auth_origin(proxy_gurl); + AuthCredentials credentials(params.proxy_user, params.proxy_pass); + auth_cache->Add(auth_origin, HttpAuth::AUTH_PROXY, + /*realm=*/{}, HttpAuth::AUTH_SCHEME_BASIC, {}, + /*challenge=*/"Basic", credentials, /*path=*/"/"); + } + + return context; +} +} // namespace +} // namespace net + +int main(int argc, char* argv[]) { + // chrome/app/chrome_exe_main_mac.cc: main() +#if BUILDFLAG(IS_APPLE) + partition_alloc::EarlyMallocZoneRegistration(); +#endif + + // content/app/content_main.cc: RunContentProcess() +#if BUILDFLAG(IS_APPLE) + base::apple::ScopedNSAutoreleasePool pool; +#endif + + // content/app/content_main.cc: RunContentProcess() +#if BUILDFLAG(IS_APPLE) && BUILDFLAG(USE_ALLOCATOR_SHIM) + // The static initializer function for initializing PartitionAlloc + // InitializeDefaultMallocZoneWithPartitionAlloc() would be removed by the + // linker if allocator_shim.o is not referenced by the following call, + // resulting in undefined behavior of accessing uninitialized TLS + // data in PurgeCurrentThread() when PA is enabled. + allocator_shim::InitializeAllocatorShim(); +#endif + + // content/app/content_main.cc: RunContentProcess() + base::EnableTerminationOnOutOfMemory(); + + auto multiple_listens = std::make_unique(); + MultipleListenCollector& multiple_listens_ref = *multiple_listens; + base::CommandLine::SetDuplicateSwitchHandler(std::move(multiple_listens)); + + // content/app/content_main.cc: RunContentProcess() + base::CommandLine::Init(argc, argv); + + // content/app/content_main.cc: RunContentProcess() + base::EnableTerminationOnHeapCorruption(); + + // content/app/content_main.cc: RunContentProcess() + // content/app/content_main_runner_impl.cc: Initialize() + base::AtExitManager exit_manager; + std::string process_type = ""; + base::allocator::PartitionAllocSupport::Get()->ReconfigureEarlyish( + process_type); + + // content/app/content_main.cc: RunContentProcess() + // content/app/content_main_runner_impl.cc: Initialize() + // If we are on a platform where the default allocator is overridden (e.g. + // with PartitionAlloc on most platforms) smoke-tests that the overriding + // logic is working correctly. If not causes a hard crash, as its unexpected + // absence has security implications. + CHECK(base::allocator::IsAllocatorInitialized()); + + // content/app/content_main.cc: RunContentProcess() + // content/app/content_main_runner_impl.cc: Run() + base::FeatureList::InitInstance("PartitionConnectionsByNetworkIsolationKey", + std::string()); + + base::allocator::PartitionAllocSupport::Get() + ->ReconfigureAfterFeatureListInit(/*process_type=*/""); + + base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO); + base::ThreadPoolInstance::CreateAndStartWithDefaultParams("naive"); + + base::allocator::PartitionAllocSupport::Get()->ReconfigureAfterTaskRunnerInit( + process_type); + + url::AddStandardScheme("quic", + url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); + net::ClientSocketPoolManager::set_max_sockets_per_pool( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_proxy_chain( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_group( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerGroup * kExpectedMaxUsers); + + CommandLine cmdline; + + Params params; + const auto& proc = *base::CommandLine::ForCurrentProcess(); + const auto& args = proc.GetArgs(); + if (args.empty()) { + if (proc.argv().size() >= 2) { + GetCommandLine(proc, &cmdline, multiple_listens_ref); + } else { + auto path = base::FilePath::FromUTF8Unsafe("config.json"); + GetCommandLineFromConfig(path, &cmdline); + } + } else { + base::FilePath path(args[0]); + GetCommandLineFromConfig(path, &cmdline); + } + if (!ParseCommandLine(cmdline, ¶ms)) { + return EXIT_FAILURE; + } + CHECK(logging::InitLogging(params.log_settings)); + + if (!params.ssl_key_path.empty()) { + net::SSLClientSocket::SetSSLKeyLogger( + std::make_unique(params.ssl_key_path)); + } + + // The declaration order for net_log and printing_log_observer is + // important. The destructor of PrintingLogObserver removes itself + // from net_log, so net_log must be available for entire lifetime of + // printing_log_observer. + net::NetLog* net_log = net::NetLog::Get(); + std::unique_ptr observer; + if (!params.net_log_path.empty()) { + observer = net::FileNetLogObserver::CreateUnbounded( + params.net_log_path, net::NetLogCaptureMode::kDefault, GetConstants()); + observer->StartObserving(net_log); + } + + // Avoids net log overhead if verbose logging is disabled. + std::unique_ptr printing_log_observer; + if (params.log_settings.logging_dest != logging::LOG_NONE && VLOG_IS_ON(1)) { + printing_log_observer = std::make_unique(); + net_log->AddObserver(printing_log_observer.get(), + net::NetLogCaptureMode::kDefault); + } + + auto cert_context = net::BuildCertURLRequestContext(net_log); + scoped_refptr cert_net_fetcher; + // The builtin verifier is supported but not enabled by default on Mac, + // falling back to CreateSystemVerifyProc() which drops the net fetcher, + // causing a DCHECK in ~CertNetFetcherURLRequest(). + // See CertVerifier::CreateDefaultWithoutCaching() and + // CertVerifyProc::CreateSystemVerifyProc() for the build flags. +#if BUILDFLAG(CHROME_ROOT_STORE_SUPPORTED) || BUILDFLAG(IS_FUCHSIA) || \ + BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID) + cert_net_fetcher = base::MakeRefCounted(); + cert_net_fetcher->SetURLRequestContext(cert_context.get()); +#endif + auto context = + net::BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log); + auto* session = context->http_transaction_factory()->GetSession(); + + std::vector> naive_proxies; + std::unique_ptr resolver; + + for (const ListenParams& listen_params : params.listens) { + auto listen_socket = + std::make_unique(net_log, net::NetLogSource()); + + int result = listen_socket->ListenWithAddressAndPort( + listen_params.listen_addr, listen_params.listen_port, kListenBackLog); + if (result != net::OK) { + LOG(ERROR) << "Failed to listen on " + << net::ToString(listen_params.protocol) << "://" + << listen_params.listen_addr << " " + << listen_params.listen_port << ": " + << net::ErrorToShortString(result); + return EXIT_FAILURE; + } + LOG(INFO) << "Listening on " << net::ToString(listen_params.protocol) + << "://" << listen_params.listen_addr << ":" + << listen_params.listen_port; + + if (resolver == nullptr && + listen_params.protocol == net::ClientProtocol::kRedir) { + auto resolver_socket = + std::make_unique(net_log, net::NetLogSource()); + resolver_socket->AllowAddressReuse(); + net::IPAddress listen_addr; + if (!listen_addr.AssignFromIPLiteral(listen_params.listen_addr)) { + LOG(ERROR) << "Failed to open resolver: " << listen_params.listen_addr; + return EXIT_FAILURE; + } + + result = resolver_socket->Listen( + net::IPEndPoint(listen_addr, listen_params.listen_port)); + if (result != net::OK) { + LOG(ERROR) << "Failed to open resolver: " + << net::ErrorToShortString(result); + return EXIT_FAILURE; + } + + resolver = std::make_unique( + std::move(resolver_socket), params.resolver_range, + params.resolver_prefix); + } + + auto naive_proxy = std::make_unique( + std::move(listen_socket), listen_params.protocol, + listen_params.listen_user, listen_params.listen_pass, + params.concurrency, resolver.get(), session, kTrafficAnnotation, + std::vector{net::PaddingType::kVariant1, + net::PaddingType::kNone}); + naive_proxies.push_back(std::move(naive_proxy)); + } + + base::RunLoop().Run(); + + return EXIT_SUCCESS; +} diff --git a/src/net/tools/naive/naive_proxy_delegate.cc b/src/net/tools/naive/naive_proxy_delegate.cc new file mode 100644 index 0000000000..11f3d76b2f --- /dev/null +++ b/src/net/tools/naive/naive_proxy_delegate.cc @@ -0,0 +1,191 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "net/tools/naive/naive_proxy_delegate.h" + +#include +#include +#include + +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/strings/string_util.h" +#include "net/base/proxy_string_util.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h" + +namespace net { +namespace { +bool g_nonindex_codes_initialized; +uint8_t g_nonindex_codes[17]; +} // namespace + +void InitializeNonindexCodes() { + if (g_nonindex_codes_initialized) + return; + g_nonindex_codes_initialized = true; + unsigned i = 0; + for (const auto& symbol : spdy::HpackHuffmanCodeVector()) { + if (symbol.id >= 0x20 && symbol.id <= 0x7f && symbol.length >= 8) { + g_nonindex_codes[i++] = symbol.id; + if (i >= sizeof(g_nonindex_codes)) + break; + } + } + CHECK(i == sizeof(g_nonindex_codes)); +} + +void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len) { + DCHECK(g_nonindex_codes_initialized); + int first = len < 16 ? len : 16; + for (int i = 0; i < first; i++) { + buf[i] = g_nonindex_codes[unique_bits & 0b1111]; + unique_bits >>= 4; + } + for (int i = first; i < len; i++) { + buf[i] = g_nonindex_codes[16]; + } +} + +NaiveProxyDelegate::NaiveProxyDelegate( + const HttpRequestHeaders& extra_headers, + const std::vector& supported_padding_types) + : extra_headers_(extra_headers) { + InitializeNonindexCodes(); + + std::vector padding_type_strs; + for (PaddingType padding_type : supported_padding_types) { + padding_type_strs.push_back(ToString(padding_type)); + } + extra_headers_.SetHeader(kPaddingTypeRequestHeader, + base::JoinString(padding_type_strs, ", ")); +} + +NaiveProxyDelegate::~NaiveProxyDelegate() = default; + +void NaiveProxyDelegate::OnBeforeTunnelRequest( + const ProxyChain& proxy_chain, + size_t chain_index, + HttpRequestHeaders* extra_headers) { + // Not possible to negotiate padding capability given the underlying + // protocols. + if (proxy_chain.is_direct()) + return; + CHECK_EQ(proxy_chain.length(), 1u) << "Multi-hop proxy not supported"; + if (proxy_chain.GetProxyServer(chain_index).is_socks()) + return; + + // Sends client-side padding header regardless of server support + std::string padding(base::RandInt(16, 32), '~'); + FillNonindexHeaderValue(base::RandUint64(), &padding[0], padding.size()); + extra_headers->SetHeader(kPaddingHeader, padding); + + // Enables Fast Open in H2/H3 proxy client socket once the state of server + // padding support is known. + if (padding_type_by_server_[proxy_chain].has_value()) { + extra_headers->SetHeader("fastopen", "1"); + } + extra_headers->MergeFrom(extra_headers_); +} + +std::optional NaiveProxyDelegate::ParsePaddingHeaders( + const HttpResponseHeaders& headers) { + bool has_padding = headers.HasHeader(kPaddingHeader); + std::string padding_type_reply; + bool has_padding_type_reply = + headers.GetNormalizedHeader(kPaddingTypeReplyHeader, &padding_type_reply); + + if (!has_padding_type_reply) { + // Backward compatibility with before kVariant1 when the padding-version + // header does not exist. + if (has_padding) { + return PaddingType::kVariant1; + } else { + return PaddingType::kNone; + } + } + std::optional padding_type = + ParsePaddingType(padding_type_reply); + if (!padding_type.has_value()) { + LOG(ERROR) << "Received invalid padding type: " << padding_type_reply; + } + return padding_type; +} + +Error NaiveProxyDelegate::OnTunnelHeadersReceived( + const ProxyChain& proxy_chain, + size_t chain_index, + const HttpResponseHeaders& response_headers) { + // Not possible to negotiate padding capability given the underlying + // protocols. + if (proxy_chain.is_direct()) + return OK; + CHECK_EQ(proxy_chain.length(), 1u) << "Multi-hop proxy not supported"; + if (proxy_chain.GetProxyServer(chain_index).is_socks()) + return OK; + + // Detects server padding support, even if it changes dynamically. + std::optional new_padding_type = + ParsePaddingHeaders(response_headers); + if (!new_padding_type.has_value()) { + return ERR_INVALID_RESPONSE; + } + std::optional& padding_type = + padding_type_by_server_[proxy_chain]; + if (!padding_type.has_value() || padding_type != new_padding_type) { + LOG(INFO) << proxy_chain.ToDebugString() << " negotiated padding type: " + << ToReadableString(*new_padding_type); + padding_type = new_padding_type; + } + return OK; +} + +std::optional NaiveProxyDelegate::GetProxyServerPaddingType( + const ProxyChain& proxy_chain) { + // Not possible to negotiate padding capability given the underlying + // protocols. + if (proxy_chain.is_direct()) + return PaddingType::kNone; + CHECK_EQ(proxy_chain.length(), 1u) << "Multi-hop proxy not supported"; + if (proxy_chain.GetProxyServer(0).is_socks()) + return PaddingType::kNone; + + return padding_type_by_server_[proxy_chain]; +} + +PaddingDetectorDelegate::PaddingDetectorDelegate( + NaiveProxyDelegate* naive_proxy_delegate, + const ProxyChain& proxy_chain, + ClientProtocol client_protocol) + : naive_proxy_delegate_(naive_proxy_delegate), + proxy_chain_(proxy_chain), + client_protocol_(client_protocol) {} + +PaddingDetectorDelegate::~PaddingDetectorDelegate() = default; + +void PaddingDetectorDelegate::SetClientPaddingType(PaddingType padding_type) { + detected_client_padding_type_ = padding_type; +} + +std::optional PaddingDetectorDelegate::GetClientPaddingType() { + // Not possible to negotiate padding capability given the underlying + // protocols. + if (client_protocol_ == ClientProtocol::kSocks5) { + return PaddingType::kNone; + } else if (client_protocol_ == ClientProtocol::kRedir) { + return PaddingType::kNone; + } + + return detected_client_padding_type_; +} + +std::optional PaddingDetectorDelegate::GetServerPaddingType() { + if (cached_server_padding_type_.has_value()) + return cached_server_padding_type_; + cached_server_padding_type_ = + naive_proxy_delegate_->GetProxyServerPaddingType(proxy_chain_); + return cached_server_padding_type_; +} + +} // namespace net diff --git a/src/net/tools/naive/naive_proxy_delegate.h b/src/net/tools/naive/naive_proxy_delegate.h new file mode 100644 index 0000000000..c7e7c64350 --- /dev/null +++ b/src/net/tools/naive/naive_proxy_delegate.h @@ -0,0 +1,100 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ + +#include +#include +#include +#include +#include + +#include "base/strings/string_piece.h" +#include "net/base/net_errors.h" +#include "net/base/proxy_chain.h" +#include "net/base/proxy_delegate.h" +#include "net/http/http_request_headers.h" +#include "net/proxy_resolution/proxy_retry_info.h" +#include "net/tools/naive/naive_protocol.h" +#include "url/gurl.h" + +namespace net { + +void InitializeNonindexCodes(); +// |unique_bits| SHOULD have relatively unique values. +void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len); + +class ProxyInfo; + +class NaiveProxyDelegate : public ProxyDelegate { + public: + NaiveProxyDelegate(const HttpRequestHeaders& extra_headers, + const std::vector& supported_padding_types); + ~NaiveProxyDelegate() override; + + void OnResolveProxy(const GURL& url, + const NetworkAnonymizationKey& network_anonymization_key, + const std::string& method, + const ProxyRetryInfoMap& proxy_retry_info, + ProxyInfo* result) override {} + void OnFallback(const ProxyChain& bad_proxy, int net_error) override {} + + // This only affects h2 proxy client socket. + void OnBeforeTunnelRequest(const ProxyChain& proxy_chain, + size_t chain_index, + HttpRequestHeaders* extra_headers) override; + + Error OnTunnelHeadersReceived( + const ProxyChain& proxy_chain, + size_t chain_index, + const HttpResponseHeaders& response_headers) override; + + void SetProxyResolutionService( + ProxyResolutionService* proxy_resolution_service) override {} + + // Returns empty if the padding type has not been negotiated. + std::optional GetProxyServerPaddingType( + const ProxyChain& proxy_chain); + + private: + std::optional ParsePaddingHeaders( + const HttpResponseHeaders& headers); + + HttpRequestHeaders extra_headers_; + + // Empty value means padding type has not been negotiated. + std::map> padding_type_by_server_; +}; + +class ClientPaddingDetectorDelegate { + public: + virtual ~ClientPaddingDetectorDelegate() = default; + + virtual void SetClientPaddingType(PaddingType padding_type) = 0; +}; + +class PaddingDetectorDelegate : public ClientPaddingDetectorDelegate { + public: + PaddingDetectorDelegate(NaiveProxyDelegate* naive_proxy_delegate, + const ProxyChain& proxy_chain, + ClientProtocol client_protocol); + ~PaddingDetectorDelegate() override; + + std::optional GetClientPaddingType(); + std::optional GetServerPaddingType(); + void SetClientPaddingType(PaddingType padding_type) override; + + private: + NaiveProxyDelegate* naive_proxy_delegate_; + const ProxyChain& proxy_chain_; + ClientProtocol client_protocol_; + + std::optional detected_client_padding_type_; + // The result is only cached during one connection, so it's still dynamically + // updated in the following connections after server changes support. + std::optional cached_server_padding_type_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ diff --git a/src/net/tools/naive/redirect_resolver.cc b/src/net/tools/naive/redirect_resolver.cc new file mode 100644 index 0000000000..6a24deda61 --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.cc @@ -0,0 +1,243 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/redirect_resolver.h" + +#include +#include +#include + +#include "base/logging.h" +#include "base/task/single_thread_task_runner.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/url_util.h" +#include "net/dns/dns_names_util.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_util.h" +#include "net/socket/datagram_server_socket.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace { +constexpr int kUdpReadBufferSize = 1024; +constexpr int kResolutionTtl = 60; +constexpr int kResolutionRecycleTime = 60 * 5; + +std::string PackedIPv4ToString(uint32_t addr) { + return net::IPAddress(addr >> 24, addr >> 16, addr >> 8, addr).ToString(); +} +} // namespace + +namespace net { + +Resolution::Resolution() = default; + +Resolution::~Resolution() = default; + +RedirectResolver::RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix) + : socket_(std::move(socket)), + range_(range), + prefix_(prefix), + offset_(0), + buffer_(base::MakeRefCounted(kUdpReadBufferSize)) { + DCHECK(socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, base::BindOnce(&RedirectResolver::DoRead, + weak_ptr_factory_.GetWeakPtr())); +} + +RedirectResolver::~RedirectResolver() = default; + +void RedirectResolver::DoRead() { + for (;;) { + int rv = socket_->RecvFrom( + buffer_.get(), kUdpReadBufferSize, &recv_address_, + base::BindOnce(&RedirectResolver::OnRecv, base::Unretained(this))); + if (rv == ERR_IO_PENDING) + return; + rv = HandleReadResult(rv); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "DoRead: ignoring error " << ErrorToShortString(rv); + } + } +} + +void RedirectResolver::OnRecv(int result) { + int rv; + rv = HandleReadResult(result); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "OnRecv: ignoring error " << ErrorToShortString(rv); + } + + DoRead(); +} + +void RedirectResolver::OnSend(int result) { + if (result < 0) { + LOG(INFO) << "OnSend: ignoring error " << ErrorToShortString(result); + } + + DoRead(); +} + +int RedirectResolver::HandleReadResult(int result) { + if (result < 0) + return result; + + DnsQuery query(buffer_.get()); + if (!query.Parse(result)) { + LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString(); + return ERR_INVALID_ARGUMENT; + } + + auto name_or = dns_names_util::NetworkToDottedName(query.qname()); + DnsResponse response; + absl::optional query_opt; + query_opt.emplace(query.id(), query.qname(), query.qtype()); + if (!name_or || !IsCanonicalizedHostCompliant(name_or.value())) { + response = + DnsResponse(query.id(), /*is_authoritative=*/false, /*answers=*/{}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt, dns_protocol::kRcodeFORMERR); + } else if (query.qtype() != dns_protocol::kTypeA) { + response = + DnsResponse(query.id(), /*is_authoritative=*/false, /*answers=*/{}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt, dns_protocol::kRcodeNOTIMP); + } else { + Resolution res; + + const auto& name = name_or.value(); + + auto by_name_lookup = resolution_by_name_.emplace(name, resolutions_.end()); + auto by_name = by_name_lookup.first; + bool has_name = !by_name_lookup.second; + if (has_name) { + auto res_it = by_name->second; + auto by_addr = res_it->by_addr; + uint32_t addr = res_it->addr; + + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + uint32_t addr = (range_.bytes()[0] << 24) | (range_.bytes()[1] << 16) | + (range_.bytes()[2] << 8) | range_.bytes()[3]; + uint32_t subnet = ~0U >> prefix_; + addr &= ~subnet; + addr += offset_; + offset_ = (offset_ + 1) & subnet; + + auto by_addr_lookup = + resolution_by_addr_.emplace(addr, resolutions_.end()); + auto by_addr = by_addr_lookup.first; + bool has_addr = !by_addr_lookup.second; + if (has_addr) { + // Too few available addresses. Overwrites old one. + auto res_it = by_addr->second; + + LOG(INFO) << "Overwrite " << res_it->name << " " + << PackedIPv4ToString(res_it->addr) << " with " << name << " " + << PackedIPv4ToString(addr); + resolution_by_name_.erase(res_it->by_name); + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + LOG(INFO) << "Add " << name << " " << PackedIPv4ToString(addr); + resolutions_.emplace_back(); + auto res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + + // Collects garbage. + auto now = base::TimeTicks::Now(); + for (auto it = resolutions_.begin(); + it != resolutions_.end() && + (now - it->time).InSeconds() > kResolutionRecycleTime;) { + auto next = std::next(it); + LOG(INFO) << "Drop " << it->name << " " + << PackedIPv4ToString(it->addr); + resolution_by_name_.erase(it->by_name); + resolution_by_addr_.erase(it->by_addr); + resolutions_.erase(it); + it = next; + } + } + } + + DnsResourceRecord record; + record.name = name; + record.type = dns_protocol::kTypeA; + record.klass = dns_protocol::kClassIN; + record.ttl = kResolutionTtl; + uint32_t addr = by_name->second->addr; + record.SetOwnedRdata(IPAddressToPackedString( + IPAddress(addr >> 24, addr >> 16, addr >> 8, addr))); + response = DnsResponse(query.id(), /*is_authoritative=*/false, + /*answers=*/{std::move(record)}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt); + } + int size = response.io_buffer_size(); + if (size > buffer_->size() || !response.io_buffer()) { + return ERR_NO_BUFFER_SPACE; + } + std::memcpy(buffer_->data(), response.io_buffer()->data(), size); + + return socket_->SendTo( + buffer_.get(), size, recv_address_, + base::BindOnce(&RedirectResolver::OnSend, base::Unretained(this))); +} + +bool RedirectResolver::IsInResolvedRange(const IPAddress& address) const { + if (!address.IsIPv4()) + return false; + return IPAddressMatchesPrefix(address, range_, prefix_); +} + +std::string RedirectResolver::FindNameByAddress( + const IPAddress& address) const { + if (!address.IsIPv4()) + return {}; + uint32_t addr = (address.bytes()[0] << 24) | (address.bytes()[1] << 16) | + (address.bytes()[2] << 8) | address.bytes()[3]; + auto by_addr = resolution_by_addr_.find(addr); + if (by_addr == resolution_by_addr_.end()) + return {}; + return by_addr->second->name; +} + +} // namespace net diff --git a/src/net/tools/naive/redirect_resolver.h b/src/net/tools/naive/redirect_resolver.h new file mode 100644 index 0000000000..3b78c4111e --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.h @@ -0,0 +1,69 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ +#define NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ + +#include +#include +#include +#include +#include + +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/ip_address.h" +#include "net/base/ip_endpoint.h" + +namespace net { + +class DatagramServerSocket; +class IOBufferWithSize; + +struct Resolution { + Resolution(); + ~Resolution(); + + uint32_t addr; + std::string name; + base::TimeTicks time; + std::map::iterator>::iterator by_name; + std::map::iterator>::iterator by_addr; +}; + +class RedirectResolver { + public: + RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix); + ~RedirectResolver(); + RedirectResolver(const RedirectResolver&) = delete; + RedirectResolver& operator=(const RedirectResolver&) = delete; + + bool IsInResolvedRange(const IPAddress& address) const; + std::string FindNameByAddress(const IPAddress& address) const; + + private: + void DoRead(); + void OnRecv(int result); + void OnSend(int result); + int HandleReadResult(int result); + + std::unique_ptr socket_; + IPAddress range_; + size_t prefix_; + uint32_t offset_; + scoped_refptr buffer_; + IPEndPoint recv_address_; + + std::map::iterator> resolution_by_name_; + std::map::iterator> resolution_by_addr_; + std::list resolutions_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ diff --git a/src/net/tools/naive/socks5_server_socket.cc b/src/net/tools/naive/socks5_server_socket.cc new file mode 100644 index 0000000000..0be1373630 --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.cc @@ -0,0 +1,660 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/socks5_server_socket.h" + +#include +#include + +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/base/sys_addrinfo.h" +#include "net/log/net_log.h" +#include "net/log/net_log_event_type.h" + +namespace net { + +enum SocksCommandType { + kCommandConnect = 0x01, + kCommandBind = 0x02, + kCommandUDPAssociate = 0x03, +}; + +static constexpr unsigned int kGreetReadHeaderSize = 2; +static constexpr unsigned int kAuthReadHeaderSize = 2; +static constexpr unsigned int kReadHeaderSize = 5; +static constexpr char kSOCKS5Version = '\x05'; +static constexpr char kSOCKS5Reserved = '\x00'; +static constexpr char kAuthMethodNone = '\x00'; +static constexpr char kAuthMethodUserPass = '\x02'; +static constexpr char kAuthMethodNoAcceptable = '\xff'; +static constexpr char kSubnegotiationVersion = '\x01'; +static constexpr char kAuthStatusSuccess = '\x00'; +static constexpr char kAuthStatusFailure = '\xff'; +static constexpr char kReplySuccess = '\x00'; +static constexpr char kReplyCommandNotSupported = '\x07'; + +static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); +static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); + +Socks5ServerSocket::Socks5ServerSocket( + std::unique_ptr transport_socket, + const std::string& user, + const std::string& pass, + const NetworkTrafficAnnotationTag& traffic_annotation) + : io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_sent_(0), + was_ever_used_(false), + user_(user), + pass_(pass), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation) {} + +Socks5ServerSocket::~Socks5ServerSocket() { + Disconnect(); +} + +const HostPortPair& Socks5ServerSocket::request_endpoint() const { + return request_endpoint_; +} + +int Socks5ServerSocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT); + + next_state_ = STATE_GREET_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } else { + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv); + } + return rv; +} + +void Socks5ServerSocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool Socks5ServerSocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool Socks5ServerSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& Socks5ServerSocket::NetLog() const { + return net_log_; +} + +bool Socks5ServerSocket::WasEverUsed() const { + return was_ever_used_; +} + +NextProto Socks5ServerSocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool Socks5ServerSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +int64_t Socks5ServerSocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void Socks5ServerSocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int Socks5ServerSocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int Socks5ServerSocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int Socks5ServerSocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int Socks5ServerSocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void Socks5ServerSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void Socks5ServerSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT); + DoCallback(rv); + } +} + +void Socks5ServerSocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int Socks5ServerSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_GREET_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetRead(); + break; + case STATE_GREET_READ_COMPLETE: + rv = DoGreetReadComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ, + rv); + break; + case STATE_GREET_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE); + rv = DoGreetWrite(); + break; + case STATE_GREET_WRITE_COMPLETE: + rv = DoGreetWriteComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE, + rv); + break; + case STATE_AUTH_READ: + DCHECK_EQ(OK, rv); + rv = DoAuthRead(); + break; + case STATE_AUTH_READ_COMPLETE: + rv = DoAuthReadComplete(rv); + break; + case STATE_AUTH_WRITE: + DCHECK_EQ(OK, rv); + rv = DoAuthWrite(); + break; + case STATE_AUTH_WRITE_COMPLETE: + rv = DoAuthWriteComplete(rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_READ, rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int Socks5ServerSocket::DoGreetRead() { + next_state_ = STATE_GREET_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kGreetReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoGreetReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kGreetReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + int nmethods = buffer_[1]; + if (nmethods == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += nmethods; + next_state_ = STATE_GREET_READ; + return OK; + } + + if (buffer_.size() == read_header_size_) { + int nmethods = buffer_[1]; + char expected_method = kAuthMethodNone; + if (!user_.empty() || !pass_.empty()) { + expected_method = kAuthMethodUserPass; + } + void* match = + std::memchr(&buffer_[kGreetReadHeaderSize], expected_method, nmethods); + if (match) { + auth_method_ = expected_method; + } else { + auth_method_ = kAuthMethodNoAcceptable; + } + buffer_.clear(); + next_state_ = STATE_GREET_WRITE; + return OK; + } + + next_state_ = STATE_GREET_READ; + return OK; +} + +int Socks5ServerSocket::DoGreetWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSOCKS5Version, auth_method_}; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_GREET_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoGreetWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (auth_method_ == kAuthMethodNone) { + next_state_ = STATE_HANDSHAKE_READ; + } else if (auth_method_ == kAuthMethodUserPass) { + next_state_ = STATE_AUTH_READ; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_GREET_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoAuthRead() { + next_state_ = STATE_AUTH_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kAuthReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoAuthReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kAuthReadHeaderSize) { + if (buffer_[0] != kSubnegotiationVersion) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + int username_len = buffer_[1]; + read_header_size_ += username_len + 1; + next_state_ = STATE_AUTH_READ; + return OK; + } + + if (buffer_.size() == read_header_size_) { + int username_len = buffer_[1]; + int password_len = buffer_[kAuthReadHeaderSize + username_len]; + size_t password_offset = kAuthReadHeaderSize + username_len + 1; + if (buffer_.size() == password_offset && password_len != 0) { + read_header_size_ += password_len; + next_state_ = STATE_AUTH_READ; + return OK; + } + + if (buffer_.compare(kAuthReadHeaderSize, username_len, user_) == 0 && + buffer_.compare(password_offset, password_len, pass_) == 0) { + auth_status_ = kAuthStatusSuccess; + } else { + auth_status_ = kAuthStatusFailure; + } + buffer_.clear(); + next_state_ = STATE_AUTH_WRITE; + return OK; + } + + next_state_ = STATE_AUTH_READ; + return OK; +} + +int Socks5ServerSocket::DoAuthWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSubnegotiationVersion, auth_status_}; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_AUTH_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoAuthWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (auth_status_ == kAuthStatusSuccess) { + next_state_ = STATE_HANDSHAKE_READ; + } else { + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_AUTH_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + SocksCommandType command = static_cast(buffer_[1]); + if (command == kCommandConnect) { + // The proxy replies with success immediately without first connecting + // to the requested endpoint. + reply_ = kReplySuccess; + } else if (command == kCommandBind || command == kCommandUDPAssociate) { + reply_ = kReplyCommandNotSupported; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_COMMAND, + "commmand", buffer_[1]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + // We check the type of IP/Domain the server returns and accordingly + // increase the size of the request. For domains, we need to read the + // size of the domain, so the initial request size is upto the domain + // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is + // read, we substract 1 byte from the additional request size. + address_type_ = static_cast(buffer_[3]); + if (address_type_ == kEndPointDomain) { + address_size_ = static_cast(buffer_[4]); + if (address_size_ == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (address_type_ == kEndPointResolvedIPv4) { + address_size_ = sizeof(struct in_addr); + --read_header_size_; + } else if (address_type_ == kEndPointResolvedIPv6) { + address_size_ = sizeof(struct in6_addr); + --read_header_size_; + } else { + // Aborts connection on unspecified address type. + net_log_.AddEventWithIntParams( + NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type", + buffer_[3]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += address_size_ + sizeof(uint16_t); + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + // When the final bytes are read, setup handshake. + if (buffer_.size() == read_header_size_) { + size_t port_start = read_header_size_ - sizeof(uint16_t); + uint16_t port_net; + std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t)); + uint16_t port_host = base::NetToHost16(port_net); + + size_t address_start = port_start - address_size_; + if (address_type_ == kEndPointDomain) { + std::string domain(&buffer_[address_start], address_size_); + request_endpoint_ = HostPortPair(domain, port_host); + } else { + IPAddress ip_addr( + reinterpret_cast(&buffer_[address_start]), + address_size_); + IPEndPoint endpoint(ip_addr, port_host); + request_endpoint_ = HostPortPair::FromIPEndPoint(endpoint); + } + buffer_.clear(); + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; + } + + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int Socks5ServerSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + const char write_data[] = { + // clang-format off + kSOCKS5Version, + reply_, + kSOCKS5Reserved, + kEndPointResolvedIPv4, + 0x00, 0x00, 0x00, 0x00, // BND.ADDR + 0x00, 0x00, // BND.PORT + // clang-format on + }; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (reply_ == kReplySuccess) { + completed_handshake_ = true; + next_state_ = STATE_NONE; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR, + "error_code", reply_); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_HANDSHAKE_WRITE; + } + + return OK; +} + +int Socks5ServerSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int Socks5ServerSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/socks5_server_socket.h b/src/net/tools/naive/socks5_server_socket.h new file mode 100644 index 0000000000..c17e74da19 --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.h @@ -0,0 +1,165 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ +#define NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { +struct NetworkTrafficAnnotationTag; + +// This StreamSocket is used to setup a SOCKSv5 handshake with a socks client. +// Currently no SOCKSv5 authentication is supported. +class Socks5ServerSocket : public StreamSocket { + public: + Socks5ServerSocket(std::unique_ptr transport_socket, + const std::string& user, + const std::string& pass, + const NetworkTrafficAnnotationTag& traffic_annotation); + + // On destruction Disconnect() is called. + ~Socks5ServerSocket() override; + + Socks5ServerSocket(const Socks5ServerSocket&) = delete; + Socks5ServerSocket& operator=(const Socks5ServerSocket&) = delete; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_GREET_READ, + STATE_GREET_READ_COMPLETE, + STATE_GREET_WRITE, + STATE_GREET_WRITE_COMPLETE, + STATE_AUTH_READ, + STATE_AUTH_READ_COMPLETE, + STATE_AUTH_WRITE, + STATE_AUTH_WRITE_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + // Addressing type that can be specified in requests or responses. + enum SocksEndPointAddressType { + kEndPointDomain = 0x03, + kEndPointResolvedIPv4 = 0x01, + kEndPointResolvedIPv6 = 0x04, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoGreetRead(); + int DoGreetReadComplete(int result); + int DoGreetWrite(); + int DoGreetWriteComplete(int result); + int DoAuthRead(); + int DoAuthReadComplete(int result); + int DoAuthWrite(); + int DoAuthWriteComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // Contains the bytes sent by the SOCKS handshake. + size_t bytes_sent_; + + size_t read_header_size_; + + bool was_ever_used_; + + SocksEndPointAddressType address_type_; + int address_size_; + + std::string user_; + std::string pass_; + char auth_method_; + char auth_status_; + char reply_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_