diff --git a/src/net/BUILD.gn b/src/net/BUILD.gn index 111980f51b..72f6f601f0 100644 --- a/src/net/BUILD.gn +++ b/src/net/BUILD.gn @@ -1753,3 +1753,35 @@ static_library("preload_decoder") { ] deps = [ "//base" ] } + +executable("naive") { + sources = [ + "tools/naive/naive_connection.cc", + "tools/naive/naive_connection.h", + "tools/naive/naive_proxy.cc", + "tools/naive/naive_proxy.h", + "tools/naive/naive_proxy_bin.cc", + "tools/naive/naive_proxy_delegate.h", + "tools/naive/naive_proxy_delegate.cc", + "tools/naive/http_proxy_socket.cc", + "tools/naive/http_proxy_socket.h", + "tools/naive/redirect_resolver.h", + "tools/naive/redirect_resolver.cc", + "tools/naive/socks5_server_socket.cc", + "tools/naive/socks5_server_socket.h", + "tools/naive/partition_alloc_support.cc", + "tools/naive/partition_alloc_support.h", + ] + + deps = [ + ":net", + "//base", + "//build/win:default_exe_manifest", + "//components/version_info:version_info", + "//url", + ] + + if (is_apple) { + deps += [ "//base/allocator:early_zone_registration_mac" ] + } +} diff --git a/src/net/log/net_log_event_type_list.h b/src/net/log/net_log_event_type_list.h index 77dbfe9bd6..c69db6e82d 100644 --- a/src/net/log/net_log_event_type_list.h +++ b/src/net/log/net_log_event_type_list.h @@ -488,6 +488,11 @@ EVENT_TYPE(SOCKS_HOSTNAME_TOO_BIG) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE) +EVENT_TYPE(SOCKS_NO_REQUESTED_AUTH) +EVENT_TYPE(SOCKS_NO_ACCEPTABLE_AUTH) +EVENT_TYPE(SOCKS_ZERO_LENGTH_DOMAIN) +EVENT_TYPE(SOCKS_UNEXPECTED_COMMAND) + // This event indicates that a bad version number was received in the // proxy server's response. The extra parameters show its value: // { diff --git a/src/net/tools/naive/http_proxy_socket.cc b/src/net/tools/naive/http_proxy_socket.cc new file mode 100644 index 0000000000..4503a0476b --- /dev/null +++ b/src/net/tools/naive/http_proxy_socket.cc @@ -0,0 +1,356 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/http_proxy_socket.h" + +#include +#include + +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/http/http_request_headers.h" +#include "net/log/net_log.h" +#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h" +#include "net/tools/naive/naive_proxy_delegate.h" + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +constexpr size_t kMaxHeaderSize = 64 * 1024; +constexpr char kResponseHeader[] = "HTTP/1.1 200 OK\r\nPadding: "; +constexpr int kResponseHeaderSize = sizeof(kResponseHeader) - 1; +// A plain 200 is 10 bytes. Expected 48 bytes. "Padding" uses up 7 bytes. +constexpr int kMinPaddingSize = 30; +constexpr int kMaxPaddingSize = kMinPaddingSize + 32; +} // namespace + +HttpProxySocket::HttpProxySocket( + std::unique_ptr transport_socket, + ClientPaddingDetectorDelegate* padding_detector_delegate, + const NetworkTrafficAnnotationTag& traffic_annotation) + : io_callback_(base::BindRepeating(&HttpProxySocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + padding_detector_delegate_(padding_detector_delegate), + next_state_(STATE_NONE), + completed_handshake_(false), + was_ever_used_(false), + header_write_size_(-1), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation) {} + +HttpProxySocket::~HttpProxySocket() { + Disconnect(); +} + +const HostPortPair& HttpProxySocket::request_endpoint() const { + return request_endpoint_; +} + +int HttpProxySocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_HEADER_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } + return rv; +} + +void HttpProxySocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool HttpProxySocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool HttpProxySocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& HttpProxySocket::NetLog() const { + return net_log_; +} + +bool HttpProxySocket::WasEverUsed() const { + return was_ever_used_; +} + +bool HttpProxySocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto HttpProxySocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool HttpProxySocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +int64_t HttpProxySocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void HttpProxySocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the HTTP header is complete. +int HttpProxySocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + if (!buffer_.empty()) { + was_ever_used_ = true; + int data_len = buffer_.size(); + if (data_len <= buf_len) { + std::memcpy(buf->data(), buffer_.data(), data_len); + buffer_.clear(); + return data_len; + } else { + std::memcpy(buf->data(), buffer_.data(), buf_len); + buffer_ = buffer_.substr(buf_len); + return buf_len; + } + } + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int HttpProxySocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int HttpProxySocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int HttpProxySocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void HttpProxySocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void HttpProxySocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +void HttpProxySocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int HttpProxySocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_HEADER_READ: + DCHECK_EQ(OK, rv); + rv = DoHeaderRead(); + break; + case STATE_HEADER_READ_COMPLETE: + rv = DoHeaderReadComplete(rv); + break; + case STATE_HEADER_WRITE: + DCHECK_EQ(OK, rv); + rv = DoHeaderWrite(); + break; + case STATE_HEADER_WRITE_COMPLETE: + rv = DoHeaderWriteComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int HttpProxySocket::DoHeaderRead() { + next_state_ = STATE_HEADER_READ_COMPLETE; + + handshake_buf_ = base::MakeRefCounted(kBufferSize); + return transport_->Read(handshake_buf_.get(), kBufferSize, io_callback_); +} + +int HttpProxySocket::DoHeaderReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_CONNECTION_CLOSED; + } + + buffer_.append(handshake_buf_->data(), result); + if (buffer_.size() > kMaxHeaderSize) { + return ERR_MSG_TOO_BIG; + } + + auto header_end = buffer_.find("\r\n\r\n"); + if (header_end == std::string::npos) { + next_state_ = STATE_HEADER_READ; + return OK; + } + + // HttpProxyClientSocket uses CONNECT for all endpoints. + auto first_line_end = buffer_.find("\r\n"); + auto first_space = buffer_.find(' '); + if (first_space == std::string::npos || first_space + 1 >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + if (buffer_.compare(0, first_space, "CONNECT") != 0) { + return ERR_INVALID_ARGUMENT; + } + auto second_space = buffer_.find(' ', first_space + 1); + if (second_space == std::string::npos || second_space >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + request_endpoint_ = HostPortPair::FromString( + buffer_.substr(first_space + 1, second_space - (first_space + 1))); + + auto second_line = first_line_end + 2; + HttpRequestHeaders headers; + std::string headers_str; + if (second_line < header_end) { + headers_str = buffer_.substr(second_line, header_end - second_line); + headers.AddHeadersFromString(headers_str); + } + if (headers.HasHeader("padding")) { + padding_detector_delegate_->SetClientPaddingSupport( + PaddingSupport::kCapable); + } else { + padding_detector_delegate_->SetClientPaddingSupport( + PaddingSupport::kIncapable); + } + + buffer_ = buffer_.substr(header_end + 4); + + next_state_ = STATE_HEADER_WRITE; + return OK; +} + +int HttpProxySocket::DoHeaderWrite() { + next_state_ = STATE_HEADER_WRITE_COMPLETE; + + // Adds padding. + int padding_size = base::RandInt(kMinPaddingSize, kMaxPaddingSize); + header_write_size_ = kResponseHeaderSize + padding_size + 4; + handshake_buf_ = base::MakeRefCounted(header_write_size_); + char* p = handshake_buf_->data(); + std::memcpy(p, kResponseHeader, kResponseHeaderSize); + FillNonindexHeaderValue(base::RandUint64(), p + kResponseHeaderSize, + padding_size); + std::memcpy(p + kResponseHeaderSize + padding_size, "\r\n\r\n", 4); + + return transport_->Write(handshake_buf_.get(), header_write_size_, + io_callback_, traffic_annotation_); +} + +int HttpProxySocket::DoHeaderWriteComplete(int result) { + if (result < 0) + return result; + + if (result != header_write_size_) { + return ERR_FAILED; + } + + completed_handshake_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int HttpProxySocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int HttpProxySocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/http_proxy_socket.h b/src/net/tools/naive/http_proxy_socket.h new file mode 100644 index 0000000000..aa836b41fb --- /dev/null +++ b/src/net/tools/naive/http_proxy_socket.h @@ -0,0 +1,122 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ +#define NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { +struct NetworkTrafficAnnotationTag; +class ClientPaddingDetectorDelegate; + +// This StreamSocket is used to setup a HTTP CONNECT tunnel. +class HttpProxySocket : public StreamSocket { + public: + HttpProxySocket(std::unique_ptr transport_socket, + ClientPaddingDetectorDelegate* padding_detector_delegate, + const NetworkTrafficAnnotationTag& traffic_annotation); + HttpProxySocket(const HttpProxySocket&) = delete; + HttpProxySocket& operator=(const HttpProxySocket&) = delete; + + // On destruction Disconnect() is called. + ~HttpProxySocket() override; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_HEADER_READ, + STATE_HEADER_READ_COMPLETE, + STATE_HEADER_WRITE, + STATE_HEADER_WRITE_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoHeaderWrite(); + int DoHeaderWriteComplete(int result); + int DoHeaderRead(); + int DoHeaderReadComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + ClientPaddingDetectorDelegate* padding_detector_delegate_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + std::string buffer_; + bool completed_handshake_; + bool was_ever_used_; + int header_write_size_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc new file mode 100644 index 0000000000..149445401e --- /dev/null +++ b/src/net/tools/naive/naive_connection.cc @@ -0,0 +1,573 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_connection.h" + +#include +#include + +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/strings/strcat.h" +#include "base/task/single_thread_task_runner.h" +#include "base/time/time.h" +#include "build/build_config.h" +#include "net/base/io_buffer.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/base/privacy_mode.h" +#include "net/base/url_util.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/stream_socket.h" +#include "net/spdy/spdy_session.h" +#include "net/tools/naive/http_proxy_socket.h" +#include "net/tools/naive/redirect_resolver.h" +#include "net/tools/naive/socks5_server_socket.h" +#include "url/scheme_host_port.h" + +#if BUILDFLAG(IS_LINUX) +#include +#include +#include + +#include "net/base/ip_endpoint.h" +#include "net/base/sockaddr_storage.h" +#include "net/socket/tcp_client_socket.h" +#endif + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +constexpr int kFirstPaddings = 8; +constexpr int kPaddingHeaderSize = 3; +constexpr int kMaxPaddingSize = 255; +} // namespace + +NaiveConnection::NaiveConnection( + unsigned int id, + ClientProtocol protocol, + std::unique_ptr padding_detector_delegate, + const ProxyInfo& proxy_info, + const SSLConfig& server_ssl_config, + const SSLConfig& proxy_ssl_config, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkAnonymizationKey& network_anonymization_key, + const NetLogWithSource& net_log, + std::unique_ptr accepted_socket, + const NetworkTrafficAnnotationTag& traffic_annotation) + : id_(id), + protocol_(protocol), + padding_detector_delegate_(std::move(padding_detector_delegate)), + proxy_info_(proxy_info), + server_ssl_config_(server_ssl_config), + proxy_ssl_config_(proxy_ssl_config), + resolver_(resolver), + session_(session), + network_anonymization_key_(network_anonymization_key), + net_log_(net_log), + next_state_(STATE_NONE), + client_socket_(std::move(accepted_socket)), + server_socket_handle_(std::make_unique()), + sockets_{client_socket_.get(), nullptr}, + errors_{OK, OK}, + write_pending_{false, false}, + early_pull_pending_(false), + can_push_to_server_(false), + early_pull_result_(ERR_IO_PENDING), + num_paddings_{0, 0}, + read_padding_state_(STATE_READ_PAYLOAD_LENGTH_1), + full_duplex_(false), + time_func_(&base::TimeTicks::Now), + traffic_annotation_(traffic_annotation) { + io_callback_ = base::BindRepeating(&NaiveConnection::OnIOComplete, + weak_ptr_factory_.GetWeakPtr()); +} + +NaiveConnection::~NaiveConnection() { + Disconnect(); +} + +int NaiveConnection::Connect(CompletionOnceCallback callback) { + DCHECK(client_socket_); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + if (full_duplex_) + return OK; + + next_state_ = STATE_CONNECT_CLIENT; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = std::move(callback); + } + return rv; +} + +void NaiveConnection::Disconnect() { + full_duplex_ = false; + // Closes server side first because latency is higher. + if (server_socket_handle_->socket()) + server_socket_handle_->socket()->Disconnect(); + client_socket_->Disconnect(); + + next_state_ = STATE_NONE; + connect_callback_.Reset(); + run_callback_.Reset(); +} + +void NaiveConnection::DoCallback(int result) { + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(connect_callback_); + + // Since Run() may result in Read being called, + // clear connect_callback_ up front. + std::move(connect_callback_).Run(result); +} + +void NaiveConnection::OnIOComplete(int result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +int NaiveConnection::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT_CLIENT: + DCHECK_EQ(rv, OK); + rv = DoConnectClient(); + break; + case STATE_CONNECT_CLIENT_COMPLETE: + rv = DoConnectClientComplete(rv); + break; + case STATE_CONNECT_SERVER: + DCHECK_EQ(rv, OK); + rv = DoConnectServer(); + break; + case STATE_CONNECT_SERVER_COMPLETE: + rv = DoConnectServerComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int NaiveConnection::DoConnectClient() { + next_state_ = STATE_CONNECT_CLIENT_COMPLETE; + + return client_socket_->Connect(io_callback_); +} + +int NaiveConnection::DoConnectClientComplete(int result) { + if (result < 0) + return result; + + // For proxy client sockets, padding support detection is finished after the + // first server response which means there will be one missed early pull. For + // proxy server sockets (HttpProxySocket), padding support detection is + // done during client connect, so there shouldn't be any missed early pull. + if (!padding_detector_delegate_->IsPaddingSupportKnown()) { + early_pull_pending_ = false; + early_pull_result_ = 0; + next_state_ = STATE_CONNECT_SERVER; + return OK; + } + + early_pull_pending_ = true; + Pull(kClient, kServer); + if (early_pull_result_ != ERR_IO_PENDING) { + // Pull has completed synchronously. + if (early_pull_result_ <= 0) { + return early_pull_result_ ? early_pull_result_ : ERR_CONNECTION_CLOSED; + } + } + + next_state_ = STATE_CONNECT_SERVER; + return OK; +} + +int NaiveConnection::DoConnectServer() { + next_state_ = STATE_CONNECT_SERVER_COMPLETE; + + HostPortPair origin; + if (protocol_ == ClientProtocol::kSocks5) { + const auto* socket = + static_cast(client_socket_.get()); + origin = socket->request_endpoint(); + } else if (protocol_ == ClientProtocol::kHttp) { + const auto* socket = + static_cast(client_socket_.get()); + origin = socket->request_endpoint(); + } else if (protocol_ == ClientProtocol::kRedir) { +#if BUILDFLAG(IS_LINUX) + const auto* socket = + static_cast(client_socket_.get()); + IPEndPoint peer_endpoint; + int rv; + rv = socket->GetPeerAddress(&peer_endpoint); + if (rv != OK) { + LOG(ERROR) << "Connection " << id_ << " cannot get peer address"; + return rv; + } + int sd = socket->SocketDescriptorForTesting(); + SockaddrStorage dst; + if (peer_endpoint.GetFamily() == ADDRESS_FAMILY_IPV4 || + peer_endpoint.address().IsIPv4MappedIPv6()) { + rv = getsockopt(sd, SOL_IP, SO_ORIGINAL_DST, dst.addr, &dst.addr_len); + } else { + rv = getsockopt(sd, SOL_IPV6, SO_ORIGINAL_DST, dst.addr, &dst.addr_len); + } + if (rv == 0) { + IPEndPoint ipe; + if (ipe.FromSockAddr(dst.addr, dst.addr_len)) { + const auto& addr = ipe.address(); + auto name = resolver_->FindNameByAddress(addr); + if (!name.empty()) { + origin = HostPortPair(name, ipe.port()); + } else if (!resolver_->IsInResolvedRange(addr)) { + origin = HostPortPair::FromIPEndPoint(ipe); + } else { + LOG(ERROR) << "Connection " << id_ << " to unresolved name for " + << addr.ToString(); + return ERR_ADDRESS_INVALID; + } + } + } else { + LOG(ERROR) << "Failed to get original destination address"; + return ERR_ADDRESS_INVALID; + } +#else + static_cast(resolver_); +#endif + } + + url::CanonHostInfo host_info; + url::SchemeHostPort endpoint( + "http", CanonicalizeHost(origin.HostForURL(), &host_info), origin.port(), + url::SchemeHostPort::ALREADY_CANONICALIZED); + if (!endpoint.IsValid()) { + LOG(ERROR) << "Connection " << id_ << " to invalid origin " + << origin.ToString(); + return ERR_ADDRESS_INVALID; + } + + LOG(INFO) << "Connection " << id_ << " to " << origin.ToString(); + + // Ignores socket limit set by socket pool for this type of socket. + return InitSocketHandleForRawConnect2( + std::move(endpoint), LOAD_IGNORE_LIMITS, MAXIMUM_PRIORITY, session_, + proxy_info_, server_ssl_config_, proxy_ssl_config_, PRIVACY_MODE_DISABLED, + network_anonymization_key_, net_log_, server_socket_handle_.get(), + io_callback_); +} + +int NaiveConnection::DoConnectServerComplete(int result) { + if (result < 0) + return result; + + DCHECK(server_socket_handle_->socket()); + sockets_[kServer] = server_socket_handle_->socket(); + + full_duplex_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int NaiveConnection::Run(CompletionOnceCallback callback) { + DCHECK(sockets_[kClient]); + DCHECK(sockets_[kServer]); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + if (errors_[kClient] != OK) + return errors_[kClient]; + if (errors_[kServer] != OK) + return errors_[kServer]; + + run_callback_ = std::move(callback); + + bytes_passed_without_yielding_[kClient] = 0; + bytes_passed_without_yielding_[kServer] = 0; + + yield_after_time_[kClient] = + time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds); + yield_after_time_[kServer] = yield_after_time_[kClient]; + + can_push_to_server_ = true; + // early_pull_result_ == 0 means the early pull was not started because + // padding support was not yet known. + if (!early_pull_pending_ && early_pull_result_ == 0) { + Pull(kClient, kServer); + } else if (!early_pull_pending_) { + DCHECK_GT(early_pull_result_, 0); + Push(kClient, kServer, early_pull_result_); + } + Pull(kServer, kClient); + + return ERR_IO_PENDING; +} + +void NaiveConnection::Pull(Direction from, Direction to) { + if (errors_[kClient] < 0 || errors_[kServer] < 0) + return; + + int read_size = kBufferSize; + auto padding_direction = padding_detector_delegate_->GetPaddingDirection(); + if (from == padding_direction && num_paddings_[from] < kFirstPaddings) { + auto buffer = base::MakeRefCounted(); + buffer->SetCapacity(kBufferSize); + buffer->set_offset(kPaddingHeaderSize); + read_buffers_[from] = buffer; + read_size = kBufferSize - kPaddingHeaderSize - kMaxPaddingSize; + } else { + read_buffers_[from] = base::MakeRefCounted(kBufferSize); + } + + DCHECK(sockets_[from]); + int rv = sockets_[from]->Read( + read_buffers_[from].get(), read_size, + base::BindRepeating(&NaiveConnection::OnPullComplete, + weak_ptr_factory_.GetWeakPtr(), from, to)); + + if (from == kClient && early_pull_pending_) + early_pull_result_ = rv; + + if (rv != ERR_IO_PENDING) + OnPullComplete(from, to, rv); +} + +void NaiveConnection::Push(Direction from, Direction to, int size) { + int write_size = size; + int write_offset = 0; + auto padding_direction = padding_detector_delegate_->GetPaddingDirection(); + if (from == padding_direction && num_paddings_[from] < kFirstPaddings) { + // Adds padding. + ++num_paddings_[from]; + int padding_size = base::RandInt(0, kMaxPaddingSize); + auto* buffer = static_cast(read_buffers_[from].get()); + buffer->set_offset(0); + uint8_t* p = reinterpret_cast(buffer->data()); + p[0] = size / 256; + p[1] = size % 256; + p[2] = padding_size; + std::memset(p + kPaddingHeaderSize + size, 0, padding_size); + write_size = kPaddingHeaderSize + size + padding_size; + } else if (to == padding_direction && num_paddings_[from] < kFirstPaddings) { + // Removes padding. + const char* p = read_buffers_[from]->data(); + bool trivial_padding = false; + if (read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1 && + size >= kPaddingHeaderSize) { + int payload_size = + static_cast(p[0]) * 256 + static_cast(p[1]); + int padding_size = static_cast(p[2]); + if (size == kPaddingHeaderSize + payload_size + padding_size) { + write_size = payload_size; + write_offset = kPaddingHeaderSize; + ++num_paddings_[from]; + trivial_padding = true; + } + } + if (!trivial_padding) { + auto unpadded_buffer = base::MakeRefCounted(kBufferSize); + char* unpadded_ptr = unpadded_buffer->data(); + for (int i = 0; i < size;) { + if (num_paddings_[from] >= kFirstPaddings && + read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1) { + std::memcpy(unpadded_ptr, p + i, size - i); + unpadded_ptr += size - i; + break; + } + int copy_size; + switch (read_padding_state_) { + case STATE_READ_PAYLOAD_LENGTH_1: + payload_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_2; + break; + case STATE_READ_PAYLOAD_LENGTH_2: + payload_length_ = + payload_length_ * 256 + static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PADDING_LENGTH; + break; + case STATE_READ_PADDING_LENGTH: + padding_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD; + break; + case STATE_READ_PAYLOAD: + if (payload_length_ <= size - i) { + copy_size = payload_length_; + read_padding_state_ = STATE_READ_PADDING; + } else { + copy_size = size - i; + } + std::memcpy(unpadded_ptr, p + i, copy_size); + unpadded_ptr += copy_size; + i += copy_size; + payload_length_ -= copy_size; + break; + case STATE_READ_PADDING: + if (padding_length_ <= size - i) { + copy_size = padding_length_; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_1; + ++num_paddings_[from]; + } else { + copy_size = size - i; + } + i += copy_size; + padding_length_ -= copy_size; + break; + } + } + write_size = unpadded_ptr - unpadded_buffer->data(); + read_buffers_[from] = unpadded_buffer; + } + if (write_size == 0) { + OnPushComplete(from, to, OK); + return; + } + } + + write_buffers_[to] = base::MakeRefCounted( + std::move(read_buffers_[from]), write_offset + write_size); + if (write_offset) { + write_buffers_[to]->DidConsume(write_offset); + } + write_pending_[to] = true; + DCHECK(sockets_[to]); + int rv = sockets_[to]->Write( + write_buffers_[to].get(), write_size, + base::BindRepeating(&NaiveConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to), + traffic_annotation_); + + if (rv != ERR_IO_PENDING) + OnPushComplete(from, to, rv); +} + +void NaiveConnection::Disconnect(Direction side) { + if (sockets_[side]) { + sockets_[side]->Disconnect(); + sockets_[side] = nullptr; + write_pending_[side] = false; + } +} + +bool NaiveConnection::IsConnected(Direction side) { + return sockets_[side]; +} + +void NaiveConnection::OnBothDisconnected() { + if (run_callback_) { + int error = OK; + if (errors_[kClient] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kClient]; + if (errors_[kServer] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kServer]; + std::move(run_callback_).Run(error); + } +} + +void NaiveConnection::OnPullError(Direction from, Direction to, int error) { + DCHECK_LT(error, 0); + + errors_[from] = error; + Disconnect(from); + + if (!write_pending_[to]) + Disconnect(to); + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPushError(Direction from, Direction to, int error) { + DCHECK_LE(error, 0); + DCHECK(!write_pending_[to]); + + if (error < 0) { + errors_[to] = error; + Disconnect(kServer); + Disconnect(kClient); + } else if (!IsConnected(from)) { + Disconnect(to); + } + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) { + if (from == kClient && early_pull_pending_) { + early_pull_pending_ = false; + early_pull_result_ = result ? result : ERR_CONNECTION_CLOSED; + } + + if (result <= 0) { + OnPullError(from, to, result ? result : ERR_CONNECTION_CLOSED); + return; + } + + if (from == kClient && !can_push_to_server_) + return; + + Push(from, to, result); +} + +void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) { + if (result >= 0 && write_buffers_[to] != nullptr) { + bytes_passed_without_yielding_[from] += result; + write_buffers_[to]->DidConsume(result); + int size = write_buffers_[to]->BytesRemaining(); + if (size > 0) { + int rv = sockets_[to]->Write( + write_buffers_[to].get(), size, + base::BindRepeating(&NaiveConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to), + traffic_annotation_); + if (rv != ERR_IO_PENDING) + OnPushComplete(from, to, rv); + return; + } + } + + write_pending_[to] = false; + // Checks for termination even if result is OK. + OnPushError(from, to, result >= 0 ? OK : result); + + if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead || + time_func_() > yield_after_time_[from]) { + bytes_passed_without_yielding_[from] = 0; + yield_after_time_[from] = + time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds); + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, + base::BindRepeating(&NaiveConnection::Pull, + weak_ptr_factory_.GetWeakPtr(), from, to)); + } else { + Pull(from, to); + } +} + +} // namespace net diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h new file mode 100644 index 0000000000..36f1ff4c06 --- /dev/null +++ b/src/net/tools/naive/naive_connection.h @@ -0,0 +1,142 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ +#define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ + +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/tools/naive/naive_proxy_delegate.h" + +namespace net { + +class ClientSocketHandle; +class DrainableIOBuffer; +class HttpNetworkSession; +class IOBuffer; +class NetLogWithSource; +class ProxyInfo; +class StreamSocket; +struct NetworkTrafficAnnotationTag; +struct SSLConfig; +class RedirectResolver; +class NetworkAnonymizationKey; + +class NaiveConnection { + public: + using TimeFunc = base::TimeTicks (*)(); + + NaiveConnection( + unsigned int id, + ClientProtocol protocol, + std::unique_ptr padding_detector_delegate, + const ProxyInfo& proxy_info, + const SSLConfig& server_ssl_config, + const SSLConfig& proxy_ssl_config, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkAnonymizationKey& network_anonymization_key, + const NetLogWithSource& net_log, + std::unique_ptr accepted_socket, + const NetworkTrafficAnnotationTag& traffic_annotation); + ~NaiveConnection(); + NaiveConnection(const NaiveConnection&) = delete; + NaiveConnection& operator=(const NaiveConnection&) = delete; + + unsigned int id() const { return id_; } + int Connect(CompletionOnceCallback callback); + void Disconnect(); + int Run(CompletionOnceCallback callback); + + private: + enum State { + STATE_CONNECT_CLIENT, + STATE_CONNECT_CLIENT_COMPLETE, + STATE_CONNECT_SERVER, + STATE_CONNECT_SERVER_COMPLETE, + STATE_NONE, + }; + + enum PaddingState { + STATE_READ_PAYLOAD_LENGTH_1, + STATE_READ_PAYLOAD_LENGTH_2, + STATE_READ_PADDING_LENGTH, + STATE_READ_PAYLOAD, + STATE_READ_PADDING, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + int DoLoop(int last_io_result); + int DoConnectClient(); + int DoConnectClientComplete(int result); + int DoConnectServer(); + int DoConnectServerComplete(int result); + void Pull(Direction from, Direction to); + void Push(Direction from, Direction to, int size); + void Disconnect(Direction side); + bool IsConnected(Direction side); + void OnBothDisconnected(); + void OnPullError(Direction from, Direction to, int error); + void OnPushError(Direction from, Direction to, int error); + void OnPullComplete(Direction from, Direction to, int result); + void OnPushComplete(Direction from, Direction to, int result); + + unsigned int id_; + ClientProtocol protocol_; + std::unique_ptr padding_detector_delegate_; + const ProxyInfo& proxy_info_; + const SSLConfig& server_ssl_config_; + const SSLConfig& proxy_ssl_config_; + RedirectResolver* resolver_; + HttpNetworkSession* session_; + const NetworkAnonymizationKey& network_anonymization_key_; + const NetLogWithSource& net_log_; + + CompletionRepeatingCallback io_callback_; + CompletionOnceCallback connect_callback_; + CompletionOnceCallback run_callback_; + + State next_state_; + + std::unique_ptr client_socket_; + std::unique_ptr server_socket_handle_; + + StreamSocket* sockets_[kNumDirections]; + scoped_refptr read_buffers_[kNumDirections]; + scoped_refptr write_buffers_[kNumDirections]; + int errors_[kNumDirections]; + bool write_pending_[kNumDirections]; + int bytes_passed_without_yielding_[kNumDirections]; + base::TimeTicks yield_after_time_[kNumDirections]; + + bool early_pull_pending_; + bool can_push_to_server_; + int early_pull_result_; + + int num_paddings_[kNumDirections]; + PaddingState read_padding_state_; + int payload_length_; + int padding_length_; + + bool full_duplex_; + + TimeFunc time_func_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ diff --git a/src/net/tools/naive/naive_protocol.h b/src/net/tools/naive/naive_protocol.h new file mode 100644 index 0000000000..e3e617fbfa --- /dev/null +++ b/src/net/tools/naive/naive_protocol.h @@ -0,0 +1,24 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ + +namespace net { +enum class ClientProtocol { + kSocks5, + kHttp, + kRedir, +}; + +// Adds padding for traffic from this direction. +// Removes padding for traffic from the opposite direction. +enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + kNone = 2, +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_ diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc new file mode 100644 index 0000000000..554af46bf8 --- /dev/null +++ b/src/net/tools/naive/naive_proxy.cc @@ -0,0 +1,212 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_proxy.h" + +#include + +#include "base/functional/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/task/single_thread_task_runner.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/proxy_resolution/configured_proxy_resolution_service.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_list.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/http_proxy_socket.h" +#include "net/tools/naive/naive_proxy_delegate.h" +#include "net/tools/naive/socks5_server_socket.h" + +namespace net { + +NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, + ClientProtocol protocol, + const std::string& listen_user, + const std::string& listen_pass, + int concurrency, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation) + : listen_socket_(std::move(listen_socket)), + protocol_(protocol), + listen_user_(listen_user), + listen_pass_(listen_pass), + concurrency_(concurrency), + resolver_(resolver), + session_(session), + net_log_( + NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)), + last_id_(0), + traffic_annotation_(traffic_annotation) { + const auto& proxy_config = static_cast( + session_->proxy_resolution_service()) + ->config(); + DCHECK(proxy_config); + const ProxyList& proxy_list = + proxy_config.value().value().proxy_rules().single_proxies; + DCHECK(!proxy_list.IsEmpty()); + proxy_info_.UseProxyList(proxy_list); + proxy_info_.set_traffic_annotation( + net::MutableNetworkTrafficAnnotationTag(traffic_annotation_)); + + // See HttpStreamFactory::Job::DoInitConnectionImpl() + proxy_ssl_config_.disable_cert_verification_network_fetches = true; + server_ssl_config_.alpn_protos = session_->GetAlpnProtos(); + proxy_ssl_config_.alpn_protos = session_->GetAlpnProtos(); + server_ssl_config_.application_settings = session_->GetApplicationSettings(); + proxy_ssl_config_.application_settings = session_->GetApplicationSettings(); + server_ssl_config_.ignore_certificate_errors = + session_->params().ignore_certificate_errors; + proxy_ssl_config_.ignore_certificate_errors = + session_->params().ignore_certificate_errors; + // TODO(https://crbug.com/964642): Also enable 0-RTT for TLS proxies. + server_ssl_config_.early_data_enabled = session_->params().enable_early_data; + + for (int i = 0; i < concurrency_; i++) { + network_anonymization_keys_.push_back( + NetworkAnonymizationKey::CreateTransient()); + } + + DCHECK(listen_socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, base::BindOnce(&NaiveProxy::DoAcceptLoop, + weak_ptr_factory_.GetWeakPtr())); +} + +NaiveProxy::~NaiveProxy() = default; + +void NaiveProxy::DoAcceptLoop() { + int result; + do { + result = listen_socket_->Accept( + &accepted_socket_, base::BindRepeating(&NaiveProxy::OnAcceptComplete, + weak_ptr_factory_.GetWeakPtr())); + if (result == ERR_IO_PENDING) + return; + HandleAcceptResult(result); + } while (result == OK); +} + +void NaiveProxy::OnAcceptComplete(int result) { + HandleAcceptResult(result); + if (result == OK) + DoAcceptLoop(); +} + +void NaiveProxy::HandleAcceptResult(int result) { + if (result != OK) { + LOG(ERROR) << "Accept error: rv=" << result; + return; + } + DoConnect(); +} + +void NaiveProxy::DoConnect() { + std::unique_ptr socket; + auto* proxy_delegate = + static_cast(session_->context().proxy_delegate); + DCHECK(proxy_delegate); + DCHECK(!proxy_info_.is_empty()); + const auto& proxy_server = proxy_info_.proxy_server(); + auto padding_detector_delegate = std::make_unique( + proxy_delegate, proxy_server, protocol_); + + if (protocol_ == ClientProtocol::kSocks5) { + socket = std::make_unique(std::move(accepted_socket_), + listen_user_, listen_pass_, + traffic_annotation_); + } else if (protocol_ == ClientProtocol::kHttp) { + socket = std::make_unique(std::move(accepted_socket_), + padding_detector_delegate.get(), + traffic_annotation_); + } else if (protocol_ == ClientProtocol::kRedir) { + socket = std::move(accepted_socket_); + } else { + return; + } + + last_id_++; + const auto& nak = network_anonymization_keys_[last_id_ % concurrency_]; + auto connection_ptr = std::make_unique( + last_id_, protocol_, std::move(padding_detector_delegate), proxy_info_, + server_ssl_config_, proxy_ssl_config_, resolver_, session_, nak, net_log_, + std::move(socket), traffic_annotation_); + auto* connection = connection_ptr.get(); + connection_by_id_[connection->id()] = std::move(connection_ptr); + int result = connection->Connect( + base::BindRepeating(&NaiveProxy::OnConnectComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleConnectResult(connection, result); +} + +void NaiveProxy::OnConnectComplete(unsigned int connection_id, int result) { + auto* connection = FindConnection(connection_id); + if (!connection) + return; + HandleConnectResult(connection, result); +} + +void NaiveProxy::HandleConnectResult(NaiveConnection* connection, int result) { + if (result != OK) { + Close(connection->id(), result); + return; + } + DoRun(connection); +} + +void NaiveProxy::DoRun(NaiveConnection* connection) { + int result = connection->Run( + base::BindRepeating(&NaiveProxy::OnRunComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::OnRunComplete(unsigned int connection_id, int result) { + auto* connection = FindConnection(connection_id); + if (!connection) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::HandleRunResult(NaiveConnection* connection, int result) { + Close(connection->id(), result); +} + +void NaiveProxy::Close(unsigned int connection_id, int reason) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return; + + LOG(INFO) << "Connection " << connection_id + << " closed: " << ErrorToShortString(reason); + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::SingleThreadTaskRunner::GetCurrentDefault()->DeleteSoon( + FROM_HERE, std::move(it->second)); + connection_by_id_.erase(it); +} + +NaiveConnection* NaiveProxy::FindConnection(unsigned int connection_id) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return nullptr; + return it->second.get(); +} + +} // namespace net diff --git a/src/net/tools/naive/naive_proxy.h b/src/net/tools/naive/naive_proxy.h new file mode 100644 index 0000000000..7f6f407d66 --- /dev/null +++ b/src/net/tools/naive/naive_proxy.h @@ -0,0 +1,89 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROXY_H_ + +#include +#include +#include + +#include "base/memory/weak_ptr.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/network_isolation_key.h" +#include "net/log/net_log_with_source.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/ssl/ssl_config.h" +#include "net/tools/naive/naive_connection.h" +#include "net/tools/naive/naive_protocol.h" + +namespace net { + +class ClientSocketHandle; +class HttpNetworkSession; +class NaiveConnection; +class ServerSocket; +class StreamSocket; +struct NetworkTrafficAnnotationTag; +class RedirectResolver; + +class NaiveProxy { + public: + NaiveProxy(std::unique_ptr server_socket, + ClientProtocol protocol, + const std::string& listen_user, + const std::string& listen_pass, + int concurrency, + RedirectResolver* resolver, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation); + ~NaiveProxy(); + NaiveProxy(const NaiveProxy&) = delete; + NaiveProxy& operator=(const NaiveProxy&) = delete; + + private: + void DoAcceptLoop(); + void OnAcceptComplete(int result); + void HandleAcceptResult(int result); + + void DoConnect(); + void OnConnectComplete(unsigned int connection_id, int result); + void HandleConnectResult(NaiveConnection* connection, int result); + + void DoRun(NaiveConnection* connection); + void OnRunComplete(unsigned int connection_id, int result); + void HandleRunResult(NaiveConnection* connection, int result); + + void Close(unsigned int connection_id, int reason); + + NaiveConnection* FindConnection(unsigned int connection_id); + + std::unique_ptr listen_socket_; + ClientProtocol protocol_; + std::string listen_user_; + std::string listen_pass_; + int concurrency_; + ProxyInfo proxy_info_; + SSLConfig server_ssl_config_; + SSLConfig proxy_ssl_config_; + RedirectResolver* resolver_; + HttpNetworkSession* session_; + NetLogWithSource net_log_; + + unsigned int last_id_; + + std::unique_ptr accepted_socket_; + + std::vector network_anonymization_keys_; + + std::map> connection_by_id_; + + const NetworkTrafficAnnotationTag& traffic_annotation_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_H_ diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc new file mode 100644 index 0000000000..c2c06204a9 --- /dev/null +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -0,0 +1,595 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include +#include + +#include "base/at_exit.h" +#include "base/command_line.h" +#include "base/feature_list.h" +#include "base/files/file_path.h" +#include "base/json/json_file_value_serializer.h" +#include "base/json/json_writer.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/run_loop.h" +#include "base/strings/escape.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" +#include "base/strings/utf_string_conversions.h" +#include "base/system/sys_info.h" +#include "base/task/single_thread_task_executor.h" +#include "base/task/thread_pool/thread_pool_instance.h" +#include "base/values.h" +#include "build/build_config.h" +#include "components/version_info/version_info.h" +#include "net/base/auth.h" +#include "net/base/network_isolation_key.h" +#include "net/base/url_util.h" +#include "net/cert/cert_verifier.h" +#include "net/cert_net/cert_net_fetcher_url_request.h" +#include "net/dns/host_resolver.h" +#include "net/dns/mapped_host_resolver.h" +#include "net/http/http_auth.h" +#include "net/http/http_auth_cache.h" +#include "net/http/http_network_session.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_transaction_factory.h" +#include "net/log/file_net_log_observer.h" +#include "net/log/net_log.h" +#include "net/log/net_log_capture_mode.h" +#include "net/log/net_log_entry.h" +#include "net/log/net_log_event_type.h" +#include "net/log/net_log_source.h" +#include "net/log/net_log_util.h" +#include "net/proxy_resolution/configured_proxy_resolution_service.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_config_service_fixed.h" +#include "net/proxy_resolution/proxy_config_with_annotation.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_server_socket.h" +#include "net/socket/udp_server_socket.h" +#include "net/ssl/ssl_key_logger_impl.h" +#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h" +#include "net/tools/naive/naive_protocol.h" +#include "net/tools/naive/naive_proxy.h" +#include "net/tools/naive/naive_proxy_delegate.h" +#include "net/tools/naive/partition_alloc_support.h" +#include "net/tools/naive/redirect_resolver.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_builder.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" +#include "url/url_util.h" + +#if BUILDFLAG(IS_APPLE) +#include "base/mac/scoped_nsautorelease_pool.h" +#endif + +namespace { + +constexpr int kListenBackLog = 512; +constexpr int kDefaultMaxSocketsPerPool = 256; +constexpr int kDefaultMaxSocketsPerGroup = 255; +constexpr int kExpectedMaxUsers = 8; +constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("naive", ""); + +struct CommandLine { + std::string listen; + std::string proxy; + std::string concurrency; + std::string extra_headers; + std::string host_resolver_rules; + std::string resolver_range; + bool no_log; + base::FilePath log; + base::FilePath log_net_log; + base::FilePath ssl_key_log_file; +}; + +struct Params { + net::ClientProtocol protocol; + std::string listen_user; + std::string listen_pass; + std::string listen_addr; + int listen_port; + int concurrency; + net::HttpRequestHeaders extra_headers; + std::string proxy_url; + std::u16string proxy_user; + std::u16string proxy_pass; + std::string host_resolver_rules; + net::IPAddress resolver_range; + size_t resolver_prefix; + logging::LoggingSettings log_settings; + base::FilePath net_log_path; + base::FilePath ssl_key_path; +}; + +std::unique_ptr GetConstants() { + base::Value::Dict constants_dict = net::GetNetConstants(); + base::Value::Dict dict; + std::string os_type = base::StringPrintf( + "%s: %s (%s)", base::SysInfo::OperatingSystemName().c_str(), + base::SysInfo::OperatingSystemVersion().c_str(), + base::SysInfo::OperatingSystemArchitecture().c_str()); + dict.Set("os_type", os_type); + constants_dict.Set("clientInfo", std::move(dict)); + return std::make_unique(std::move(constants_dict)); +} + +void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { + if (proc.HasSwitch("h") || proc.HasSwitch("help")) { + std::cout << "Usage: naive { OPTIONS | config.json }\n" + "\n" + "Options:\n" + "-h, --help Show this message\n" + "--version Print version\n" + "--listen=://[addr][:port]\n" + " proto: socks, http\n" + " redir (Linux only)\n" + "--proxy=://[:@][:]\n" + " proto: https, quic\n" + "--insecure-concurrency= Use N connections, insecure\n" + "--extra-headers=... Extra headers split by CRLF\n" + "--host-resolver-rules=... Resolver rules\n" + "--resolver-range=... Redirect resolver range\n" + "--log[=] Log to stderr, or file\n" + "--log-net-log= Save NetLog\n" + "--ssl-key-log-file= Save SSL keys for Wireshark\n" + << std::endl; + exit(EXIT_SUCCESS); + } + + if (proc.HasSwitch("version")) { + std::cout << "naive " << version_info::GetVersionNumber() << std::endl; + exit(EXIT_SUCCESS); + } + + cmdline->listen = proc.GetSwitchValueASCII("listen"); + cmdline->proxy = proc.GetSwitchValueASCII("proxy"); + cmdline->concurrency = proc.GetSwitchValueASCII("insecure-concurrency"); + cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers"); + cmdline->host_resolver_rules = + proc.GetSwitchValueASCII("host-resolver-rules"); + cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range"); + cmdline->no_log = !proc.HasSwitch("log"); + cmdline->log = proc.GetSwitchValuePath("log"); + cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log"); + cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file"); +} + +void GetCommandLineFromConfig(const base::FilePath& config_path, + CommandLine* cmdline) { + JSONFileValueDeserializer reader(config_path); + int error_code; + std::string error_message; + auto value = reader.Deserialize(&error_code, &error_message); + if (value == nullptr) { + std::cerr << "Error reading " << config_path << ": (" << error_code << ") " + << error_message << std::endl; + exit(EXIT_FAILURE); + } + if (!value->is_dict()) { + std::cerr << "Invalid config format" << std::endl; + exit(EXIT_FAILURE); + } + const auto* listen = value->FindStringKey("listen"); + if (listen) { + cmdline->listen = *listen; + } + const auto* proxy = value->FindStringKey("proxy"); + if (proxy) { + cmdline->proxy = *proxy; + } + const auto* concurrency = value->FindStringKey("insecure-concurrency"); + if (concurrency) { + cmdline->concurrency = *concurrency; + } + const auto* extra_headers = value->FindStringKey("extra-headers"); + if (extra_headers) { + cmdline->extra_headers = *extra_headers; + } + const auto* host_resolver_rules = value->FindStringKey("host-resolver-rules"); + if (host_resolver_rules) { + cmdline->host_resolver_rules = *host_resolver_rules; + } + const auto* resolver_range = value->FindStringKey("resolver-range"); + if (resolver_range) { + cmdline->resolver_range = *resolver_range; + } + cmdline->no_log = true; + const auto* log = value->FindStringKey("log"); + if (log) { + cmdline->no_log = false; + cmdline->log = base::FilePath::FromUTF8Unsafe(*log); + } + const auto* log_net_log = value->FindStringKey("log-net-log"); + if (log_net_log) { + cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(*log_net_log); + } + const auto* ssl_key_log_file = value->FindStringKey("ssl-key-log-file"); + if (ssl_key_log_file) { + cmdline->ssl_key_log_file = + base::FilePath::FromUTF8Unsafe(*ssl_key_log_file); + } +} + +bool ParseCommandLine(const CommandLine& cmdline, Params* params) { + params->protocol = net::ClientProtocol::kSocks5; + params->listen_addr = "0.0.0.0"; + params->listen_port = 1080; + url::AddStandardScheme("socks", + url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); + url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT); + if (!cmdline.listen.empty()) { + GURL url(cmdline.listen); + if (url.scheme() == "socks") { + params->protocol = net::ClientProtocol::kSocks5; + params->listen_port = 1080; + } else if (url.scheme() == "http") { + params->protocol = net::ClientProtocol::kHttp; + params->listen_port = 8080; + } else if (url.scheme() == "redir") { +#if BUILDFLAG(IS_LINUX) + params->protocol = net::ClientProtocol::kRedir; + params->listen_port = 1080; +#else + std::cerr << "Redir protocol only supports Linux." << std::endl; + return false; +#endif + } else { + std::cerr << "Invalid scheme in --listen" << std::endl; + return false; + } + if (!url.username().empty()) { + params->listen_user = base::UnescapeBinaryURLComponent(url.username()); + } + if (!url.password().empty()) { + params->listen_pass = base::UnescapeBinaryURLComponent(url.password()); + } + if (!url.host().empty()) { + params->listen_addr = url.HostNoBrackets(); + } + if (!url.port().empty()) { + if (!base::StringToInt(url.port(), ¶ms->listen_port)) { + std::cerr << "Invalid port in --listen" << std::endl; + return false; + } + if (params->listen_port <= 0 || + params->listen_port > std::numeric_limits::max()) { + std::cerr << "Invalid port in --listen" << std::endl; + return false; + } + } + } + + params->proxy_url = "direct://"; + GURL url(cmdline.proxy); + GURL::Replacements remove_auth; + remove_auth.ClearUsername(); + remove_auth.ClearPassword(); + GURL url_no_auth = url.ReplaceComponents(remove_auth); + if (!cmdline.proxy.empty()) { + params->proxy_url = url_no_auth.GetWithEmptyPath().spec(); + if (params->proxy_url.empty()) { + std::cerr << "Invalid proxy URL" << std::endl; + return false; + } else if (params->proxy_url.back() == '/') { + params->proxy_url.pop_back(); + } + net::GetIdentityFromURL(url, ¶ms->proxy_user, ¶ms->proxy_pass); + } + + if (!cmdline.concurrency.empty()) { + if (!base::StringToInt(cmdline.concurrency, ¶ms->concurrency) || + params->concurrency < 1) { + std::cerr << "Invalid concurrency" << std::endl; + return false; + } + } else { + params->concurrency = 1; + } + + params->extra_headers.AddHeadersFromString(cmdline.extra_headers); + + params->host_resolver_rules = cmdline.host_resolver_rules; + + if (params->protocol == net::ClientProtocol::kRedir) { + std::string range = "100.64.0.0/10"; + if (!cmdline.resolver_range.empty()) + range = cmdline.resolver_range; + + if (!net::ParseCIDRBlock(range, ¶ms->resolver_range, + ¶ms->resolver_prefix)) { + std::cerr << "Invalid resolver range" << std::endl; + return false; + } + if (params->resolver_range.IsIPv6()) { + std::cerr << "IPv6 resolver range not supported" << std::endl; + return false; + } + } + + if (!cmdline.no_log) { + if (!cmdline.log.empty()) { + params->log_settings.logging_dest = logging::LOG_TO_FILE; + params->log_settings.log_file_path = cmdline.log.value().c_str(); + } else { + params->log_settings.logging_dest = logging::LOG_TO_STDERR; + } + } else { + params->log_settings.logging_dest = logging::LOG_NONE; + } + + params->net_log_path = cmdline.log_net_log; + params->ssl_key_path = cmdline.ssl_key_log_file; + + return true; +} +} // namespace + +namespace net { +namespace { +// NetLog::ThreadSafeObserver implementation that simply prints events +// to the logs. +class PrintingLogObserver : public NetLog::ThreadSafeObserver { + public: + PrintingLogObserver() = default; + PrintingLogObserver(const PrintingLogObserver&) = delete; + PrintingLogObserver& operator=(const PrintingLogObserver&) = delete; + + ~PrintingLogObserver() override { + // This is guaranteed to be safe as this program is single threaded. + net_log()->RemoveObserver(this); + } + + // NetLog::ThreadSafeObserver implementation: + void OnAddEntry(const NetLogEntry& entry) override { + switch (entry.type) { + case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS: + case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP: + case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_SESSION_SEND_WINDOW: + case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_STREAM_SEND_WINDOW: + case NetLogEventType::HTTP2_SESSION_STALLED_MAX_STREAMS: + case NetLogEventType::HTTP2_STREAM_FLOW_CONTROL_UNSTALLED: + break; + default: + return; + } + const char* source_type = NetLog::SourceTypeToString(entry.source.type); + const char* event_type = NetLogEventTypeToString(entry.type); + const char* event_phase = NetLog::EventPhaseToString(entry.phase); + base::Value params(entry.ToValue()); + std::string params_str; + base::JSONWriter::Write(params, ¶ms_str); + params_str.insert(0, ": "); + + VLOG(1) << source_type << "(" << entry.source.id << "): " << event_type + << ": " << event_phase << params_str; + } +}; +} // namespace + +namespace { +std::unique_ptr BuildCertURLRequestContext(NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + return builder.Build(); +} + +// Builds a URLRequestContext assuming there's only a single loop. +std::unique_ptr BuildURLRequestContext( + const Params& params, + scoped_refptr cert_net_fetcher, + NetLog* net_log) { + URLRequestContextBuilder builder; + + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString(params.proxy_url); + LOG(INFO) << "Proxying via " << params.proxy_url; + auto proxy_service = + ConfiguredProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + builder.set_proxy_resolution_service(std::move(proxy_service)); + + if (!params.host_resolver_rules.empty()) { + builder.set_host_mapping_rules(params.host_resolver_rules); + } + + builder.SetCertVerifier( + CertVerifier::CreateDefault(std::move(cert_net_fetcher))); + + builder.set_proxy_delegate( + std::make_unique(params.extra_headers)); + + auto context = builder.Build(); + + if (!params.proxy_url.empty() && !params.proxy_user.empty() && + !params.proxy_pass.empty()) { + auto* session = context->http_transaction_factory()->GetSession(); + auto* auth_cache = session->http_auth_cache(); + std::string proxy_url = params.proxy_url; + GURL proxy_gurl(proxy_url); + if (proxy_url.compare(0, 7, "quic://") == 0) { + proxy_url.replace(0, 4, "https"); + proxy_gurl = GURL(proxy_url); + auto* quic = context->quic_context()->params(); + quic->supported_versions = {quic::ParsedQuicVersion::RFCv1()}; + quic->origins_to_force_quic_on.insert( + net::HostPortPair::FromURL(proxy_gurl)); + } + url::SchemeHostPort auth_origin(proxy_gurl); + AuthCredentials credentials(params.proxy_user, params.proxy_pass); + auth_cache->Add(auth_origin, HttpAuth::AUTH_PROXY, + /*realm=*/{}, HttpAuth::AUTH_SCHEME_BASIC, {}, + /*challenge=*/"Basic", credentials, /*path=*/"/"); + } + + return context; +} +} // namespace +} // namespace net + +int main(int argc, char* argv[]) { + naive_partition_alloc_support::ReconfigureEarly(); + + url::AddStandardScheme("quic", + url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); + base::FeatureList::InitializeInstance( + "PartitionConnectionsByNetworkIsolationKey", std::string()); + net::ClientSocketPoolManager::set_max_sockets_per_pool( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_proxy_server( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_group( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerGroup * kExpectedMaxUsers); + + naive_partition_alloc_support::ReconfigureAfterFeatureListInit(); + +#if BUILDFLAG(IS_APPLE) + base::mac::ScopedNSAutoreleasePool pool; +#endif + + base::AtExitManager exit_manager; + base::CommandLine::Init(argc, argv); + + CommandLine cmdline; + Params params; + const auto& proc = *base::CommandLine::ForCurrentProcess(); + const auto& args = proc.GetArgs(); + if (args.empty()) { + if (proc.argv().size() >= 2) { + GetCommandLine(proc, &cmdline); + } else { + auto path = base::FilePath::FromUTF8Unsafe("config.json"); + GetCommandLineFromConfig(path, &cmdline); + } + } else { + base::FilePath path(args[0]); + GetCommandLineFromConfig(path, &cmdline); + } + if (!ParseCommandLine(cmdline, ¶ms)) { + return EXIT_FAILURE; + } + CHECK(logging::InitLogging(params.log_settings)); + + base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO); + base::ThreadPoolInstance::CreateAndStartWithDefaultParams("naive"); + + naive_partition_alloc_support::ReconfigureAfterTaskRunnerInit(); + + if (!params.ssl_key_path.empty()) { + net::SSLClientSocket::SetSSLKeyLogger( + std::make_unique(params.ssl_key_path)); + } + + // The declaration order for net_log and printing_log_observer is + // important. The destructor of PrintingLogObserver removes itself + // from net_log, so net_log must be available for entire lifetime of + // printing_log_observer. + net::NetLog* net_log = net::NetLog::Get(); + std::unique_ptr observer; + if (!params.net_log_path.empty()) { + observer = net::FileNetLogObserver::CreateUnbounded( + params.net_log_path, net::NetLogCaptureMode::kDefault, GetConstants()); + observer->StartObserving(net_log); + } + + // Avoids net log overhead if verbose logging is disabled. + std::unique_ptr printing_log_observer; + if (params.log_settings.logging_dest != logging::LOG_NONE && VLOG_IS_ON(1)) { + printing_log_observer = std::make_unique(); + net_log->AddObserver(printing_log_observer.get(), + net::NetLogCaptureMode::kDefault); + } + + auto cert_context = net::BuildCertURLRequestContext(net_log); + scoped_refptr cert_net_fetcher; + // The builtin verifier is supported but not enabled by default on Mac, + // falling back to CreateSystemVerifyProc() which drops the net fetcher, + // causing a DCHECK in ~CertNetFetcherURLRequest(). + // See CertVerifier::CreateDefaultWithoutCaching() and + // CertVerifyProc::CreateSystemVerifyProc() for the build flags. +#if BUILDFLAG(CHROME_ROOT_STORE_SUPPORTED) || BUILDFLAG(IS_FUCHSIA) || \ + BUILDFLAG(IS_LINUX) || BUILDFLAG(IS_CHROMEOS) || BUILDFLAG(IS_ANDROID) + cert_net_fetcher = base::MakeRefCounted(); + cert_net_fetcher->SetURLRequestContext(cert_context.get()); +#endif + auto context = + net::BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log); + auto* session = context->http_transaction_factory()->GetSession(); + + auto listen_socket = + std::make_unique(net_log, net::NetLogSource()); + + int result = listen_socket->ListenWithAddressAndPort( + params.listen_addr, params.listen_port, kListenBackLog); + if (result != net::OK) { + LOG(ERROR) << "Failed to listen: " << result; + return EXIT_FAILURE; + } + LOG(INFO) << "Listening on " << params.listen_addr << ":" + << params.listen_port; + + std::unique_ptr resolver; + if (params.protocol == net::ClientProtocol::kRedir) { + auto resolver_socket = + std::make_unique(net_log, net::NetLogSource()); + resolver_socket->AllowAddressReuse(); + net::IPAddress listen_addr; + if (!listen_addr.AssignFromIPLiteral(params.listen_addr)) { + LOG(ERROR) << "Failed to open resolver: " << net::ERR_ADDRESS_INVALID; + return EXIT_FAILURE; + } + + result = resolver_socket->Listen( + net::IPEndPoint(listen_addr, params.listen_port)); + if (result != net::OK) { + LOG(ERROR) << "Failed to open resolver: " << result; + return EXIT_FAILURE; + } + + resolver = std::make_unique( + std::move(resolver_socket), params.resolver_range, + params.resolver_prefix); + } + + net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, + params.listen_user, params.listen_pass, + params.concurrency, resolver.get(), session, + kTrafficAnnotation); + + base::RunLoop().Run(); + + return EXIT_SUCCESS; +} diff --git a/src/net/tools/naive/naive_proxy_delegate.cc b/src/net/tools/naive/naive_proxy_delegate.cc new file mode 100644 index 0000000000..d18bf65e88 --- /dev/null +++ b/src/net/tools/naive/naive_proxy_delegate.cc @@ -0,0 +1,159 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#include "net/tools/naive/naive_proxy_delegate.h" + +#include + +#include "base/logging.h" +#include "base/rand_util.h" +#include "net/base/proxy_string_util.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" +#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h" + +namespace net { +namespace { +bool g_nonindex_codes_initialized; +uint8_t g_nonindex_codes[17]; +} // namespace + +void InitializeNonindexCodes() { + if (g_nonindex_codes_initialized) + return; + g_nonindex_codes_initialized = true; + unsigned i = 0; + for (const auto& symbol : spdy::HpackHuffmanCodeVector()) { + if (symbol.id >= 0x20 && symbol.id <= 0x7f && symbol.length >= 8) { + g_nonindex_codes[i++] = symbol.id; + if (i >= sizeof(g_nonindex_codes)) + break; + } + } + CHECK(i == sizeof(g_nonindex_codes)); +} + +void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len) { + DCHECK(g_nonindex_codes_initialized); + int first = len < 16 ? len : 16; + for (int i = 0; i < first; i++) { + buf[i] = g_nonindex_codes[unique_bits & 0b1111]; + unique_bits >>= 4; + } + for (int i = first; i < len; i++) { + buf[i] = g_nonindex_codes[16]; + } +} + +NaiveProxyDelegate::NaiveProxyDelegate(const HttpRequestHeaders& extra_headers) + : extra_headers_(extra_headers) { + InitializeNonindexCodes(); +} + +NaiveProxyDelegate::~NaiveProxyDelegate() = default; + +void NaiveProxyDelegate::OnBeforeTunnelRequest( + const ProxyServer& proxy_server, + HttpRequestHeaders* extra_headers) { + if (proxy_server.is_direct() || proxy_server.is_socks()) + return; + + // Sends client-side padding header regardless of server support + std::string padding(base::RandInt(16, 32), '~'); + FillNonindexHeaderValue(base::RandUint64(), &padding[0], padding.size()); + extra_headers->SetHeader("padding", padding); + + // Enables Fast Open in H2/H3 proxy client socket once the state of server + // padding support is known. + if (padding_state_by_server_[proxy_server] != PaddingSupport::kUnknown) { + extra_headers->SetHeader("fastopen", "1"); + } + extra_headers->MergeFrom(extra_headers_); +} + +Error NaiveProxyDelegate::OnTunnelHeadersReceived( + const ProxyServer& proxy_server, + const HttpResponseHeaders& response_headers) { + if (proxy_server.is_direct() || proxy_server.is_socks()) + return OK; + + // Detects server padding support, even if it changes dynamically. + bool padding = response_headers.HasHeader("padding"); + auto new_state = + padding ? PaddingSupport::kCapable : PaddingSupport::kIncapable; + auto& padding_state = padding_state_by_server_[proxy_server]; + if (padding_state == PaddingSupport::kUnknown || padding_state != new_state) { + LOG(INFO) << "Padding capability of " << ProxyServerToProxyUri(proxy_server) + << (padding ? " detected" : " undetected"); + } + padding_state = new_state; + return OK; +} + +PaddingSupport NaiveProxyDelegate::GetProxyServerPaddingSupport( + const ProxyServer& proxy_server) { + // Not possible to detect padding capability given underlying protocol. + if (proxy_server.is_direct() || proxy_server.is_socks()) + return PaddingSupport::kIncapable; + + return padding_state_by_server_[proxy_server]; +} + +PaddingDetectorDelegate::PaddingDetectorDelegate( + NaiveProxyDelegate* naive_proxy_delegate, + const ProxyServer& proxy_server, + ClientProtocol client_protocol) + : naive_proxy_delegate_(naive_proxy_delegate), + proxy_server_(proxy_server), + client_protocol_(client_protocol), + detected_client_padding_support_(PaddingSupport::kUnknown), + cached_server_padding_support_(PaddingSupport::kUnknown) {} + +PaddingDetectorDelegate::~PaddingDetectorDelegate() = default; + +bool PaddingDetectorDelegate::IsPaddingSupportKnown() { + auto c = GetClientPaddingSupport(); + auto s = GetServerPaddingSupport(); + return c != PaddingSupport::kUnknown && s != PaddingSupport::kUnknown; +} + +Direction PaddingDetectorDelegate::GetPaddingDirection() { + auto c = GetClientPaddingSupport(); + auto s = GetServerPaddingSupport(); + // Padding support must be already detected at this point. + CHECK_NE(c, PaddingSupport::kUnknown); + CHECK_NE(s, PaddingSupport::kUnknown); + if (c == PaddingSupport::kCapable && s == PaddingSupport::kIncapable) { + return kServer; + } + if (c == PaddingSupport::kIncapable && s == PaddingSupport::kCapable) { + return kClient; + } + return kNone; +} + +void PaddingDetectorDelegate::SetClientPaddingSupport( + PaddingSupport padding_support) { + detected_client_padding_support_ = padding_support; +} + +PaddingSupport PaddingDetectorDelegate::GetClientPaddingSupport() { + // Not possible to detect padding capability given underlying protocol. + if (client_protocol_ == ClientProtocol::kSocks5) { + return PaddingSupport::kIncapable; + } else if (client_protocol_ == ClientProtocol::kRedir) { + return PaddingSupport::kIncapable; + } + + return detected_client_padding_support_; +} + +PaddingSupport PaddingDetectorDelegate::GetServerPaddingSupport() { + if (cached_server_padding_support_ != PaddingSupport::kUnknown) + return cached_server_padding_support_; + cached_server_padding_support_ = + naive_proxy_delegate_->GetProxyServerPaddingSupport(proxy_server_); + return cached_server_padding_support_; +} + +} // namespace net diff --git a/src/net/tools/naive/naive_proxy_delegate.h b/src/net/tools/naive/naive_proxy_delegate.h new file mode 100644 index 0000000000..43cdeeadc2 --- /dev/null +++ b/src/net/tools/naive/naive_proxy_delegate.h @@ -0,0 +1,94 @@ +// Copyright 2020 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ + +#include +#include +#include + +#include "base/strings/string_piece.h" +#include "net/base/net_errors.h" +#include "net/base/proxy_delegate.h" +#include "net/base/proxy_server.h" +#include "net/proxy_resolution/proxy_retry_info.h" +#include "net/tools/naive/naive_protocol.h" +#include "url/gurl.h" + +namespace net { + +void InitializeNonindexCodes(); +// |unique_bits| SHOULD have relatively unique values. +void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len); + +class ProxyInfo; +class HttpRequestHeaders; +class HttpResponseHeaders; + +enum class PaddingSupport { + kUnknown = 0, + kCapable, + kIncapable, +}; + +class NaiveProxyDelegate : public ProxyDelegate { + public: + explicit NaiveProxyDelegate(const HttpRequestHeaders& extra_headers); + ~NaiveProxyDelegate() override; + + void OnResolveProxy(const GURL& url, + const std::string& method, + const ProxyRetryInfoMap& proxy_retry_info, + ProxyInfo* result) override {} + void OnFallback(const ProxyServer& bad_proxy, int net_error) override {} + + // This only affects h2 proxy client socket. + void OnBeforeTunnelRequest(const ProxyServer& proxy_server, + HttpRequestHeaders* extra_headers) override; + + Error OnTunnelHeadersReceived( + const ProxyServer& proxy_server, + const HttpResponseHeaders& response_headers) override; + + PaddingSupport GetProxyServerPaddingSupport(const ProxyServer& proxy_server); + + private: + const HttpRequestHeaders& extra_headers_; + std::map padding_state_by_server_; +}; + +class ClientPaddingDetectorDelegate { + public: + virtual ~ClientPaddingDetectorDelegate() = default; + + virtual void SetClientPaddingSupport(PaddingSupport padding_support) = 0; +}; + +class PaddingDetectorDelegate : public ClientPaddingDetectorDelegate { + public: + PaddingDetectorDelegate(NaiveProxyDelegate* naive_proxy_delegate, + const ProxyServer& proxy_server, + ClientProtocol client_protocol); + ~PaddingDetectorDelegate() override; + + bool IsPaddingSupportKnown(); + Direction GetPaddingDirection(); + void SetClientPaddingSupport(PaddingSupport padding_support) override; + + private: + PaddingSupport GetClientPaddingSupport(); + PaddingSupport GetServerPaddingSupport(); + + NaiveProxyDelegate* naive_proxy_delegate_; + const ProxyServer& proxy_server_; + ClientProtocol client_protocol_; + + PaddingSupport detected_client_padding_support_; + // The result is only cached during one connection, so it's still dynamically + // updated in the following connections after server changes support. + PaddingSupport cached_server_padding_support_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_ diff --git a/src/net/tools/naive/partition_alloc_support.cc b/src/net/tools/naive/partition_alloc_support.cc new file mode 100644 index 0000000000..fb145c6787 --- /dev/null +++ b/src/net/tools/naive/partition_alloc_support.cc @@ -0,0 +1,173 @@ +// Copyright 2021 The Chromium Authors +// Copyright 2022 klzgrad . +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/partition_alloc_support.h" + +#include + +#include "base/allocator/allocator_check.h" +#include "base/allocator/buildflags.h" +#include "base/allocator/partition_alloc_features.h" +#include "base/allocator/partition_alloc_support.h" +#include "base/allocator/partition_allocator/partition_alloc_config.h" +#include "base/allocator/partition_allocator/shim/allocator_shim.h" +#include "base/allocator/partition_allocator/shim/allocator_shim_default_dispatch_to_partition_alloc.h" +#include "base/allocator/partition_allocator/thread_cache.h" +#include "base/feature_list.h" +#include "base/process/memory.h" +#include "base/time/time.h" +#include "build/build_config.h" + +#if BUILDFLAG(IS_ANDROID) +#include "base/system/sys_info.h" +#endif + +#if BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) +#include "base/allocator/partition_allocator/memory_reclaimer.h" +#include "base/task/single_thread_task_runner.h" +#endif + +#if BUILDFLAG(IS_APPLE) +#include "base/allocator/early_zone_registration_mac.h" +#endif + +namespace naive_partition_alloc_support { + +void ReconfigureEarly() { + // chrome/app/chrome_exe_main_mac.cc: main() +#if BUILDFLAG(IS_APPLE) + partition_alloc::EarlyMallocZoneRegistration(); +#endif + + // content/app/content_main.cc: ChromeMain() +#if BUILDFLAG(IS_WIN) +#if BUILDFLAG(USE_ALLOCATOR_SHIM) && BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + // Call this early on in order to configure heap workarounds. This must be + // called from chrome.dll. This may be a NOP on some platforms. + allocator_shim::ConfigurePartitionAlloc(); +#endif +#endif // BUILDFLAG(IS_WIN) + + // content/app/content_main.cc: RunContentProcess() +#if BUILDFLAG(IS_APPLE) && BUILDFLAG(USE_ALLOCATOR_SHIM) + // The static initializer function for initializing PartitionAlloc + // InitializeDefaultMallocZoneWithPartitionAlloc() would be removed by the + // linker if allocator_shim.o is not referenced by the following call, + // resulting in undefined behavior of accessing uninitialized TLS + // data in PurgeCurrentThread() when PA is enabled. + allocator_shim::InitializeAllocatorShim(); +#endif + + // content/app/content_main.cc: RunContentProcess() + base::EnableTerminationOnOutOfMemory(); + + // content/app/content_main.cc: RunContentProcess() + base::EnableTerminationOnHeapCorruption(); + + // content/app/content_main.cc: RunContentProcess() + // content/app/content_main_runner_impl.cc: Initialize() + // ReconfigureEarlyish(): + // These initializations are only relevant for PartitionAlloc-Everywhere + // builds. +#if BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + allocator_shim::EnablePartitionAllocMemoryReclaimer(); +#endif // BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + + // content/app/content_main.cc: RunContentProcess() + // content/app/content_main_runner_impl.cc: Initialize() + // If we are on a platform where the default allocator is overridden (e.g. + // with PartitionAlloc on most platforms) smoke-tests that the overriding + // logic is working correctly. If not causes a hard crash, as its unexpected + // absence has security implications. + CHECK(base::allocator::IsAllocatorInitialized()); +} + +void ReconfigureAfterFeatureListInit() { + // TODO(bartekn): Switch to DCHECK once confirmed there are no issues. + CHECK(base::FeatureList::GetInstance()); + + // Does not use any of the security features yet. + [[maybe_unused]] bool enable_brp = false; + [[maybe_unused]] bool enable_brp_zapping = false; + [[maybe_unused]] bool split_main_partition = false; + [[maybe_unused]] bool use_dedicated_aligned_partition = false; + [[maybe_unused]] bool add_dummy_ref_count = false; + [[maybe_unused]] bool process_affected_by_brp_flag = false; + +#if BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + allocator_shim::ConfigurePartitions( + allocator_shim::EnableBrp(enable_brp), + allocator_shim::EnableBrpZapping(enable_brp_zapping), + allocator_shim::SplitMainPartition(split_main_partition), + allocator_shim::UseDedicatedAlignedPartition( + use_dedicated_aligned_partition), + allocator_shim::AddDummyRefCount(add_dummy_ref_count), + allocator_shim::AlternateBucketDistribution( + base::features::kPartitionAllocAlternateBucketDistributionParam + .Get())); +#endif // BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + +#if BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + allocator_shim::internal::PartitionAllocMalloc::Allocator() + ->EnableThreadCacheIfSupported(); + + if (base::FeatureList::IsEnabled( + base::features::kPartitionAllocLargeEmptySlotSpanRing)) { + allocator_shim::internal::PartitionAllocMalloc::Allocator() + ->EnableLargeEmptySlotSpanRing(); + allocator_shim::internal::PartitionAllocMalloc::AlignedAllocator() + ->EnableLargeEmptySlotSpanRing(); + } +#endif // BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) +} + +void ReconfigureAfterTaskRunnerInit() { +#if defined(PA_THREAD_CACHE_SUPPORTED) && \ + BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + base::allocator::StartThreadCachePeriodicPurge(); + +#if BUILDFLAG(IS_ANDROID) + // Lower thread cache limits to avoid stranding too much memory in the caches. + if (base::SysInfo::IsLowEndDevice()) { + ::partition_alloc::ThreadCacheRegistry::Instance().SetThreadCacheMultiplier( + ::partition_alloc::ThreadCache::kDefaultMultiplier / 2.); + } +#endif // BUILDFLAG(IS_ANDROID) + + // Renderer processes are more performance-sensitive, increase thread cache + // limits. + if (/*is_performance_sensitive=*/true && + base::FeatureList::IsEnabled( + base::features::kPartitionAllocLargeThreadCacheSize)) { + size_t largest_cached_size = + ::partition_alloc::ThreadCacheLimits::kLargeSizeThreshold; + +#if BUILDFLAG(IS_ANDROID) && defined(ARCH_CPU_32_BITS) + // Devices almost always report less physical memory than what they actually + // have, so anything above 3GiB will catch 4GiB and above. + if (base::SysInfo::AmountOfPhysicalMemoryMB() <= 3500) + largest_cached_size = + ::partition_alloc::ThreadCacheLimits::kDefaultSizeThreshold; +#endif // BUILDFLAG(IS_ANDROID) && !defined(ARCH_CPU_64_BITS) + + ::partition_alloc::ThreadCache::SetLargestCachedSize(largest_cached_size); + } + +#endif // defined(PA_THREAD_CACHE_SUPPORTED) && + // BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + +#if BUILDFLAG(USE_PARTITION_ALLOC_AS_MALLOC) + base::allocator::StartMemoryReclaimer( + base::SingleThreadTaskRunner::GetCurrentDefault()); +#endif + + if (base::FeatureList::IsEnabled( + base::features::kPartitionAllocSortActiveSlotSpans)) { + partition_alloc::PartitionRoot< + partition_alloc::internal::ThreadSafe>::EnableSortActiveSlotSpans(); + } +} + +} // namespace naive_partition_alloc_support diff --git a/src/net/tools/naive/partition_alloc_support.h b/src/net/tools/naive/partition_alloc_support.h new file mode 100644 index 0000000000..515ea3bf84 --- /dev/null +++ b/src/net/tools/naive/partition_alloc_support.h @@ -0,0 +1,17 @@ +// Copyright 2021 The Chromium Authors +// Copyright 2022 klzgrad . +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_PARTITION_ALLOC_SUPPORT_H_ +#define NET_TOOLS_NAIVE_PARTITION_ALLOC_SUPPORT_H_ + +namespace naive_partition_alloc_support { + +void ReconfigureEarly(); +void ReconfigureAfterFeatureListInit(); +void ReconfigureAfterTaskRunnerInit(); + +} // namespace naive_partition_alloc_support + +#endif // NET_TOOLS_NAIVE_PARTITION_ALLOC_SUPPORT_H_ diff --git a/src/net/tools/naive/redirect_resolver.cc b/src/net/tools/naive/redirect_resolver.cc new file mode 100644 index 0000000000..e294d1d54c --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.cc @@ -0,0 +1,243 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/redirect_resolver.h" + +#include +#include +#include + +#include "base/logging.h" +#include "base/task/single_thread_task_runner.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/url_util.h" +#include "net/dns/dns_names_util.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_util.h" +#include "net/socket/datagram_server_socket.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace { +constexpr int kUdpReadBufferSize = 1024; +constexpr int kResolutionTtl = 60; +constexpr int kResolutionRecycleTime = 60 * 5; + +std::string PackedIPv4ToString(uint32_t addr) { + return net::IPAddress(addr >> 24, addr >> 16, addr >> 8, addr).ToString(); +} +} // namespace + +namespace net { + +Resolution::Resolution() = default; + +Resolution::~Resolution() = default; + +RedirectResolver::RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix) + : socket_(std::move(socket)), + range_(range), + prefix_(prefix), + offset_(0), + buffer_(base::MakeRefCounted(kUdpReadBufferSize)) { + DCHECK(socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( + FROM_HERE, base::BindOnce(&RedirectResolver::DoRead, + weak_ptr_factory_.GetWeakPtr())); +} + +RedirectResolver::~RedirectResolver() = default; + +void RedirectResolver::DoRead() { + for (;;) { + int rv = socket_->RecvFrom( + buffer_.get(), kUdpReadBufferSize, &recv_address_, + base::BindOnce(&RedirectResolver::OnRecv, base::Unretained(this))); + if (rv == ERR_IO_PENDING) + return; + rv = HandleReadResult(rv); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "DoRead: ignoring error " << rv; + } + } +} + +void RedirectResolver::OnRecv(int result) { + int rv; + rv = HandleReadResult(result); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "OnRecv: ignoring error " << result; + } + + DoRead(); +} + +void RedirectResolver::OnSend(int result) { + if (result < 0) { + LOG(INFO) << "OnSend: ignoring error " << result; + } + + DoRead(); +} + +int RedirectResolver::HandleReadResult(int result) { + if (result < 0) + return result; + + DnsQuery query(buffer_.get()); + if (!query.Parse(result)) { + LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString(); + return ERR_INVALID_ARGUMENT; + } + + auto name_or = dns_names_util::NetworkToDottedName(query.qname()); + DnsResponse response; + absl::optional query_opt; + query_opt.emplace(query.id(), query.qname(), query.qtype()); + if (!name_or || !IsCanonicalizedHostCompliant(name_or.value())) { + response = + DnsResponse(query.id(), /*is_authoritative=*/false, /*answers=*/{}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt, dns_protocol::kRcodeFORMERR); + } else if (query.qtype() != dns_protocol::kTypeA) { + response = + DnsResponse(query.id(), /*is_authoritative=*/false, /*answers=*/{}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt, dns_protocol::kRcodeNOTIMP); + } else { + Resolution res; + + const auto& name = name_or.value(); + + auto by_name_lookup = resolution_by_name_.emplace(name, resolutions_.end()); + auto by_name = by_name_lookup.first; + bool has_name = !by_name_lookup.second; + if (has_name) { + auto res_it = by_name->second; + auto by_addr = res_it->by_addr; + uint32_t addr = res_it->addr; + + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + uint32_t addr = (range_.bytes()[0] << 24) | (range_.bytes()[1] << 16) | + (range_.bytes()[2] << 8) | range_.bytes()[3]; + uint32_t subnet = ~0U >> prefix_; + addr &= ~subnet; + addr += offset_; + offset_ = (offset_ + 1) & subnet; + + auto by_addr_lookup = + resolution_by_addr_.emplace(addr, resolutions_.end()); + auto by_addr = by_addr_lookup.first; + bool has_addr = !by_addr_lookup.second; + if (has_addr) { + // Too few available addresses. Overwrites old one. + auto res_it = by_addr->second; + + LOG(INFO) << "Overwrite " << res_it->name << " " + << PackedIPv4ToString(res_it->addr) << " with " << name << " " + << PackedIPv4ToString(addr); + resolution_by_name_.erase(res_it->by_name); + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + LOG(INFO) << "Add " << name << " " << PackedIPv4ToString(addr); + resolutions_.emplace_back(); + auto res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + + // Collects garbage. + auto now = base::TimeTicks::Now(); + for (auto it = resolutions_.begin(); + it != resolutions_.end() && + (now - it->time).InSeconds() > kResolutionRecycleTime;) { + auto next = std::next(it); + LOG(INFO) << "Drop " << it->name << " " + << PackedIPv4ToString(it->addr); + resolution_by_name_.erase(it->by_name); + resolution_by_addr_.erase(it->by_addr); + resolutions_.erase(it); + it = next; + } + } + } + + DnsResourceRecord record; + record.name = name; + record.type = dns_protocol::kTypeA; + record.klass = dns_protocol::kClassIN; + record.ttl = kResolutionTtl; + uint32_t addr = by_name->second->addr; + record.SetOwnedRdata(IPAddressToPackedString( + IPAddress(addr >> 24, addr >> 16, addr >> 8, addr))); + response = DnsResponse(query.id(), /*is_authoritative=*/false, + /*answers=*/{std::move(record)}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt); + } + int size = response.io_buffer_size(); + if (size > buffer_->size() || !response.io_buffer()) { + return ERR_NO_BUFFER_SPACE; + } + std::memcpy(buffer_->data(), response.io_buffer()->data(), size); + + return socket_->SendTo( + buffer_.get(), size, recv_address_, + base::BindOnce(&RedirectResolver::OnSend, base::Unretained(this))); +} + +bool RedirectResolver::IsInResolvedRange(const IPAddress& address) const { + if (!address.IsIPv4()) + return false; + return IPAddressMatchesPrefix(address, range_, prefix_); +} + +std::string RedirectResolver::FindNameByAddress( + const IPAddress& address) const { + if (!address.IsIPv4()) + return {}; + uint32_t addr = (address.bytes()[0] << 24) | (address.bytes()[1] << 16) | + (address.bytes()[2] << 8) | address.bytes()[3]; + auto by_addr = resolution_by_addr_.find(addr); + if (by_addr == resolution_by_addr_.end()) + return {}; + return by_addr->second->name; +} + +} // namespace net diff --git a/src/net/tools/naive/redirect_resolver.h b/src/net/tools/naive/redirect_resolver.h new file mode 100644 index 0000000000..3b78c4111e --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.h @@ -0,0 +1,69 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ +#define NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ + +#include +#include +#include +#include +#include + +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/ip_address.h" +#include "net/base/ip_endpoint.h" + +namespace net { + +class DatagramServerSocket; +class IOBufferWithSize; + +struct Resolution { + Resolution(); + ~Resolution(); + + uint32_t addr; + std::string name; + base::TimeTicks time; + std::map::iterator>::iterator by_name; + std::map::iterator>::iterator by_addr; +}; + +class RedirectResolver { + public: + RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix); + ~RedirectResolver(); + RedirectResolver(const RedirectResolver&) = delete; + RedirectResolver& operator=(const RedirectResolver&) = delete; + + bool IsInResolvedRange(const IPAddress& address) const; + std::string FindNameByAddress(const IPAddress& address) const; + + private: + void DoRead(); + void OnRecv(int result); + void OnSend(int result); + int HandleReadResult(int result); + + std::unique_ptr socket_; + IPAddress range_; + size_t prefix_; + uint32_t offset_; + scoped_refptr buffer_; + IPEndPoint recv_address_; + + std::map::iterator> resolution_by_name_; + std::map::iterator> resolution_by_addr_; + std::list resolutions_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ diff --git a/src/net/tools/naive/socks5_server_socket.cc b/src/net/tools/naive/socks5_server_socket.cc new file mode 100644 index 0000000000..02267fc63c --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.cc @@ -0,0 +1,669 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/socks5_server_socket.h" + +#include +#include + +#include "base/cxx17_backports.h" +#include "base/functional/bind.h" +#include "base/functional/callback_helpers.h" +#include "base/logging.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/base/sys_addrinfo.h" +#include "net/log/net_log.h" +#include "net/log/net_log_event_type.h" + +namespace net { + +enum SocksCommandType { + kCommandConnect = 0x01, + kCommandBind = 0x02, + kCommandUDPAssociate = 0x03, +}; + +static constexpr unsigned int kGreetReadHeaderSize = 2; +static constexpr unsigned int kAuthReadHeaderSize = 2; +static constexpr unsigned int kReadHeaderSize = 5; +static constexpr char kSOCKS5Version = '\x05'; +static constexpr char kSOCKS5Reserved = '\x00'; +static constexpr char kAuthMethodNone = '\x00'; +static constexpr char kAuthMethodUserPass = '\x02'; +static constexpr char kAuthMethodNoAcceptable = '\xff'; +static constexpr char kSubnegotiationVersion = '\x01'; +static constexpr char kAuthStatusSuccess = '\x00'; +static constexpr char kAuthStatusFailure = '\xff'; +static constexpr char kReplySuccess = '\x00'; +static constexpr char kReplyCommandNotSupported = '\x07'; + +static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); +static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); + +Socks5ServerSocket::Socks5ServerSocket( + std::unique_ptr transport_socket, + const std::string& user, + const std::string& pass, + const NetworkTrafficAnnotationTag& traffic_annotation) + : io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_sent_(0), + was_ever_used_(false), + user_(user), + pass_(pass), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation) {} + +Socks5ServerSocket::~Socks5ServerSocket() { + Disconnect(); +} + +const HostPortPair& Socks5ServerSocket::request_endpoint() const { + return request_endpoint_; +} + +int Socks5ServerSocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT); + + next_state_ = STATE_GREET_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } else { + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv); + } + return rv; +} + +void Socks5ServerSocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool Socks5ServerSocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool Socks5ServerSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& Socks5ServerSocket::NetLog() const { + return net_log_; +} + +bool Socks5ServerSocket::WasEverUsed() const { + return was_ever_used_; +} + +bool Socks5ServerSocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto Socks5ServerSocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool Socks5ServerSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +int64_t Socks5ServerSocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void Socks5ServerSocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int Socks5ServerSocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int Socks5ServerSocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int Socks5ServerSocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int Socks5ServerSocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void Socks5ServerSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void Socks5ServerSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT); + DoCallback(rv); + } +} + +void Socks5ServerSocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int Socks5ServerSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_GREET_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetRead(); + break; + case STATE_GREET_READ_COMPLETE: + rv = DoGreetReadComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ, + rv); + break; + case STATE_GREET_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE); + rv = DoGreetWrite(); + break; + case STATE_GREET_WRITE_COMPLETE: + rv = DoGreetWriteComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE, + rv); + break; + case STATE_AUTH_READ: + DCHECK_EQ(OK, rv); + rv = DoAuthRead(); + break; + case STATE_AUTH_READ_COMPLETE: + rv = DoAuthReadComplete(rv); + break; + case STATE_AUTH_WRITE: + DCHECK_EQ(OK, rv); + rv = DoAuthWrite(); + break; + case STATE_AUTH_WRITE_COMPLETE: + rv = DoAuthWriteComplete(rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_READ, rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int Socks5ServerSocket::DoGreetRead() { + next_state_ = STATE_GREET_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kGreetReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoGreetReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kGreetReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + int nmethods = buffer_[1]; + if (nmethods == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += nmethods; + next_state_ = STATE_GREET_READ; + return OK; + } + + if (buffer_.size() == read_header_size_) { + int nmethods = buffer_[1]; + char expected_method = kAuthMethodNone; + if (!user_.empty() || !pass_.empty()) { + expected_method = kAuthMethodUserPass; + } + void* match = + std::memchr(&buffer_[kGreetReadHeaderSize], expected_method, nmethods); + if (match) { + auth_method_ = expected_method; + } else { + auth_method_ = kAuthMethodNoAcceptable; + } + buffer_.clear(); + next_state_ = STATE_GREET_WRITE; + return OK; + } + + next_state_ = STATE_GREET_READ; + return OK; +} + +int Socks5ServerSocket::DoGreetWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSOCKS5Version, auth_method_}; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_GREET_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoGreetWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (auth_method_ == kAuthMethodNone) { + next_state_ = STATE_HANDSHAKE_READ; + } else if (auth_method_ == kAuthMethodUserPass) { + next_state_ = STATE_AUTH_READ; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_GREET_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoAuthRead() { + next_state_ = STATE_AUTH_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kAuthReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoAuthReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kAuthReadHeaderSize) { + if (buffer_[0] != kSubnegotiationVersion) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + int username_len = buffer_[1]; + read_header_size_ += username_len + 1; + next_state_ = STATE_AUTH_READ; + return OK; + } + + if (buffer_.size() == read_header_size_) { + int username_len = buffer_[1]; + int password_len = buffer_[kAuthReadHeaderSize + username_len]; + size_t password_offset = kAuthReadHeaderSize + username_len + 1; + if (buffer_.size() == password_offset && password_len != 0) { + read_header_size_ += password_len; + next_state_ = STATE_AUTH_READ; + return OK; + } + + if (buffer_.compare(kAuthReadHeaderSize, username_len, user_) == 0 && + buffer_.compare(password_offset, password_len, pass_) == 0) { + auth_status_ = kAuthStatusSuccess; + } else { + auth_status_ = kAuthStatusFailure; + } + buffer_.clear(); + next_state_ = STATE_AUTH_WRITE; + return OK; + } + + next_state_ = STATE_AUTH_READ; + return OK; +} + +int Socks5ServerSocket::DoAuthWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSubnegotiationVersion, auth_status_}; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_AUTH_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoAuthWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (auth_status_ == kAuthStatusSuccess) { + next_state_ = STATE_HANDSHAKE_READ; + } else { + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_AUTH_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + read_header_size_ = kReadHeaderSize; + } + + int handshake_buf_len = read_header_size_ - buffer_.size(); + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (buffer_.size() == kReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + SocksCommandType command = static_cast(buffer_[1]); + if (command == kCommandConnect) { + // The proxy replies with success immediately without first connecting + // to the requested endpoint. + reply_ = kReplySuccess; + } else if (command == kCommandBind || command == kCommandUDPAssociate) { + reply_ = kReplyCommandNotSupported; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_COMMAND, + "commmand", buffer_[1]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + // We check the type of IP/Domain the server returns and accordingly + // increase the size of the request. For domains, we need to read the + // size of the domain, so the initial request size is upto the domain + // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is + // read, we substract 1 byte from the additional request size. + address_type_ = static_cast(buffer_[3]); + if (address_type_ == kEndPointDomain) { + address_size_ = static_cast(buffer_[4]); + if (address_size_ == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (address_type_ == kEndPointResolvedIPv4) { + address_size_ = sizeof(struct in_addr); + --read_header_size_; + } else if (address_type_ == kEndPointResolvedIPv6) { + address_size_ = sizeof(struct in6_addr); + --read_header_size_; + } else { + // Aborts connection on unspecified address type. + net_log_.AddEventWithIntParams( + NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type", + buffer_[3]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += address_size_ + sizeof(uint16_t); + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + // When the final bytes are read, setup handshake. + if (buffer_.size() == read_header_size_) { + size_t port_start = read_header_size_ - sizeof(uint16_t); + uint16_t port_net; + std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t)); + uint16_t port_host = base::NetToHost16(port_net); + + size_t address_start = port_start - address_size_; + if (address_type_ == kEndPointDomain) { + std::string domain(&buffer_[address_start], address_size_); + request_endpoint_ = HostPortPair(domain, port_host); + } else { + IPAddress ip_addr( + reinterpret_cast(&buffer_[address_start]), + address_size_); + IPEndPoint endpoint(ip_addr, port_host); + request_endpoint_ = HostPortPair::FromIPEndPoint(endpoint); + } + buffer_.clear(); + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; + } + + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int Socks5ServerSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + const char write_data[] = { + // clang-format off + kSOCKS5Version, + reply_, + kSOCKS5Reserved, + kEndPointResolvedIPv4, + 0x00, 0x00, 0x00, 0x00, // BND.ADDR + 0x00, 0x00, // BND.PORT + // clang-format on + }; + buffer_ = std::string(write_data, std::size(write_data)); + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = base::MakeRefCounted(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (reply_ == kReplySuccess) { + completed_handshake_ = true; + next_state_ = STATE_NONE; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR, + "error_code", reply_); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_HANDSHAKE_WRITE; + } + + return OK; +} + +int Socks5ServerSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int Socks5ServerSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/socks5_server_socket.h b/src/net/tools/naive/socks5_server_socket.h new file mode 100644 index 0000000000..41d1fceb0a --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.h @@ -0,0 +1,166 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ +#define NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { +struct NetworkTrafficAnnotationTag; + +// This StreamSocket is used to setup a SOCKSv5 handshake with a socks client. +// Currently no SOCKSv5 authentication is supported. +class Socks5ServerSocket : public StreamSocket { + public: + Socks5ServerSocket(std::unique_ptr transport_socket, + const std::string& user, + const std::string& pass, + const NetworkTrafficAnnotationTag& traffic_annotation); + + // On destruction Disconnect() is called. + ~Socks5ServerSocket() override; + + Socks5ServerSocket(const Socks5ServerSocket&) = delete; + Socks5ServerSocket& operator=(const Socks5ServerSocket&) = delete; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_GREET_READ, + STATE_GREET_READ_COMPLETE, + STATE_GREET_WRITE, + STATE_GREET_WRITE_COMPLETE, + STATE_AUTH_READ, + STATE_AUTH_READ_COMPLETE, + STATE_AUTH_WRITE, + STATE_AUTH_WRITE_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + // Addressing type that can be specified in requests or responses. + enum SocksEndPointAddressType { + kEndPointDomain = 0x03, + kEndPointResolvedIPv4 = 0x01, + kEndPointResolvedIPv6 = 0x04, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoGreetRead(); + int DoGreetReadComplete(int result); + int DoGreetWrite(); + int DoGreetWriteComplete(int result); + int DoAuthRead(); + int DoAuthReadComplete(int result); + int DoAuthWrite(); + int DoAuthWriteComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // Contains the bytes sent by the SOCKS handshake. + size_t bytes_sent_; + + size_t read_header_size_; + + bool was_ever_used_; + + SocksEndPointAddressType address_type_; + int address_size_; + + std::string user_; + std::string pass_; + char auth_method_; + char auth_status_; + char reply_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_