From cece06a2cd1e07a788e4888779180f38ef3e27fb Mon Sep 17 00:00:00 2001 From: klzgrad Date: Sat, 20 Jan 2018 12:15:45 -0500 Subject: [PATCH] Add initial implementation of Naive client --- src/net/log/net_log_event_type_list.h | 5 + src/net/tools/naive/naive_connection.cc | 342 ++++++++++++ src/net/tools/naive/naive_connection.h | 126 +++++ src/net/tools/naive/naive_proxy.cc | 209 ++++++++ src/net/tools/naive/naive_proxy.h | 84 +++ src/net/tools/naive/naive_proxy_bin.cc | 381 ++++++++++++++ src/net/tools/naive/socks5_server_socket.cc | 554 ++++++++++++++++++++ src/net/tools/naive/socks5_server_socket.h | 172 ++++++ 8 files changed, 1873 insertions(+) create mode 100644 src/net/tools/naive/naive_connection.cc create mode 100644 src/net/tools/naive/naive_connection.h create mode 100644 src/net/tools/naive/naive_proxy.cc create mode 100644 src/net/tools/naive/naive_proxy.h create mode 100644 src/net/tools/naive/naive_proxy_bin.cc create mode 100644 src/net/tools/naive/socks5_server_socket.cc create mode 100644 src/net/tools/naive/socks5_server_socket.h diff --git a/src/net/log/net_log_event_type_list.h b/src/net/log/net_log_event_type_list.h index b1aad66ea6..bb04c48a63 100644 --- a/src/net/log/net_log_event_type_list.h +++ b/src/net/log/net_log_event_type_list.h @@ -438,6 +438,11 @@ EVENT_TYPE(SOCKS_HOSTNAME_TOO_BIG) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE) +EVENT_TYPE(SOCKS_NO_REQUESTED_AUTH) +EVENT_TYPE(SOCKS_NO_ACCEPTABLE_AUTH) +EVENT_TYPE(SOCKS_ZERO_LENGTH_DOMAIN) +EVENT_TYPE(SOCKS_UNEXPECTED_COMMAND) + // This event indicates that a bad version number was received in the // proxy server's response. The extra parameters show its value: // { diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc new file mode 100644 index 0000000000..46cf37c9e6 --- /dev/null +++ b/src/net/tools/naive/naive_connection.cc @@ -0,0 +1,342 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_connection.h" + +#include + +#include "base/bind.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/threading/thread_task_runner_handle.h" +#include "net/base/io_buffer.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" +#include "net/base/privacy_mode.h" +#include "net/http/http_network_session.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/proxy_resolution/proxy_list.h" +#include "net/proxy_resolution/proxy_resolution_service.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/stream_socket.h" +#include "net/spdy/spdy_session.h" + +namespace net { + +namespace { +static const int kBufferSize = 64 * 1024; +} // namespace + +NaiveConnection::NaiveConnection( + unsigned int id, + std::unique_ptr accepted_socket, + Delegate* delegate, + const NetworkTrafficAnnotationTag& traffic_annotation) + : id_(id), + next_state_(STATE_NONE), + delegate_(delegate), + client_socket_(std::move(accepted_socket)), + server_socket_handle_(std::make_unique()), + sockets_{client_socket_.get(), nullptr}, + errors_{OK, OK}, + write_pending_{false, false}, + early_pull_pending_(false), + can_push_to_server_(false), + early_pull_result_(ERR_IO_PENDING), + full_duplex_(false), + time_func_(&base::TimeTicks::Now), + traffic_annotation_(traffic_annotation) { + io_callback_ = base::BindRepeating(&NaiveConnection::OnIOComplete, + weak_ptr_factory_.GetWeakPtr()); +} + +NaiveConnection::~NaiveConnection() { + Disconnect(); +} + +int NaiveConnection::Connect(CompletionOnceCallback callback) { + DCHECK(client_socket_); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + if (full_duplex_) + return OK; + + next_state_ = STATE_CONNECT_CLIENT; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = std::move(callback); + } + return rv; +} + +void NaiveConnection::Disconnect() { + full_duplex_ = false; + // Closes server side first because latency is higher. + if (server_socket_handle_->socket()) + server_socket_handle_->socket()->Disconnect(); + client_socket_->Disconnect(); + + next_state_ = STATE_NONE; + connect_callback_.Reset(); + run_callback_.Reset(); +} + +void NaiveConnection::DoCallback(int result) { + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(connect_callback_); + + // Since Run() may result in Read being called, + // clear connect_callback_ up front. + std::move(connect_callback_).Run(result); +} + +void NaiveConnection::OnIOComplete(int result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +int NaiveConnection::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT_CLIENT: + DCHECK_EQ(rv, OK); + rv = DoConnectClient(); + break; + case STATE_CONNECT_CLIENT_COMPLETE: + rv = DoConnectClientComplete(rv); + break; + case STATE_CONNECT_SERVER: + DCHECK_EQ(rv, OK); + rv = DoConnectServer(); + break; + case STATE_CONNECT_SERVER_COMPLETE: + rv = DoConnectServerComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int NaiveConnection::DoConnectClient() { + next_state_ = STATE_CONNECT_CLIENT_COMPLETE; + + return client_socket_->Connect(io_callback_); +} + +int NaiveConnection::DoConnectClientComplete(int result) { + if (result < 0) + return result; + + early_pull_pending_ = true; + Pull(kClient, kServer); + if (early_pull_result_ != ERR_IO_PENDING) { + // Pull has completed synchronously. + if (early_pull_result_ <= 0) { + return early_pull_result_ ? early_pull_result_ : ERR_CONNECTION_CLOSED; + } + } + + next_state_ = STATE_CONNECT_SERVER; + return OK; +} + +int NaiveConnection::DoConnectServer() { + DCHECK(delegate_); + next_state_ = STATE_CONNECT_SERVER_COMPLETE; + + return delegate_->OnConnectServer(id_, client_socket_.get(), + server_socket_handle_.get(), io_callback_); +} + +int NaiveConnection::DoConnectServerComplete(int result) { + if (result < 0) + return result; + + DCHECK(server_socket_handle_->socket()); + sockets_[kServer] = server_socket_handle_->socket(); + + full_duplex_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int NaiveConnection::Run(CompletionOnceCallback callback) { + DCHECK(sockets_[kClient]); + DCHECK(sockets_[kServer]); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(!connect_callback_); + + if (errors_[kClient] != OK) + return errors_[kClient]; + if (errors_[kServer] != OK) + return errors_[kServer]; + + run_callback_ = std::move(callback); + + bytes_passed_without_yielding_[kClient] = 0; + bytes_passed_without_yielding_[kServer] = 0; + + yield_after_time_[kClient] = + time_func_() + + base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds); + yield_after_time_[kServer] = yield_after_time_[kClient]; + + can_push_to_server_ = true; + if (!early_pull_pending_) { + DCHECK_GT(early_pull_result_, 0); + Push(kClient, kServer, early_pull_result_); + } + Pull(kServer, kClient); + + return ERR_IO_PENDING; +} + +void NaiveConnection::Pull(Direction from, Direction to) { + if (errors_[kClient] < 0 || errors_[kServer] < 0) + return; + + read_buffers_[from] = new IOBuffer(kBufferSize); + DCHECK(sockets_[from]); + int rv = sockets_[from]->Read( + read_buffers_[from].get(), kBufferSize, + base::BindRepeating(&NaiveConnection::OnPullComplete, + weak_ptr_factory_.GetWeakPtr(), from, to)); + + if (from == kClient && early_pull_pending_) + early_pull_result_ = rv; + + if (rv != ERR_IO_PENDING) + OnPullComplete(from, to, rv); +} + +void NaiveConnection::Push(Direction from, Direction to, int size) { + write_buffers_[to] = new DrainableIOBuffer(read_buffers_[from].get(), size); + write_pending_[to] = true; + DCHECK(sockets_[to]); + int rv = sockets_[to]->Write( + write_buffers_[to].get(), size, + base::BindRepeating(&NaiveConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to), + traffic_annotation_); + + if (rv != ERR_IO_PENDING) + OnPushComplete(from, to, rv); +} + +void NaiveConnection::Disconnect(Direction side) { + if (sockets_[side]) { + sockets_[side]->Disconnect(); + sockets_[side] = nullptr; + write_pending_[side] = false; + } +} + +bool NaiveConnection::IsConnected(Direction side) { + return sockets_[side]; +} + +void NaiveConnection::OnBothDisconnected() { + if (run_callback_) { + int error = OK; + if (errors_[kClient] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kClient]; + if (errors_[kServer] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kServer]; + std::move(run_callback_).Run(error); + } +} + +void NaiveConnection::OnPullError(Direction from, Direction to, int error) { + DCHECK_LT(error, 0); + + errors_[from] = error; + Disconnect(from); + + if (!write_pending_[to]) + Disconnect(to); + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPushError(Direction from, Direction to, int error) { + DCHECK_LE(error, 0); + DCHECK(!write_pending_[to]); + + if (error < 0) { + errors_[to] = error; + Disconnect(kServer); + Disconnect(kClient); + } else if (!IsConnected(from)) { + Disconnect(to); + } + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) { + if (from == kClient && early_pull_pending_) { + early_pull_pending_ = false; + early_pull_result_ = result; + } + + if (result <= 0) { + OnPullError(from, to, result ? result : ERR_CONNECTION_CLOSED); + return; + } + + if (from == kClient && !can_push_to_server_) + return; + + Push(from, to, result); +} + +void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) { + if (result >= 0) { + bytes_passed_without_yielding_[from] += result; + write_buffers_[to]->DidConsume(result); + int size = write_buffers_[to]->BytesRemaining(); + if (size > 0) { + Push(from, to, size); + return; + } + } + + write_pending_[to] = false; + // Checks for termination even if result is OK. + OnPushError(from, to, result >= 0 ? OK : result); + + if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead || + time_func_() > yield_after_time_[from]) { + bytes_passed_without_yielding_[from] = 0; + yield_after_time_[from] = + time_func_() + + base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds); + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, + base::BindRepeating(&NaiveConnection::Pull, + weak_ptr_factory_.GetWeakPtr(), from, to)); + } else { + Pull(from, to); + } +} + +} // namespace net diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h new file mode 100644 index 0000000000..2bf976543a --- /dev/null +++ b/src/net/tools/naive/naive_connection.h @@ -0,0 +1,126 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ +#define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ + +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" + +namespace net { + +class ClientSocketHandle; +class DrainableIOBuffer; +class IOBuffer; +class StreamSocket; +struct NetworkTrafficAnnotationTag; + +class NaiveConnection { + public: + using TimeFunc = base::TimeTicks (*)(); + + class Delegate { + public: + Delegate() {} + virtual ~Delegate() {} + + virtual int OnConnectServer(unsigned int connection_id, + const StreamSocket* accepted_socket, + ClientSocketHandle* server_socket, + CompletionRepeatingCallback callback) = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(Delegate); + }; + + NaiveConnection(unsigned int id, + std::unique_ptr accepted_socket, + Delegate* delegate, + const NetworkTrafficAnnotationTag& traffic_annotation); + ~NaiveConnection(); + + unsigned int id() const { return id_; } + int Connect(CompletionOnceCallback callback); + void Disconnect(); + int Run(CompletionOnceCallback callback); + + private: + enum State { + STATE_CONNECT_CLIENT, + STATE_CONNECT_CLIENT_COMPLETE, + STATE_CONNECT_SERVER, + STATE_CONNECT_SERVER_COMPLETE, + STATE_NONE, + }; + + // From this direction. + enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + int DoLoop(int last_io_result); + int DoConnectClient(); + int DoConnectClientComplete(int result); + int DoConnectServer(); + int DoConnectServerComplete(int result); + void Pull(Direction from, Direction to); + void Push(Direction from, Direction to, int size); + void Disconnect(Direction side); + bool IsConnected(Direction side); + void OnBothDisconnected(); + void OnPullError(Direction from, Direction to, int error); + void OnPushError(Direction from, Direction to, int error); + void OnPullComplete(Direction from, Direction to, int result); + void OnPushComplete(Direction from, Direction to, int result); + + unsigned int id_; + + CompletionRepeatingCallback io_callback_; + CompletionOnceCallback connect_callback_; + CompletionOnceCallback run_callback_; + + State next_state_; + + Delegate* delegate_; + + std::unique_ptr client_socket_; + std::unique_ptr server_socket_handle_; + + StreamSocket* sockets_[kNumDirections]; + scoped_refptr read_buffers_[kNumDirections]; + scoped_refptr write_buffers_[kNumDirections]; + int errors_[kNumDirections]; + bool write_pending_[kNumDirections]; + int bytes_passed_without_yielding_[kNumDirections]; + base::TimeTicks yield_after_time_[kNumDirections]; + + bool early_pull_pending_; + bool can_push_to_server_; + int early_pull_result_; + + bool full_duplex_; + + TimeFunc time_func_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + base::WeakPtrFactory weak_ptr_factory_{this}; + + DISALLOW_COPY_AND_ASSIGN(NaiveConnection); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_ diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc new file mode 100644 index 0000000000..e54dcd61e8 --- /dev/null +++ b/src/net/tools/naive/naive_proxy.cc @@ -0,0 +1,209 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_proxy.h" + +#include + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/threading/thread_task_runner_handle.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_info.h" +#include "net/proxy_resolution/proxy_list.h" +#include "net/proxy_resolution/proxy_resolution_service.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/socks5_server_socket.h" + +namespace net { + +NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, + Protocol protocol, + bool use_proxy, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation) + : listen_socket_(std::move(listen_socket)), + protocol_(protocol), + use_proxy_(use_proxy), + session_(session), + net_log_( + NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)), + last_id_(0), + traffic_annotation_(traffic_annotation) { + DCHECK(listen_socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce(&NaiveProxy::DoAcceptLoop, + weak_ptr_factory_.GetWeakPtr())); +} + +NaiveProxy::~NaiveProxy() = default; + +void NaiveProxy::DoAcceptLoop() { + int result; + do { + result = listen_socket_->Accept( + &accepted_socket_, base::BindRepeating(&NaiveProxy::OnAcceptComplete, + weak_ptr_factory_.GetWeakPtr())); + if (result == ERR_IO_PENDING) + return; + HandleAcceptResult(result); + } while (result == OK); +} + +void NaiveProxy::OnAcceptComplete(int result) { + HandleAcceptResult(result); + if (result == OK) + DoAcceptLoop(); +} + +void NaiveProxy::HandleAcceptResult(int result) { + if (result != OK) { + LOG(ERROR) << "Accept error: rv=" << result; + return; + } + DoConnect(); +} + +void NaiveProxy::DoConnect() { + std::unique_ptr socket; + if (protocol_ == kSocks5) { + socket = std::make_unique(std::move(accepted_socket_)); + } else { + return; + } + auto connection_ptr = std::make_unique( + ++last_id_, std::move(socket), this, traffic_annotation_); + auto* connection = connection_ptr.get(); + connection_by_id_[connection->id()] = std::move(connection_ptr); + int result = connection->Connect( + base::BindRepeating(&NaiveProxy::OnConnectComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleConnectResult(connection, result); +} + +int NaiveProxy::OnConnectServer(unsigned int connection_id, + const StreamSocket* client_socket, + ClientSocketHandle* server_socket, + CompletionRepeatingCallback callback) { + // Ignores socket limit set by socket pool for this type of socket. + constexpr int request_load_flags = LOAD_IGNORE_LIMITS; + constexpr RequestPriority request_priority = MAXIMUM_PRIORITY; + + ProxyInfo proxy_info; + SSLConfig server_ssl_config; + SSLConfig proxy_ssl_config; + + if (use_proxy_) { + const auto& proxy_config = session_->proxy_resolution_service()->config(); + DCHECK(proxy_config); + const ProxyList& proxy_list = + proxy_config.value().value().proxy_rules().single_proxies; + if (proxy_list.IsEmpty()) + return ERR_MANDATORY_PROXY_CONFIGURATION_FAILED; + proxy_info.UseProxyList(proxy_list); + proxy_info.set_traffic_annotation( + net::MutableNetworkTrafficAnnotationTag(traffic_annotation_)); + + HttpRequestInfo req_info; + session_->GetSSLConfig(req_info, &server_ssl_config, &proxy_ssl_config); + proxy_ssl_config.disable_cert_verification_network_fetches = true; + } + + HostPortPair request_endpoint; + if (protocol_ == kSocks5) { + const auto* socket = static_cast(client_socket); + request_endpoint = socket->request_endpoint(); + } + if (request_endpoint.port() == 0) { + LOG(ERROR) << "Connection " << connection_id << " has invalid upstream"; + return ERR_ADDRESS_INVALID; + } + + LOG(INFO) << "Connection " << connection_id << " to " + << request_endpoint.ToString(); + + return InitSocketHandleForRawConnect( + request_endpoint, session_, request_load_flags, request_priority, + proxy_info, server_ssl_config, proxy_ssl_config, PRIVACY_MODE_DISABLED, + net_log_, server_socket, callback); +} + +void NaiveProxy::OnConnectComplete(int connection_id, int result) { + NaiveConnection* connection = FindConnection(connection_id); + if (!connection) + return; + HandleConnectResult(connection, result); +} + +void NaiveProxy::HandleConnectResult(NaiveConnection* connection, int result) { + if (result != OK) { + Close(connection->id(), result); + return; + } + DoRun(connection); +} + +void NaiveProxy::DoRun(NaiveConnection* connection) { + int result = connection->Run( + base::BindRepeating(&NaiveProxy::OnRunComplete, + weak_ptr_factory_.GetWeakPtr(), connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::OnRunComplete(int connection_id, int result) { + NaiveConnection* connection = FindConnection(connection_id); + if (!connection) + return; + HandleRunResult(connection, result); +} + +void NaiveProxy::HandleRunResult(NaiveConnection* connection, int result) { + Close(connection->id(), result); +} + +void NaiveProxy::Close(int connection_id, int reason) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return; + + LOG(INFO) << "Connection " << connection_id + << " closed: " << ErrorToShortString(reason); + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, + std::move(it->second)); + connection_by_id_.erase(it); +} + +NaiveConnection* NaiveProxy::FindConnection(int connection_id) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return nullptr; + return it->second.get(); +} + +// This is called after any delegate callbacks are called to check if Close() +// has been called during callback processing. Using the pointer of connection, +// |connection| is safe here because Close() deletes the connection in next run +// loop. +bool NaiveProxy::HasClosedConnection(NaiveConnection* connection) { + return FindConnection(connection->id()) != connection; +} + +} // namespace net diff --git a/src/net/tools/naive/naive_proxy.h b/src/net/tools/naive/naive_proxy.h new file mode 100644 index 0000000000..e304085571 --- /dev/null +++ b/src/net/tools/naive/naive_proxy.h @@ -0,0 +1,84 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_H_ +#define NET_TOOLS_NAIVE_NAIVE_PROXY_H_ + +#include +#include + +#include "base/macros.h" +#include "base/memory/weak_ptr.h" +#include "net/base/completion_repeating_callback.h" +#include "net/log/net_log_with_source.h" +#include "net/tools/naive/naive_connection.h" + +namespace net { + +class ClientSocketHandle; +class HttpNetworkSession; +class NaiveConnection; +class ServerSocket; +class StreamSocket; +struct NetworkTrafficAnnotationTag; + +class NaiveProxy : public NaiveConnection::Delegate { + public: + enum Protocol { + kSocks5, + kHttp, + }; + + NaiveProxy(std::unique_ptr server_socket, + Protocol protocol, + bool use_proxy, + HttpNetworkSession* session, + const NetworkTrafficAnnotationTag& traffic_annotation); + ~NaiveProxy() override; + + int OnConnectServer(unsigned int connection_id, + const StreamSocket* accepted_socket, + ClientSocketHandle* server_socket, + CompletionRepeatingCallback callback) override; + + private: + void DoAcceptLoop(); + void OnAcceptComplete(int result); + void HandleAcceptResult(int result); + + void DoConnect(); + void OnConnectComplete(int connection_id, int result); + void HandleConnectResult(NaiveConnection* connection, int result); + + void DoRun(NaiveConnection* connection); + void OnRunComplete(int connection_id, int result); + void HandleRunResult(NaiveConnection* connection, int result); + + void Close(int connection_id, int reason); + + NaiveConnection* FindConnection(int connection_id); + bool HasClosedConnection(NaiveConnection* connection); + + std::unique_ptr listen_socket_; + Protocol protocol_; + bool use_proxy_; + HttpNetworkSession* session_; + NetLogWithSource net_log_; + + unsigned int last_id_; + + std::unique_ptr accepted_socket_; + + std::map> connection_by_id_; + + const NetworkTrafficAnnotationTag& traffic_annotation_; + + base::WeakPtrFactory weak_ptr_factory_{this}; + + DISALLOW_COPY_AND_ASSIGN(NaiveProxy); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_H_ diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc new file mode 100644 index 0000000000..48491fc81c --- /dev/null +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -0,0 +1,381 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include + +#include "base/at_exit.h" +#include "base/command_line.h" +#include "base/files/file_path.h" +#include "base/json/json_writer.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" +#include "base/strings/utf_string_conversions.h" +#include "base/system/sys_info.h" +#include "base/task/single_thread_task_executor.h" +#include "base/task/thread_pool/thread_pool_instance.h" +#include "base/values.h" +#include "build/build_config.h" +#include "components/version_info/version_info.h" +#include "net/base/auth.h" +#include "net/base/network_isolation_key.h" +#include "net/dns/host_resolver.h" +#include "net/dns/mapped_host_resolver.h" +#include "net/http/http_auth.h" +#include "net/http/http_auth_cache.h" +#include "net/http/http_network_session.h" +#include "net/http/http_transaction_factory.h" +#include "net/log/file_net_log_observer.h" +#include "net/log/net_log.h" +#include "net/log/net_log_capture_mode.h" +#include "net/log/net_log_entry.h" +#include "net/log/net_log_source.h" +#include "net/log/net_log_util.h" +#include "net/proxy_resolution/proxy_config.h" +#include "net/proxy_resolution/proxy_config_service_fixed.h" +#include "net/proxy_resolution/proxy_config_with_annotation.h" +#include "net/proxy_resolution/proxy_resolution_service.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_server_socket.h" +#include "net/ssl/ssl_key_logger_impl.h" +#include "net/tools/naive/naive_proxy.h" +#include "net/traffic_annotation/network_traffic_annotation.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_builder.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" + +#if defined(OS_MACOSX) +#include "base/mac/scoped_nsautorelease_pool.h" +#endif + +namespace { + +constexpr int kListenBackLog = 512; +constexpr int kDefaultMaxSocketsPerPool = 256; +constexpr int kDefaultMaxSocketsPerGroup = 255; +constexpr int kExpectedMaxUsers = 8; +constexpr char kDefaultHostName[] = "example"; +constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("naive", ""); + +struct Params { + std::string listen_addr; + int listen_port; + net::NaiveProxy::Protocol protocol; + bool use_proxy; + std::string proxy_url; + std::string proxy_user; + std::string proxy_pass; + std::string host_resolver_rules; + logging::LoggingSettings log_settings; + base::FilePath net_log_path; + base::FilePath ssl_key_path; +}; + +std::unique_ptr GetConstants( + const base::CommandLine::StringType& command_line_string) { + auto constants_dict = net::GetNetConstants(); + DCHECK(constants_dict); + + // Add a dictionary with the version of the client and its command line + // arguments. + auto dict = std::make_unique(); + + // We have everything we need to send the right values. + std::string os_type = base::StringPrintf( + "%s: %s (%s)", base::SysInfo::OperatingSystemName().c_str(), + base::SysInfo::OperatingSystemVersion().c_str(), + base::SysInfo::OperatingSystemArchitecture().c_str()); + dict->SetString("os_type", os_type); + dict->SetString("command_line", command_line_string); + + constants_dict->Set("clientInfo", std::move(dict)); + + return std::move(constants_dict); +} + +// Builds a URLRequestContext assuming there's only a single loop. +std::unique_ptr BuildURLRequestContext( + const Params& params, + net::NetLog* net_log) { + net::URLRequestContextBuilder builder; + + net::ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString(params.proxy_url); + auto proxy_service = net::ProxyResolutionService::CreateWithoutProxyResolver( + std::make_unique( + net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), + net_log); + proxy_service->ForceReloadProxyConfig(); + + builder.set_proxy_resolution_service(std::move(proxy_service)); + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + if (!params.host_resolver_rules.empty()) { + builder.set_host_mapping_rules(params.host_resolver_rules); + } + + auto context = builder.Build(); + + net::HttpNetworkSession* session = + context->http_transaction_factory()->GetSession(); + net::HttpAuthCache* auth_cache = session->http_auth_cache(); + GURL auth_origin(params.proxy_url); + net::AuthCredentials credentials(base::ASCIIToUTF16(params.proxy_user), + base::ASCIIToUTF16(params.proxy_pass)); + auth_cache->Add(auth_origin, net::HttpAuth::AUTH_PROXY, + /*realm=*/std::string(), net::HttpAuth::AUTH_SCHEME_BASIC, + net::NetworkIsolationKey(), /*challenge=*/"Basic", + credentials, /*path=*/"/"); + + return context; +} + +bool ParseCommandLineFlags(Params* params) { + const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); + + if (line.HasSwitch("h") || line.HasSwitch("help")) { + LOG(INFO) << "Usage: naive [options]\n" + "\n" + "Options:\n" + "-h, --help Show this message\n" + "--version Print version\n" + "--addr=
Address to listen on (0.0.0.0)\n" + "--port= Port to listen on (1080)\n" + "--proto=[socks|http] Protocol to accept (socks)\n" + "--proxy=https://:@[:]\n" + " Proxy specification.\n" + "--log Log to stderr, otherwise no log\n" + "--log-net-log= Save NetLog\n" + "--ssl-key-log-file= Save SSL keys for Wireshark\n"; + exit(EXIT_SUCCESS); + return false; + } + + if (line.HasSwitch("version")) { + LOG(INFO) << "Version: " << version_info::GetVersionNumber(); + exit(EXIT_SUCCESS); + return false; + } + + params->listen_addr = "0.0.0.0"; + if (line.HasSwitch("addr")) { + params->listen_addr = line.GetSwitchValueASCII("addr"); + } + if (params->listen_addr.empty()) { + LOG(ERROR) << "Invalid --addr"; + return false; + } + + params->listen_port = 1080; + if (line.HasSwitch("port")) { + if (!base::StringToInt(line.GetSwitchValueASCII("port"), + ¶ms->listen_port)) { + LOG(ERROR) << "Invalid --port"; + return false; + } + if (params->listen_port <= 0 || + params->listen_port > std::numeric_limits::max()) { + LOG(ERROR) << "Invalid --port"; + return false; + } + } + + params->protocol = net::NaiveProxy::kSocks5; + if (line.HasSwitch("proto")) { + const auto& proto = line.GetSwitchValueASCII("proto"); + if (proto == "socks") { + params->protocol = net::NaiveProxy::kSocks5; + } else if (proto == "http") { + params->protocol = net::NaiveProxy::kHttp; + } else { + LOG(ERROR) << "Invalid --proto"; + return false; + } + } + + GURL url(line.GetSwitchValueASCII("proxy")); + if (line.HasSwitch("proxy")) { + params->use_proxy = true; + if (!url.is_valid()) { + LOG(ERROR) << "Invalid proxy URL"; + return false; + } + if (url.scheme() != "https") { + LOG(ERROR) << "Must be HTTPS proxy"; + return false; + } + if (url.username().empty() || url.password().empty()) { + LOG(ERROR) << "Missing user or pass"; + return false; + } + params->proxy_url = url::SchemeHostPort(url).Serialize(); + params->proxy_user = url.username(); + params->proxy_pass = url.password(); + } + + if (line.HasSwitch("host-resolver-rules")) { + params->host_resolver_rules = + line.GetSwitchValueASCII("host-resolver-rules"); + } else { + // SNI should only contain DNS hostnames not IP addresses per RFC 6066. + if (url.HostIsIPAddress()) { + GURL::Replacements replacements; + replacements.SetHostStr(kDefaultHostName); + params->proxy_url = + url::SchemeHostPort(url.ReplaceComponents(replacements)).Serialize(); + LOG(INFO) << "Using '" << kDefaultHostName << "' as the hostname for " + << url.host(); + params->host_resolver_rules = + std::string("MAP ") + kDefaultHostName + " " + url.host(); + } + } + + if (line.HasSwitch("log")) { + params->log_settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG; + } else { + params->log_settings.logging_dest = logging::LOG_NONE; + } + + if (line.HasSwitch("log-net-log")) { + params->net_log_path = line.GetSwitchValuePath("log-net-log"); + } + + if (line.HasSwitch("ssl-key-log-file")) { + params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file"); + } + + return true; +} + +// NetLog::ThreadSafeObserver implementation that simply prints events +// to the logs. +class PrintingLogObserver : public net::NetLog::ThreadSafeObserver { + public: + PrintingLogObserver() = default; + + ~PrintingLogObserver() override { + // This is guaranteed to be safe as this program is single threaded. + net_log()->RemoveObserver(this); + } + + // NetLog::ThreadSafeObserver implementation: + void OnAddEntry(const net::NetLogEntry& entry) override { + switch (entry.type) { + case net::NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS: + case net::NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP: + case net::NetLogEventType:: + HTTP2_SESSION_STREAM_STALLED_BY_SESSION_SEND_WINDOW: + case net::NetLogEventType:: + HTTP2_SESSION_STREAM_STALLED_BY_STREAM_SEND_WINDOW: + case net::NetLogEventType::HTTP2_SESSION_STALLED_MAX_STREAMS: + case net::NetLogEventType::HTTP2_STREAM_FLOW_CONTROL_UNSTALLED: + break; + default: + return; + } + const char* const source_type = + net::NetLog::SourceTypeToString(entry.source.type); + const char* const event_type = net::NetLog::EventTypeToString(entry.type); + const char* const event_phase = + net::NetLog::EventPhaseToString(entry.phase); + base::Value params(entry.ToValue()); + std::string params_str; + base::JSONWriter::Write(params, ¶ms_str); + params_str.insert(0, ": "); + + LOG(INFO) << source_type << "(" << entry.source.id << "): " << event_type + << ": " << event_phase << params_str; + } + + private: + DISALLOW_COPY_AND_ASSIGN(PrintingLogObserver); +}; + +} // namespace + +int main(int argc, char* argv[]) { + base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO); + base::ThreadPoolInstance::CreateAndStartWithDefaultParams("naive"); + base::AtExitManager exit_manager; + +#if defined(OS_MACOSX) + base::mac::ScopedNSAutoreleasePool pool; +#endif + + base::CommandLine::Init(argc, argv); + + Params params; + if (!ParseCommandLineFlags(¶ms)) { + return EXIT_FAILURE; + } + + net::ClientSocketPoolManager::set_max_sockets_per_pool( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_proxy_server( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_group( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerGroup * kExpectedMaxUsers); + + CHECK(logging::InitLogging(params.log_settings)); + + if (!params.ssl_key_path.empty()) { + net::SSLClientSocket::SetSSLKeyLogger( + std::make_unique(params.ssl_key_path)); + } + + // The declaration order for net_log and printing_log_observer is + // important. The destructor of PrintingLogObserver removes itself + // from net_log, so net_log must be available for entire lifetime of + // printing_log_observer. + net::NetLog* net_log = net::NetLog::Get(); + std::unique_ptr observer; + if (!params.net_log_path.empty()) { + const auto& cmdline = + base::CommandLine::ForCurrentProcess()->GetCommandLineString(); + observer = net::FileNetLogObserver::CreateUnbounded(params.net_log_path, + GetConstants(cmdline)); + observer->StartObserving(net_log, net::NetLogCaptureMode::kDefault); + } + + // Avoids net log overhead if logging is disabled. + std::unique_ptr printing_log_observer; + if (params.log_settings.logging_dest != logging::LOG_NONE) { + PrintingLogObserver printing_log_observer; + net_log->AddObserver(&printing_log_observer, + net::NetLogCaptureMode::kDefault); + } + + auto context = BuildURLRequestContext(params, net_log); + + auto listen_socket = + std::make_unique(net_log, net::NetLogSource()); + + int result = listen_socket->ListenWithAddressAndPort( + params.listen_addr, params.listen_port, kListenBackLog); + if (result != net::OK) { + LOG(ERROR) << "Failed to listen: " << result; + return EXIT_FAILURE; + } + + net::NaiveProxy naive_proxy( + std::move(listen_socket), params.protocol, params.use_proxy, + context->http_transaction_factory()->GetSession(), kTrafficAnnotation); + + base::RunLoop().Run(); + + return EXIT_SUCCESS; +} diff --git a/src/net/tools/naive/socks5_server_socket.cc b/src/net/tools/naive/socks5_server_socket.cc new file mode 100644 index 0000000000..4ac9163990 --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.cc @@ -0,0 +1,554 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/socks5_server_socket.h" + +#include +#include + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/stl_util.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/log/net_log.h" +#include "net/log/net_log_event_type.h" + +namespace net { + +const unsigned int Socks5ServerSocket::kGreetReadHeaderSize = 2; +const unsigned int Socks5ServerSocket::kReadHeaderSize = 5; +const char Socks5ServerSocket::kSOCKS5Version = '\x05'; +const char Socks5ServerSocket::kSOCKS5Reserved = '\x00'; +const char Socks5ServerSocket::kAuthMethodNone = '\x00'; +const char Socks5ServerSocket::kAuthMethodNoAcceptable = '\xff'; +const char Socks5ServerSocket::kReplySuccess = '\x00'; +const char Socks5ServerSocket::kReplyCommandNotSupported = '\x07'; + +static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); +static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); + +namespace { +constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("naive", ""); +} // namespace + +Socks5ServerSocket::Socks5ServerSocket( + std::unique_ptr transport_socket) + : io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_received_(0), + bytes_sent_(0), + greet_read_header_size_(kGreetReadHeaderSize), + read_header_size_(kReadHeaderSize), + was_ever_used_(false), + net_log_(transport_->NetLog()), + traffic_annotation_(kTrafficAnnotation) {} + +Socks5ServerSocket::~Socks5ServerSocket() { + Disconnect(); +} + +const HostPortPair& Socks5ServerSocket::request_endpoint() const { + return request_endpoint_; +} + +int Socks5ServerSocket::Connect(CompletionOnceCallback callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT); + + next_state_ = STATE_GREET_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = std::move(callback); + } else { + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv); + } + return rv; +} + +void Socks5ServerSocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool Socks5ServerSocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool Socks5ServerSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& Socks5ServerSocket::NetLog() const { + return net_log_; +} + +bool Socks5ServerSocket::WasEverUsed() const { + return was_ever_used_; +} + +bool Socks5ServerSocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto Socks5ServerSocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool Socks5ServerSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +void Socks5ServerSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); +} + +int64_t Socks5ServerSocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +void Socks5ServerSocket::ApplySocketTag(const SocketTag& tag) { + return transport_->ApplySocketTag(tag); +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int Socks5ServerSocket::Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Read( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback))); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int Socks5ServerSocket::Write( + IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(!user_callback_); + DCHECK(callback); + + int rv = transport_->Write( + buf, buf_len, + base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), std::move(callback)), + traffic_annotation); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int Socks5ServerSocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int Socks5ServerSocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void Socks5ServerSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(user_callback_); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + std::move(user_callback_).Run(result); +} + +void Socks5ServerSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT); + DoCallback(rv); + } +} + +void Socks5ServerSocket::OnReadWriteComplete(CompletionOnceCallback callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(callback); + + if (result > 0) + was_ever_used_ = true; + std::move(callback).Run(result); +} + +int Socks5ServerSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_GREET_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetRead(); + break; + case STATE_GREET_READ_COMPLETE: + rv = DoGreetReadComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE, + rv); + break; + case STATE_GREET_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetWrite(); + break; + case STATE_GREET_WRITE_COMPLETE: + rv = DoGreetWriteComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ, + rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_READ, rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int Socks5ServerSocket::DoGreetRead() { + next_state_ = STATE_GREET_READ_COMPLETE; + + if (buffer_.empty()) { + DCHECK_EQ(0U, bytes_received_); + DCHECK_EQ(kGreetReadHeaderSize, greet_read_header_size_); + } + + int handshake_buf_len = greet_read_header_size_ - bytes_received_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoGreetReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING); + return ERR_SOCKS_CONNECTION_FAILED; + } + + bytes_received_ += result; + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (bytes_received_ == kGreetReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + if (buffer_[1] == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + + greet_read_header_size_ += buffer_[1]; + next_state_ = STATE_GREET_READ; + return OK; + } + + if (bytes_received_ == greet_read_header_size_) { + void* match = std::memchr(&buffer_[kGreetReadHeaderSize], kAuthMethodNone, + greet_read_header_size_ - kGreetReadHeaderSize); + if (match) { + auth_method_ = kAuthMethodNone; + } else { + auth_method_ = kAuthMethodNoAcceptable; + } + buffer_.clear(); + next_state_ = STATE_GREET_WRITE; + return OK; + } + + next_state_ = STATE_GREET_READ; + return OK; +} + +int Socks5ServerSocket::DoGreetWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSOCKS5Version, auth_method_}; + buffer_ = std::string(write_data, base::size(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_GREET_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoGreetWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + bytes_received_ = 0; + if (auth_method_ != kAuthMethodNoAcceptable) { + next_state_ = STATE_HANDSHAKE_READ; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_GREET_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + DCHECK_EQ(0U, bytes_received_); + DCHECK_EQ(kReadHeaderSize, read_header_size_); + } + + int handshake_buf_len = read_header_size_ - bytes_received_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + bytes_received_ += result; + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (bytes_received_ == kReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + "version", buffer_[0]); + return ERR_SOCKS_CONNECTION_FAILED; + } + SocksCommandType command = static_cast(buffer_[1]); + if (command == kCommandConnect) { + // The proxy replies with success immediately without first connecting + // to the requested endpoint. + reply_ = kReplySuccess; + } else if (command == kCommandBind || command == kCommandUDPAssociate) { + reply_ = kReplyCommandNotSupported; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_COMMAND, + "commmand", buffer_[1]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + // We check the type of IP/Domain the server returns and accordingly + // increase the size of the request. For domains, we need to read the + // size of the domain, so the initial request size is upto the domain + // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is + // read, we substract 1 byte from the additional request size. + address_type_ = static_cast(buffer_[3]); + if (address_type_ == kEndPointDomain) { + address_size_ = static_cast(buffer_[4]); + if (address_size_ == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (address_type_ == kEndPointResolvedIPv4) { + address_size_ = sizeof(struct in_addr); + --read_header_size_; + } else if (address_type_ == kEndPointResolvedIPv6) { + address_size_ = sizeof(struct in6_addr); + --read_header_size_; + } else { + // Aborts connection on unspecified address type. + net_log_.AddEventWithIntParams( + NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type", + buffer_[3]); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += address_size_ + sizeof(uint16_t); + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + // When the final bytes are read, setup handshake. + if (bytes_received_ == read_header_size_) { + size_t port_start = read_header_size_ - sizeof(uint16_t); + uint16_t port_net; + std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t)); + uint16_t port_host = base::NetToHost16(port_net); + + size_t address_start = port_start - address_size_; + if (address_type_ == kEndPointDomain) { + std::string domain(&buffer_[address_start], address_size_); + request_endpoint_ = HostPortPair(domain, port_host); + } else { + IPAddress ip_addr( + reinterpret_cast(&buffer_[address_start]), + address_size_); + IPEndPoint endpoint(ip_addr, port_host); + request_endpoint_ = HostPortPair::FromIPEndPoint(endpoint); + } + buffer_.clear(); + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; + } + + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int Socks5ServerSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + const char write_data[] = { + kSOCKS5Version, + reply_, + kSOCKS5Reserved, + kEndPointResolvedIPv4, + 0x00, + 0x00, + 0x00, + 0x00, // BND.ADDR + 0x00, + 0x00, // BND.PORT + }; + buffer_ = std::string(write_data, base::size(write_data)); + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_, traffic_annotation_); +} + +int Socks5ServerSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (reply_ == kReplySuccess) { + completed_handshake_ = true; + next_state_ = STATE_NONE; + } else { + net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR, + "error_code", reply_); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (bytes_sent_ < buffer_.size()) { + next_state_ = STATE_HANDSHAKE_WRITE; + } else { + NOTREACHED(); + } + + return OK; +} + +int Socks5ServerSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int Socks5ServerSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/src/net/tools/naive/socks5_server_socket.h b/src/net/tools/naive/socks5_server_socket.h new file mode 100644 index 0000000000..4ee7abab35 --- /dev/null +++ b/src/net/tools/naive/socks5_server_socket.h @@ -0,0 +1,172 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ +#define NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "net/base/completion_once_callback.h" +#include "net/base/completion_repeating_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { +struct NetworkTrafficAnnotationTag; + +// This StreamSocket is used to setup a SOCKSv5 handshake with a socks client. +// Currently no SOCKSv5 authentication is supported. +class Socks5ServerSocket : public StreamSocket { + public: + explicit Socks5ServerSocket(std::unique_ptr transport_socket); + + // On destruction Disconnect() is called. + ~Socks5ServerSocket() override; + + const HostPortPair& request_endpoint() const; + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + int Connect(CompletionOnceCallback callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + void ApplySocketTag(const SocketTag& tag) override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback) override; + int Write(IOBuffer* buf, + int buf_len, + CompletionOnceCallback callback, + const NetworkTrafficAnnotationTag& traffic_annotation) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_GREET_READ, + STATE_GREET_READ_COMPLETE, + STATE_GREET_WRITE, + STATE_GREET_WRITE_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + // Addressing type that can be specified in requests or responses. + enum SocksEndPointAddressType { + kEndPointDomain = 0x03, + kEndPointResolvedIPv4 = 0x01, + kEndPointResolvedIPv6 = 0x04, + }; + + enum SocksCommandType { + kCommandConnect = 0x01, + kCommandBind = 0x02, + kCommandUDPAssociate = 0x03, + }; + + static const unsigned int kGreetReadHeaderSize; + static const unsigned int kReadHeaderSize; + static const char kSOCKS5Version; + static const char kSOCKS5Reserved; + static const char kAuthMethodNone; + static const char kAuthMethodNoAcceptable; + static const char kReplySuccess; + static const char kReplyCommandNotSupported; + + void DoCallback(int result); + void OnIOComplete(int result); + void OnReadWriteComplete(CompletionOnceCallback callback, int result); + + int DoLoop(int last_io_result); + int DoGreetWrite(); + int DoGreetWriteComplete(int result); + int DoGreetRead(); + int DoGreetReadComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + + CompletionRepeatingCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionOnceCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // These contain the bytes received / sent by the SOCKS handshake. + size_t bytes_received_; + size_t bytes_sent_; + + size_t greet_read_header_size_; + size_t read_header_size_; + + bool was_ever_used_; + + SocksEndPointAddressType address_type_; + int address_size_; + + char auth_method_; + char reply_; + + HostPortPair request_endpoint_; + + NetLogWithSource net_log_; + + // Traffic annotation for socket control. + const NetworkTrafficAnnotationTag& traffic_annotation_; + + DISALLOW_COPY_AND_ASSIGN(Socks5ServerSocket); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_