From 7a115772effa573374f8533322517b8005ca751a Mon Sep 17 00:00:00 2001 From: klzgrad Date: Sat, 8 Dec 2018 00:51:40 -0500 Subject: [PATCH] Add server implementation and tunnel padding --- net/tools/naive/http_proxy_socket.cc | 327 ++++++++++++++++++++++++ net/tools/naive/http_proxy_socket.h | 123 +++++++++ net/tools/naive/naive_connection.cc | 127 ++++++++- net/tools/naive/naive_connection.h | 27 +- net/tools/naive/naive_proxy.cc | 30 ++- net/tools/naive/naive_proxy_bin.cc | 31 ++- net/tools/naive/socks5_server_socket.cc | 10 +- net/tools/naive/socks5_server_socket.h | 3 +- 8 files changed, 639 insertions(+), 39 deletions(-) create mode 100644 net/tools/naive/http_proxy_socket.cc create mode 100644 net/tools/naive/http_proxy_socket.h diff --git a/net/tools/naive/http_proxy_socket.cc b/net/tools/naive/http_proxy_socket.cc new file mode 100644 index 0000000000..fd6e6a055d --- /dev/null +++ b/net/tools/naive/http_proxy_socket.cc @@ -0,0 +1,327 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/http_proxy_socket.h" + +#include +#include + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/log/net_log.h" + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +constexpr size_t kMaxHeaderSize = 64 * 1024; +constexpr char kResponseHeader[] = "HTTP/1.1 200 OK\r\nPadding: "; +constexpr int kResponseHeaderSize = sizeof(kResponseHeader) - 1; +// A plain 200 is 10 bytes. Expected 48 bytes. "Padding" uses up 7 bytes. +constexpr int kMinPaddingSize = 30; +constexpr int kMaxPaddingSize = kMinPaddingSize + 32; +} // namespace + +HttpProxySocket::HttpProxySocket( + std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation) + : io_callback_(base::BindRepeating(&HttpProxySocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + was_ever_used_(false), + header_write_size_(-1), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation) {} + +HttpProxySocket::~HttpProxySocket() { + Disconnect(); +} + +const HostPortPair& HttpProxySocket::request_endpoint() const { + return request_endpoint_; +} + +int HttpProxySocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_HEADER_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } + return rv; +} + +void HttpProxySocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool HttpProxySocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool HttpProxySocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& HttpProxySocket::NetLog() const { + return net_log_; +} + +bool HttpProxySocket::WasEverUsed() const { + return was_ever_used_; +} + +bool HttpProxySocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto HttpProxySocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool HttpProxySocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +void HttpProxySocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); +} + +int64_t HttpProxySocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void HttpProxySocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the HTTP header is complete. +int HttpProxySocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int HttpProxySocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int HttpProxySocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int HttpProxySocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void HttpProxySocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void HttpProxySocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +void HttpProxySocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int HttpProxySocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_HEADER_READ: + DCHECK_EQ(OK, rv); + rv = DoHeaderRead(); + break; + case STATE_HEADER_READ_COMPLETE: + rv = DoHeaderReadComplete(rv); + break; + case STATE_HEADER_WRITE: + DCHECK_EQ(OK, rv); + rv = DoHeaderWrite(); + break; + case STATE_HEADER_WRITE_COMPLETE: + rv = DoHeaderWriteComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int HttpProxySocket::DoHeaderRead() { + next_state_ = STATE_HEADER_READ_COMPLETE; + + handshake_buf_ = new IOBuffer(kBufferSize); + return transport_->Read(handshake_buf_.get(), kBufferSize, io_callback_); +} + +int HttpProxySocket::DoHeaderReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_CONNECTION_CLOSED; + } + + buffer_.append(handshake_buf_->data(), result); + if (buffer_.size() > kMaxHeaderSize) { + return ERR_MSG_TOO_BIG; + } + + auto header_end = buffer_.find("\r\n\r\n"); + if (header_end == std::string::npos) { + next_state_ = STATE_HEADER_READ; + return OK; + } + if (header_end + 4 != buffer_.size()) { + return ERR_INVALID_ARGUMENT; + } + + // HttpProxyClientSocket uses CONNECT for all endpoints. + auto first_line_end = buffer_.find("\r\n"); + auto first_space = buffer_.find(' '); + if (first_space == std::string::npos || first_space + 1 >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + if (buffer_.compare(0, first_space, "CONNECT") != 0) { + return ERR_INVALID_ARGUMENT; + } + auto second_space = buffer_.find(' ', first_space + 1); + if (second_space == std::string::npos || second_space >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + request_endpoint_ = HostPortPair::FromString( + buffer_.substr(first_space + 1, second_space - (first_space + 1))); + + next_state_ = STATE_HEADER_WRITE; + return OK; +} + +int HttpProxySocket::DoHeaderWrite() { + next_state_ = STATE_HEADER_WRITE_COMPLETE; + + // Adds padding. + int padding_size = base::RandInt(kMinPaddingSize, kMaxPaddingSize); + header_write_size_ = kResponseHeaderSize + padding_size + 4; + handshake_buf_ = new IOBuffer(header_write_size_); + char* p = handshake_buf_->data(); + std::memcpy(p, kResponseHeader, kResponseHeaderSize); + std::memset(p + kResponseHeaderSize, '.', padding_size); + std::memcpy(p + kResponseHeaderSize + padding_size, "\r\n\r\n", 4); + + return transport_->Write(handshake_buf_.get(), header_write_size_, + io_callback_, traffic_annotation_); +} + +int HttpProxySocket::DoHeaderWriteComplete(int result) { + if (result < 0) + return result; + + if (result != header_write_size_) { + return ERR_FAILED; + } + + completed_handshake_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int HttpProxySocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int HttpProxySocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/net/tools/naive/http_proxy_socket.h b/net/tools/naive/http_proxy_socket.h new file mode 100644 index 0000000000..85888eee8a --- /dev/null +++ b/net/tools/naive/http_proxy_socket.h @@ -0,0 +1,123 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ +#define NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "net/base/completion_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" +#include "net/traffic_annotation/network_traffic_annotation.h" + +namespace net { + +// This StreamSocket is used to setup a HTTP CONNECT tunnel. +class HttpProxySocket : public StreamSocket { + public: + HttpProxySocket(std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation); + + // On destruction Disconnect() is called. + ~HttpProxySocket() override; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_HEADER_READ, + STATE_HEADER_READ_COMPLETE, + STATE_HEADER_WRITE, + STATE_HEADER_WRITE_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoHeaderWrite(); + int DoHeaderWriteComplete(int result); + int DoHeaderRead(); + int DoHeaderReadComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + std::string buffer_; + bool completed_handshake_; + bool was_ever_used_; + int header_write_size_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + NetworkTrafficAnnotationTag traffic_annotation_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxySocket); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ diff --git a/net/tools/naive/naive_connection.cc b/net/tools/naive/naive_connection.cc index 935080dd2c..81977723a3 100644 --- a/net/tools/naive/naive_connection.cc +++ b/net/tools/naive/naive_connection.cc @@ -5,11 +5,13 @@ #include "net/tools/naive/naive_connection.h" +#include #include #include "base/bind.h" #include "base/callback_helpers.h" #include "base/logging.h" +#include "base/rand_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/privacy_mode.h" @@ -26,15 +28,20 @@ namespace net { namespace { -static const int kBufferSize = 64 * 1024; +constexpr int kBufferSize = 64 * 1024; +constexpr int kFirstPaddings = 4; +constexpr int kPaddingHeaderSize = 3; +constexpr int kMaxPaddingSize = 255; } // namespace NaiveConnection::NaiveConnection( unsigned int id, + Direction pad_direction, std::unique_ptr accepted_socket, Delegate* delegate, const NetworkTrafficAnnotationTag& traffic_annotation) : id_(id), + pad_direction_(pad_direction), next_state_(STATE_NONE), delegate_(delegate), client_socket_(std::move(accepted_socket)), @@ -45,6 +52,8 @@ NaiveConnection::NaiveConnection( early_pull_pending_(false), can_push_to_server_(false), early_pull_result_(ERR_IO_PENDING), + num_paddings_{0, 0}, + read_padding_state_(STATE_READ_PAYLOAD_LENGTH_1), full_duplex_(false), time_func_(&base::TimeTicks::Now), traffic_annotation_(traffic_annotation), @@ -211,10 +220,20 @@ void NaiveConnection::Pull(Direction from, Direction to) { if (errors_[kClient] < 0 || errors_[kServer] < 0) return; - read_buffers_[from] = new IOBuffer(kBufferSize); + int read_size = kBufferSize; + if (from == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + auto* buffer = new GrowableIOBuffer; + buffer->SetCapacity(kBufferSize); + buffer->set_offset(kPaddingHeaderSize); + read_buffers_[from] = buffer; + read_size = kBufferSize - kPaddingHeaderSize - kMaxPaddingSize; + } else { + read_buffers_[from] = new IOBuffer(kBufferSize); + } + DCHECK(sockets_[from]); int rv = sockets_[from]->Read( - read_buffers_[from].get(), kBufferSize, + read_buffers_[from].get(), read_size, base::BindRepeating(&NaiveConnection::OnPullComplete, weak_ptr_factory_.GetWeakPtr(), from, to)); @@ -226,11 +245,107 @@ void NaiveConnection::Pull(Direction from, Direction to) { } void NaiveConnection::Push(Direction from, Direction to, int size) { - write_buffers_[to] = new DrainableIOBuffer(read_buffers_[from].get(), size); + int write_size = size; + int write_offset = 0; + if (from == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + // Adds padding. + ++num_paddings_[from]; + int padding_size = base::RandInt(0, kMaxPaddingSize); + auto* buffer = static_cast(read_buffers_[from].get()); + buffer->set_offset(0); + uint8_t* p = reinterpret_cast(buffer->data()); + p[0] = size / 256; + p[1] = size % 256; + p[2] = padding_size; + std::memset(p + kPaddingHeaderSize + size, 0, padding_size); + write_size = kPaddingHeaderSize + size + padding_size; + } else if (to == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + // Removes padding. + const char* p = read_buffers_[from]->data(); + bool trivial_padding = false; + if (read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1 && + size >= kPaddingHeaderSize) { + int payload_size = + static_cast(p[0]) * 256 + static_cast(p[1]); + int padding_size = static_cast(p[2]); + if (size == kPaddingHeaderSize + payload_size + padding_size) { + write_size = payload_size; + write_offset = kPaddingHeaderSize; + ++num_paddings_[from]; + trivial_padding = true; + } + } + if (!trivial_padding) { + auto* unpadded_buffer = new IOBuffer(kBufferSize); + char* unpadded_ptr = unpadded_buffer->data(); + for (int i = 0; i < size;) { + if (num_paddings_[from] >= kFirstPaddings && + read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1) { + std::memcpy(unpadded_ptr, p + i, size - i); + unpadded_ptr += size - i; + break; + } + int copy_size; + switch (read_padding_state_) { + case STATE_READ_PAYLOAD_LENGTH_1: + payload_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_2; + break; + case STATE_READ_PAYLOAD_LENGTH_2: + payload_length_ = + payload_length_ * 256 + static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PADDING_LENGTH; + break; + case STATE_READ_PADDING_LENGTH: + padding_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD; + break; + case STATE_READ_PAYLOAD: + if (payload_length_ <= size - i) { + copy_size = payload_length_; + read_padding_state_ = STATE_READ_PADDING; + } else { + copy_size = size - i; + } + std::memcpy(unpadded_ptr, p + i, copy_size); + unpadded_ptr += copy_size; + i += copy_size; + payload_length_ -= copy_size; + break; + case STATE_READ_PADDING: + if (padding_length_ <= size - i) { + copy_size = padding_length_; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_1; + ++num_paddings_[from]; + } else { + copy_size = size - i; + } + i += copy_size; + padding_length_ -= copy_size; + break; + } + } + write_size = unpadded_ptr - unpadded_buffer->data(); + read_buffers_[from] = unpadded_buffer; + } + if (write_size == 0) { + OnPushComplete(from, to, OK); + return; + } + } + + write_buffers_[to] = new DrainableIOBuffer(read_buffers_[from].get(), + write_offset + write_size); + if (write_offset) { + write_buffers_[to]->DidConsume(write_offset); + } write_pending_[to] = true; DCHECK(sockets_[to]); int rv = sockets_[to]->Write( - write_buffers_[to].get(), size, + write_buffers_[to].get(), write_size, base::BindRepeating(&NaiveConnection::OnPushComplete, weak_ptr_factory_.GetWeakPtr(), from, to), traffic_annotation_); @@ -309,7 +424,7 @@ void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) { } void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) { - if (result >= 0) { + if (result >= 0 && write_buffers_[to] != nullptr) { bytes_passed_without_yielding_[from] += result; write_buffers_[to]->DidConsume(result); int size = write_buffers_[to]->BytesRemaining(); diff --git a/net/tools/naive/naive_connection.h b/net/tools/naive/naive_connection.h index 30bc5dcf00..82239916a6 100644 --- a/net/tools/naive/naive_connection.h +++ b/net/tools/naive/naive_connection.h @@ -7,6 +7,7 @@ #define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ #include +#include #include "base/macros.h" #include "base/memory/ref_counted.h" @@ -27,6 +28,14 @@ class NaiveConnection { public: using TimeFunc = base::TimeTicks (*)(); + // From this direction. + enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + kNone = 2, + }; + class Delegate { public: Delegate() {} @@ -42,6 +51,7 @@ class NaiveConnection { }; NaiveConnection(unsigned int id, + Direction pad_direction, std::unique_ptr accepted_socket, Delegate* delegate, const NetworkTrafficAnnotationTag& traffic_annotation); @@ -61,11 +71,12 @@ class NaiveConnection { STATE_NONE, }; - // From this direction. - enum Direction { - kClient = 0, - kServer = 1, - kNumDirections = 2, + enum PaddingState { + STATE_READ_PAYLOAD_LENGTH_1, + STATE_READ_PAYLOAD_LENGTH_2, + STATE_READ_PADDING_LENGTH, + STATE_READ_PAYLOAD, + STATE_READ_PADDING, }; void DoCallback(int result); @@ -86,6 +97,7 @@ class NaiveConnection { void OnPushComplete(Direction from, Direction to, int result); unsigned int id_; + Direction pad_direction_; CompletionRepeatingCallback io_callback_; CompletionOnceCallback connect_callback_; @@ -110,6 +122,11 @@ class NaiveConnection { bool can_push_to_server_; int early_pull_result_; + int num_paddings_[kNumDirections]; + PaddingState read_padding_state_; + int payload_length_; + int padding_length_; + bool full_duplex_; TimeFunc time_func_; diff --git a/net/tools/naive/naive_proxy.cc b/net/tools/naive/naive_proxy.cc index 311017e167..f202391637 100644 --- a/net/tools/naive/naive_proxy.cc +++ b/net/tools/naive/naive_proxy.cc @@ -20,6 +20,8 @@ #include "net/socket/client_socket_pool_manager.h" #include "net/socket/server_socket.h" #include "net/socket/stream_socket.h" +#include "net/third_party/quic/core/quic_versions.h" +#include "net/tools/naive/http_proxy_socket.h" #include "net/tools/naive/socks5_server_socket.h" namespace net { @@ -76,13 +78,20 @@ void NaiveProxy::HandleAcceptResult(int result) { void NaiveProxy::DoConnect() { std::unique_ptr socket; + NaiveConnection::Direction pad_direction; if (protocol_ == kSocks5) { - socket = std::make_unique(std::move(accepted_socket_)); + socket = std::make_unique(std::move(accepted_socket_), + traffic_annotation_); + pad_direction = NaiveConnection::kClient; + } else if (protocol_ == kHttp) { + socket = std::make_unique(std::move(accepted_socket_), + traffic_annotation_); + pad_direction = NaiveConnection::kServer; } else { return; } auto connection_ptr = std::make_unique( - ++last_id_, std::move(socket), this, traffic_annotation_); + ++last_id_, pad_direction, std::move(socket), this, traffic_annotation_); auto* connection = connection_ptr.get(); connection_by_id_[connection->id()] = std::move(connection_ptr); int result = connection->Connect( @@ -119,25 +128,32 @@ int NaiveProxy::OnConnectServer(unsigned int connection_id, HttpRequestInfo req_info; session_->GetSSLConfig(req_info, &server_ssl_config, &proxy_ssl_config); proxy_ssl_config.disable_cert_verification_network_fetches = true; + } else { + proxy_info.UseDirect(); } HostPortPair request_endpoint; if (protocol_ == kSocks5) { const auto* socket = static_cast(client_socket); request_endpoint = socket->request_endpoint(); + } else if (protocol_ == kHttp) { + const auto* socket = static_cast(client_socket); + request_endpoint = socket->request_endpoint(); } - if (request_endpoint.port() == 0) { - LOG(ERROR) << "Connection " << connection_id << " has invalid upstream"; + if (request_endpoint.IsEmpty()) { + LOG(ERROR) << "Connection " << connection_id << " to invalid origin"; return ERR_ADDRESS_INVALID; } LOG(INFO) << "Connection " << connection_id << " to " << request_endpoint.ToString(); - return InitSocketHandleForRawConnect( + auto quic_version = quic::QUIC_VERSION_UNSUPPORTED; + + return InitSocketHandleForRawConnect2( request_endpoint, session_, request_load_flags, request_priority, - proxy_info, server_ssl_config, proxy_ssl_config, PRIVACY_MODE_DISABLED, - net_log_, server_socket, callback); + proxy_info, quic_version, server_ssl_config, proxy_ssl_config, + PRIVACY_MODE_DISABLED, net_log_, server_socket, callback); } void NaiveProxy::OnConnectComplete(int connection_id, int result) { diff --git a/net/tools/naive/naive_proxy_bin.cc b/net/tools/naive/naive_proxy_bin.cc index 30fd89c6bc..e5754cc881 100644 --- a/net/tools/naive/naive_proxy_bin.cc +++ b/net/tools/naive/naive_proxy_bin.cc @@ -108,17 +108,19 @@ std::unique_ptr BuildURLRequestContext( net::NetLog* net_log) { net::URLRequestContextBuilder builder; + builder.DisableHttpCache(); + builder.set_net_log(net_log); + net::ProxyConfig proxy_config; - proxy_config.proxy_rules().ParseFromString(params.proxy_url); + if (params.use_proxy) { + proxy_config.proxy_rules().ParseFromString(params.proxy_url); + } auto proxy_service = net::ProxyResolutionService::CreateWithoutProxyResolver( std::make_unique( net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), net_log); proxy_service->ForceReloadProxyConfig(); - builder.set_proxy_resolution_service(std::move(proxy_service)); - builder.DisableHttpCache(); - builder.set_net_log(net_log); if (!params.host_resolver_rules.empty()) { auto remapped_resolver = std::make_unique( @@ -129,15 +131,17 @@ std::unique_ptr BuildURLRequestContext( auto context = builder.Build(); - net::HttpNetworkSession* session = - context->http_transaction_factory()->GetSession(); - net::HttpAuthCache* auth_cache = session->http_auth_cache(); - GURL auth_origin(params.proxy_url); - net::AuthCredentials credentials(base::ASCIIToUTF16(params.proxy_user), - base::ASCIIToUTF16(params.proxy_pass)); - auth_cache->Add(auth_origin, /*realm=*/std::string(), - net::HttpAuth::AUTH_SCHEME_BASIC, /*challenge=*/"Basic", - credentials, /*path=*/"/"); + if (params.use_proxy) { + net::HttpNetworkSession* session = + context->http_transaction_factory()->GetSession(); + net::HttpAuthCache* auth_cache = session->http_auth_cache(); + GURL auth_origin(params.proxy_url); + net::AuthCredentials credentials(base::ASCIIToUTF16(params.proxy_user), + base::ASCIIToUTF16(params.proxy_pass)); + auth_cache->Add(auth_origin, /*realm=*/std::string(), + net::HttpAuth::AUTH_SCHEME_BASIC, /*challenge=*/"Basic", + credentials, /*path=*/"/"); + } return context; } @@ -205,6 +209,7 @@ bool ParseCommandLineFlags(Params* params) { } } + params->use_proxy = false; GURL url(line.GetSwitchValueASCII("proxy")); if (line.HasSwitch("proxy")) { params->use_proxy = true; diff --git a/net/tools/naive/socks5_server_socket.cc b/net/tools/naive/socks5_server_socket.cc index d11eefec32..65643958a1 100644 --- a/net/tools/naive/socks5_server_socket.cc +++ b/net/tools/naive/socks5_server_socket.cc @@ -32,13 +32,9 @@ const char Socks5ServerSocket::kReplyCommandNotSupported = '\x07'; static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); -namespace { -constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = - net::DefineNetworkTrafficAnnotation("naive", ""); -} // namespace - Socks5ServerSocket::Socks5ServerSocket( - std::unique_ptr transport_socket) + std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation) : io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete, base::Unretained(this))), transport_(std::move(transport_socket)), @@ -50,7 +46,7 @@ Socks5ServerSocket::Socks5ServerSocket( read_header_size_(kReadHeaderSize), was_ever_used_(false), net_log_(transport_->NetLog()), - traffic_annotation_(kTrafficAnnotation) {} + traffic_annotation_(traffic_annotation) {} Socks5ServerSocket::~Socks5ServerSocket() { Disconnect(); diff --git a/net/tools/naive/socks5_server_socket.h b/net/tools/naive/socks5_server_socket.h index ce996cd313..18848bd628 100644 --- a/net/tools/naive/socks5_server_socket.h +++ b/net/tools/naive/socks5_server_socket.h @@ -31,7 +31,8 @@ namespace net { // Currently no SOCKSv5 authentication is supported. class Socks5ServerSocket : public StreamSocket { public: - explicit Socks5ServerSocket(std::unique_ptr transport_socket); + Socks5ServerSocket(std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation); // On destruction Disconnect() is called. ~Socks5ServerSocket() override;