From 443962d602f4f4aa10f471fcfb48322dbdf979e3 Mon Sep 17 00:00:00 2001 From: klzgrad Date: Sat, 8 Dec 2018 00:51:40 -0500 Subject: [PATCH] Add server implementation and tunnel padding --- src/net/tools/naive/http_proxy_socket.cc | 339 +++++++++++++++++++++++ src/net/tools/naive/http_proxy_socket.h | 123 ++++++++ src/net/tools/naive/naive_connection.cc | 126 ++++++++- src/net/tools/naive/naive_connection.h | 27 +- src/net/tools/naive/naive_proxy.cc | 21 +- 5 files changed, 621 insertions(+), 15 deletions(-) create mode 100644 src/net/tools/naive/http_proxy_socket.cc create mode 100644 src/net/tools/naive/http_proxy_socket.h diff --git a/src/net/tools/naive/http_proxy_socket.cc b/src/net/tools/naive/http_proxy_socket.cc new file mode 100644 index 0000000000..3fc4377276 --- /dev/null +++ b/src/net/tools/naive/http_proxy_socket.cc @@ -0,0 +1,339 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/http_proxy_socket.h" + +#include +#include + +#include "base/bind.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/rand_util.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/log/net_log.h" + +namespace net { + +namespace { +constexpr int kBufferSize = 64 * 1024; +constexpr size_t kMaxHeaderSize = 64 * 1024; +constexpr char kResponseHeader[] = "HTTP/1.1 200 OK\r\nPadding: "; +constexpr int kResponseHeaderSize = sizeof(kResponseHeader) - 1; +// A plain 200 is 10 bytes. Expected 48 bytes. "Padding" uses up 7 bytes. +constexpr int kMinPaddingSize = 30; +constexpr int kMaxPaddingSize = kMinPaddingSize + 32; +} // namespace + +HttpProxySocket::HttpProxySocket( + std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation) + : io_callback_(base::BindRepeating(&HttpProxySocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + was_ever_used_(false), + header_write_size_(-1), + net_log_(transport_->NetLog()), + traffic_annotation_(traffic_annotation) {} + +HttpProxySocket::~HttpProxySocket() { + Disconnect(); +} + +const HostPortPair& HttpProxySocket::request_endpoint() const { + return request_endpoint_; +} + +int HttpProxySocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + next_state_ = STATE_HEADER_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } + return rv; +} + +void HttpProxySocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool HttpProxySocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool HttpProxySocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& HttpProxySocket::NetLog() const { + return net_log_; +} + +bool HttpProxySocket::WasEverUsed() const { + return was_ever_used_; +} + +bool HttpProxySocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto HttpProxySocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool HttpProxySocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +void HttpProxySocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); +} + +int64_t HttpProxySocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void HttpProxySocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the HTTP header is complete. +int HttpProxySocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + if (!buffer_.empty()) { + was_ever_used_ = true; + int data_len = buffer_.size(); + if (data_len <= buf_len) { + std::memcpy(buf->data(), buffer_.data(), data_len); + buffer_.clear(); + return data_len; + } else { + std::memcpy(buf->data(), buffer_.data(), buf_len); + buffer_ = buffer_.substr(buf_len); + return buf_len; + } + } + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int HttpProxySocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&HttpProxySocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int HttpProxySocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int HttpProxySocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void HttpProxySocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void HttpProxySocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +void HttpProxySocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int HttpProxySocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_HEADER_READ: + DCHECK_EQ(OK, rv); + rv = DoHeaderRead(); + break; + case STATE_HEADER_READ_COMPLETE: + rv = DoHeaderReadComplete(rv); + break; + case STATE_HEADER_WRITE: + DCHECK_EQ(OK, rv); + rv = DoHeaderWrite(); + break; + case STATE_HEADER_WRITE_COMPLETE: + rv = DoHeaderWriteComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int HttpProxySocket::DoHeaderRead() { + next_state_ = STATE_HEADER_READ_COMPLETE; + + handshake_buf_ = base::MakeRefCounted(kBufferSize); + return transport_->Read(handshake_buf_.get(), kBufferSize, io_callback_); +} + +int HttpProxySocket::DoHeaderReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + return ERR_CONNECTION_CLOSED; + } + + buffer_.append(handshake_buf_->data(), result); + if (buffer_.size() > kMaxHeaderSize) { + return ERR_MSG_TOO_BIG; + } + + auto header_end = buffer_.find("\r\n\r\n"); + if (header_end == std::string::npos) { + next_state_ = STATE_HEADER_READ; + return OK; + } + + // HttpProxyClientSocket uses CONNECT for all endpoints. + auto first_line_end = buffer_.find("\r\n"); + auto first_space = buffer_.find(' '); + if (first_space == std::string::npos || first_space + 1 >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + if (buffer_.compare(0, first_space, "CONNECT") != 0) { + return ERR_INVALID_ARGUMENT; + } + auto second_space = buffer_.find(' ', first_space + 1); + if (second_space == std::string::npos || second_space >= first_line_end) { + return ERR_INVALID_ARGUMENT; + } + request_endpoint_ = HostPortPair::FromString( + buffer_.substr(first_space + 1, second_space - (first_space + 1))); + + buffer_ = buffer_.substr(header_end + 4); + + next_state_ = STATE_HEADER_WRITE; + return OK; +} + +int HttpProxySocket::DoHeaderWrite() { + next_state_ = STATE_HEADER_WRITE_COMPLETE; + + // Adds padding. + int padding_size = base::RandInt(kMinPaddingSize, kMaxPaddingSize); + header_write_size_ = kResponseHeaderSize + padding_size + 4; + handshake_buf_ = base::MakeRefCounted(header_write_size_); + char* p = handshake_buf_->data(); + std::memcpy(p, kResponseHeader, kResponseHeaderSize); + std::memset(p + kResponseHeaderSize, '.', padding_size); + std::memcpy(p + kResponseHeaderSize + padding_size, "\r\n\r\n", 4); + + return transport_->Write(handshake_buf_.get(), header_write_size_, + io_callback_, traffic_annotation_); +} + +int HttpProxySocket::DoHeaderWriteComplete(int result) { + if (result < 0) + return result; + + if (result != header_write_size_) { + return ERR_FAILED; + } + + completed_handshake_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int HttpProxySocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int HttpProxySocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/http_proxy_socket.h b/src/net/tools/naive/http_proxy_socket.h new file mode 100644 index 0000000000..62d846018e --- /dev/null +++ b/src/net/tools/naive/http_proxy_socket.h @@ -0,0 +1,123 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ +#define NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/macros.h" +#include "base/memory/scoped_refptr.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { +struct NetworkTrafficAnnotationTag; + +// This StreamSocket is used to setup a HTTP CONNECT tunnel. +class HttpProxySocket : public StreamSocket { + public: + HttpProxySocket(std::unique_ptr transport_socket, + const NetworkTrafficAnnotationTag& traffic_annotation); + + // On destruction Disconnect() is called. + ~HttpProxySocket() override; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_HEADER_READ, + STATE_HEADER_READ_COMPLETE, + STATE_HEADER_WRITE, + STATE_HEADER_WRITE_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoHeaderWrite(); + int DoHeaderWriteComplete(int result); + int DoHeaderRead(); + int DoHeaderReadComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + std::string buffer_; + bool completed_handshake_; + bool was_ever_used_; + int header_write_size_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + DISALLOW_COPY_AND_ASSIGN(HttpProxySocket); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_ diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc index 82e557d457..435f102fff 100644 --- a/src/net/tools/naive/naive_connection.cc +++ b/src/net/tools/naive/naive_connection.cc @@ -5,11 +5,13 @@ #include "net/tools/naive/naive_connection.h" +#include #include #include "base/bind.h" #include "base/callback_helpers.h" #include "base/logging.h" +#include "base/rand_util.h" #include "base/threading/thread_task_runner_handle.h" #include "net/base/io_buffer.h" #include "net/base/load_flags.h" @@ -28,15 +30,20 @@ namespace net { namespace { -static const int kBufferSize = 64 * 1024; +constexpr int kBufferSize = 64 * 1024; +constexpr int kFirstPaddings = 4; +constexpr int kPaddingHeaderSize = 3; +constexpr int kMaxPaddingSize = 255; } // namespace NaiveConnection::NaiveConnection( unsigned int id, + Direction pad_direction, std::unique_ptr accepted_socket, Delegate* delegate, const NetworkTrafficAnnotationTag& traffic_annotation) : id_(id), + pad_direction_(pad_direction), next_state_(STATE_NONE), delegate_(delegate), client_socket_(std::move(accepted_socket)), @@ -47,6 +54,8 @@ NaiveConnection::NaiveConnection( early_pull_pending_(false), can_push_to_server_(false), early_pull_result_(ERR_IO_PENDING), + num_paddings_{0, 0}, + read_padding_state_(STATE_READ_PAYLOAD_LENGTH_1), full_duplex_(false), time_func_(&base::TimeTicks::Now), traffic_annotation_(traffic_annotation) { @@ -212,10 +221,20 @@ void NaiveConnection::Pull(Direction from, Direction to) { if (errors_[kClient] < 0 || errors_[kServer] < 0) return; - read_buffers_[from] = base::MakeRefCounted(kBufferSize); + int read_size = kBufferSize; + if (from == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + auto buffer = base::MakeRefCounted(); + buffer->SetCapacity(kBufferSize); + buffer->set_offset(kPaddingHeaderSize); + read_buffers_[from] = buffer; + read_size = kBufferSize - kPaddingHeaderSize - kMaxPaddingSize; + } else { + read_buffers_[from] = base::MakeRefCounted(kBufferSize); + } + DCHECK(sockets_[from]); int rv = sockets_[from]->Read( - read_buffers_[from].get(), kBufferSize, + read_buffers_[from].get(), read_size, base::BindRepeating(&NaiveConnection::OnPullComplete, weak_ptr_factory_.GetWeakPtr(), from, to)); @@ -227,12 +246,107 @@ void NaiveConnection::Pull(Direction from, Direction to) { } void NaiveConnection::Push(Direction from, Direction to, int size) { + int write_size = size; + int write_offset = 0; + if (from == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + // Adds padding. + ++num_paddings_[from]; + int padding_size = base::RandInt(0, kMaxPaddingSize); + auto* buffer = static_cast(read_buffers_[from].get()); + buffer->set_offset(0); + uint8_t* p = reinterpret_cast(buffer->data()); + p[0] = size / 256; + p[1] = size % 256; + p[2] = padding_size; + std::memset(p + kPaddingHeaderSize + size, 0, padding_size); + write_size = kPaddingHeaderSize + size + padding_size; + } else if (to == pad_direction_ && num_paddings_[from] < kFirstPaddings) { + // Removes padding. + const char* p = read_buffers_[from]->data(); + bool trivial_padding = false; + if (read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1 && + size >= kPaddingHeaderSize) { + int payload_size = + static_cast(p[0]) * 256 + static_cast(p[1]); + int padding_size = static_cast(p[2]); + if (size == kPaddingHeaderSize + payload_size + padding_size) { + write_size = payload_size; + write_offset = kPaddingHeaderSize; + ++num_paddings_[from]; + trivial_padding = true; + } + } + if (!trivial_padding) { + auto unpadded_buffer = base::MakeRefCounted(kBufferSize); + char* unpadded_ptr = unpadded_buffer->data(); + for (int i = 0; i < size;) { + if (num_paddings_[from] >= kFirstPaddings && + read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1) { + std::memcpy(unpadded_ptr, p + i, size - i); + unpadded_ptr += size - i; + break; + } + int copy_size; + switch (read_padding_state_) { + case STATE_READ_PAYLOAD_LENGTH_1: + payload_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_2; + break; + case STATE_READ_PAYLOAD_LENGTH_2: + payload_length_ = + payload_length_ * 256 + static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PADDING_LENGTH; + break; + case STATE_READ_PADDING_LENGTH: + padding_length_ = static_cast(p[i]); + ++i; + read_padding_state_ = STATE_READ_PAYLOAD; + break; + case STATE_READ_PAYLOAD: + if (payload_length_ <= size - i) { + copy_size = payload_length_; + read_padding_state_ = STATE_READ_PADDING; + } else { + copy_size = size - i; + } + std::memcpy(unpadded_ptr, p + i, copy_size); + unpadded_ptr += copy_size; + i += copy_size; + payload_length_ -= copy_size; + break; + case STATE_READ_PADDING: + if (padding_length_ <= size - i) { + copy_size = padding_length_; + read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_1; + ++num_paddings_[from]; + } else { + copy_size = size - i; + } + i += copy_size; + padding_length_ -= copy_size; + break; + } + } + write_size = unpadded_ptr - unpadded_buffer->data(); + read_buffers_[from] = unpadded_buffer; + } + if (write_size == 0) { + OnPushComplete(from, to, OK); + return; + } + } + write_buffers_[to] = base::MakeRefCounted( - std::move(read_buffers_[from]), size); + std::move(read_buffers_[from]), write_offset + write_size); + if (write_offset) { + write_buffers_[to]->DidConsume(write_offset); + } write_pending_[to] = true; DCHECK(sockets_[to]); int rv = sockets_[to]->Write( - write_buffers_[to].get(), size, + write_buffers_[to].get(), write_size, base::BindRepeating(&NaiveConnection::OnPushComplete, weak_ptr_factory_.GetWeakPtr(), from, to), traffic_annotation_); @@ -311,7 +425,7 @@ void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) { } void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) { - if (result >= 0) { + if (result >= 0 && write_buffers_[to] != nullptr) { bytes_passed_without_yielding_[from] += result; write_buffers_[to]->DidConsume(result); int size = write_buffers_[to]->BytesRemaining(); diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h index 3b081dd001..d94d205d6b 100644 --- a/src/net/tools/naive/naive_connection.h +++ b/src/net/tools/naive/naive_connection.h @@ -7,6 +7,7 @@ #define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ #include +#include #include "base/macros.h" #include "base/memory/scoped_refptr.h" @@ -27,6 +28,14 @@ class NaiveConnection { public: using TimeFunc = base::TimeTicks (*)(); + // From this direction. + enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + kNone = 2, + }; + class Delegate { public: Delegate() {} @@ -42,6 +51,7 @@ class NaiveConnection { }; NaiveConnection(unsigned int id, + Direction pad_direction, std::unique_ptr accepted_socket, Delegate* delegate, const NetworkTrafficAnnotationTag& traffic_annotation); @@ -61,11 +71,12 @@ class NaiveConnection { STATE_NONE, }; - // From this direction. - enum Direction { - kClient = 0, - kServer = 1, - kNumDirections = 2, + enum PaddingState { + STATE_READ_PAYLOAD_LENGTH_1, + STATE_READ_PAYLOAD_LENGTH_2, + STATE_READ_PADDING_LENGTH, + STATE_READ_PAYLOAD, + STATE_READ_PADDING, }; void DoCallback(int result); @@ -86,6 +97,7 @@ class NaiveConnection { void OnPushComplete(Direction from, Direction to, int result); unsigned int id_; + Direction pad_direction_; CompletionRepeatingCallback io_callback_; CompletionOnceCallback connect_callback_; @@ -110,6 +122,11 @@ class NaiveConnection { bool can_push_to_server_; int early_pull_result_; + int num_paddings_[kNumDirections]; + PaddingState read_padding_state_; + int payload_length_; + int padding_length_; + bool full_duplex_; TimeFunc time_func_; diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc index d2d95c59b8..f55163d73b 100644 --- a/src/net/tools/naive/naive_proxy.cc +++ b/src/net/tools/naive/naive_proxy.cc @@ -21,6 +21,7 @@ #include "net/socket/client_socket_pool_manager.h" #include "net/socket/server_socket.h" #include "net/socket/stream_socket.h" +#include "net/tools/naive/http_proxy_socket.h" #include "net/tools/naive/socks5_server_socket.h" namespace net { @@ -76,13 +77,20 @@ void NaiveProxy::HandleAcceptResult(int result) { void NaiveProxy::DoConnect() { std::unique_ptr socket; + NaiveConnection::Direction pad_direction; if (protocol_ == kSocks5) { - socket = std::make_unique(std::move(accepted_socket_)); + socket = std::make_unique(std::move(accepted_socket_), + traffic_annotation_); + pad_direction = NaiveConnection::kClient; + } else if (protocol_ == kHttp) { + socket = std::make_unique(std::move(accepted_socket_), + traffic_annotation_); + pad_direction = NaiveConnection::kServer; } else { return; } auto connection_ptr = std::make_unique( - ++last_id_, std::move(socket), this, traffic_annotation_); + ++last_id_, pad_direction, std::move(socket), this, traffic_annotation_); auto* connection = connection_ptr.get(); connection_by_id_[connection->id()] = std::move(connection_ptr); int result = connection->Connect( @@ -120,15 +128,20 @@ int NaiveProxy::OnConnectServer(unsigned int connection_id, session_->GetSSLConfig(&server_ssl_config, &proxy_ssl_config); proxy_ssl_config.disable_cert_verification_network_fetches = true; + } else { + proxy_info.UseDirect(); } HostPortPair request_endpoint; if (protocol_ == kSocks5) { const auto* socket = static_cast(client_socket); request_endpoint = socket->request_endpoint(); + } else if (protocol_ == kHttp) { + const auto* socket = static_cast(client_socket); + request_endpoint = socket->request_endpoint(); } - if (request_endpoint.port() == 0) { - LOG(ERROR) << "Connection " << connection_id << " has invalid upstream"; + if (request_endpoint.IsEmpty()) { + LOG(ERROR) << "Connection " << connection_id << " to invalid origin"; return ERR_ADDRESS_INVALID; }