diff --git a/net/log/net_log_event_type_list.h b/net/log/net_log_event_type_list.h index 551953165f..08e3a0c6ee 100644 --- a/net/log/net_log_event_type_list.h +++ b/net/log/net_log_event_type_list.h @@ -429,6 +429,11 @@ EVENT_TYPE(SOCKS_HOSTNAME_TOO_BIG) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING) EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE) +EVENT_TYPE(SOCKS_NO_REQUESTED_AUTH) +EVENT_TYPE(SOCKS_NO_ACCEPTABLE_AUTH) +EVENT_TYPE(SOCKS_ZERO_LENGTH_DOMAIN) +EVENT_TYPE(SOCKS_UNEXPECTED_COMMAND) + // This event indicates that a bad version number was received in the // proxy server's response. The extra parameters show its value: // { diff --git a/net/tools/naive/naive_client.cc b/net/tools/naive/naive_client.cc new file mode 100644 index 0000000000..bcfeb6aec2 --- /dev/null +++ b/net/tools/naive/naive_client.cc @@ -0,0 +1,146 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_client.h" + +#include + +#include "base/bind.h" +#include "base/location.h" +#include "base/logging.h" +#include "base/threading/thread_task_runner_handle.h" +#include "net/base/net_errors.h" +#include "net/http/http_network_session.h" +#include "net/socket/server_socket.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/naive_client_connection.h" +#include "net/tools/naive/socks5_server_socket.h" + +namespace net { + +NaiveClient::NaiveClient(std::unique_ptr server_socket, + HttpNetworkSession* session) + : server_socket_(std::move(server_socket)), + session_(session), + last_id_(0), + weak_ptr_factory_(this) { + DCHECK(server_socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce(&NaiveClient::DoAcceptLoop, + weak_ptr_factory_.GetWeakPtr())); +} + +NaiveClient::~NaiveClient() = default; + +void NaiveClient::DoAcceptLoop() { + int result; + do { + result = server_socket_->Accept(&accepted_socket_, + base::Bind(&NaiveClient::OnAcceptComplete, + weak_ptr_factory_.GetWeakPtr())); + if (result == ERR_IO_PENDING) + return; + HandleAcceptResult(result); + } while (result == OK); +} + +void NaiveClient::OnAcceptComplete(int result) { + HandleAcceptResult(result); + if (result == OK) + DoAcceptLoop(); +} + +void NaiveClient::HandleAcceptResult(int result) { + if (result != OK) { + LOG(ERROR) << "Accept error: rv=" << result; + return; + } + DoConnect(); +} + +void NaiveClient::DoConnect() { + auto connection_ptr = std::make_unique( + ++last_id_, std::move(accepted_socket_), session_); + auto* connection = connection_ptr.get(); + connection_by_id_[connection->id()] = std::move(connection_ptr); + int result = connection->Connect(base::Bind(&NaiveClient::OnConnectComplete, + weak_ptr_factory_.GetWeakPtr(), + connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleConnectResult(connection, result); +} + +void NaiveClient::OnConnectComplete(int connection_id, int result) { + NaiveClientConnection* connection = FindConnection(connection_id); + if (!connection) + return; + HandleConnectResult(connection, result); +} + +void NaiveClient::HandleConnectResult(NaiveClientConnection* connection, + int result) { + if (result != OK) { + Close(connection->id()); + return; + } + DoRun(connection); +} + +void NaiveClient::DoRun(NaiveClientConnection* connection) { + int result = connection->Run(base::Bind(&NaiveClient::OnRunComplete, + weak_ptr_factory_.GetWeakPtr(), + connection->id())); + if (result == ERR_IO_PENDING) + return; + HandleRunResult(connection, result); +} + +void NaiveClient::OnRunComplete(int connection_id, int result) { + NaiveClientConnection* connection = FindConnection(connection_id); + if (!connection) + return; + HandleRunResult(connection, result); +} + +void NaiveClient::HandleRunResult(NaiveClientConnection* connection, + int result) { + LOG(INFO) << "Connection " << connection->id() + << " ended: " << ErrorToString(result); + Close(connection->id()); +} + +void NaiveClient::Close(int connection_id) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return; + + // The call stack might have callbacks which still have the pointer of + // connection. Instead of referencing connection with ID all the time, + // destroys the connection in next run loop to make sure any pending + // callbacks in the call stack return. + base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE, + std::move(it->second)); + connection_by_id_.erase(it); +} + +NaiveClientConnection* NaiveClient::FindConnection(int connection_id) { + auto it = connection_by_id_.find(connection_id); + if (it == connection_by_id_.end()) + return nullptr; + return it->second.get(); +} + +// This is called after any delegate callbacks are called to check if Close() +// has been called during callback processing. Using the pointer of connection, +// |connection| is safe here because Close() deletes the connection in next run +// loop. +bool NaiveClient::HasClosedConnection(NaiveClientConnection* connection) { + return FindConnection(connection->id()) != connection; +} + +} // namespace net diff --git a/net/tools/naive/naive_client.h b/net/tools/naive/naive_client.h new file mode 100644 index 0000000000..90ac49feb2 --- /dev/null +++ b/net/tools/naive/naive_client.h @@ -0,0 +1,62 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_CLIENT_H_ +#define NET_TOOLS_NAIVE_NAIVE_CLIENT_H_ + +#include +#include + +#include "base/macros.h" +#include "base/memory/weak_ptr.h" + +namespace net { + +class HttpNetworkSession; +class ServerSocket; +class StreamSocket; +class NaiveClientConnection; + +class NaiveClient { + public: + NaiveClient(std::unique_ptr server_socket, + HttpNetworkSession* session); + ~NaiveClient(); + + private: + void DoAcceptLoop(); + void OnAcceptComplete(int result); + void HandleAcceptResult(int result); + + void DoConnect(); + void OnConnectComplete(int connection_id, int result); + void HandleConnectResult(NaiveClientConnection* connection, int result); + + void DoRun(NaiveClientConnection* connection); + void OnRunComplete(int connection_id, int result); + void HandleRunResult(NaiveClientConnection* connection, int result); + + void Close(int connection_id); + + NaiveClientConnection* FindConnection(int connection_id); + bool HasClosedConnection(NaiveClientConnection* connection); + + std::unique_ptr server_socket_; + HttpNetworkSession* session_; + + unsigned int last_id_; + + std::unique_ptr accepted_socket_; + + std::map> + connection_by_id_; + + base::WeakPtrFactory weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(NaiveClient); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_CLIENT_H_ diff --git a/net/tools/naive/naive_client_bin.cc b/net/tools/naive/naive_client_bin.cc new file mode 100644 index 0000000000..fb9adf7b37 --- /dev/null +++ b/net/tools/naive/naive_client_bin.cc @@ -0,0 +1,325 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include +#include +#include +#include + +#include "base/at_exit.h" +#include "base/command_line.h" +#include "base/files/file_path.h" +#include "base/json/json_writer.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/message_loop/message_loop.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/stringprintf.h" +#include "base/strings/utf_string_conversions.h" +#include "base/sys_info.h" +#include "base/task_scheduler/task_scheduler.h" +#include "base/values.h" +#include "build/build_config.h" +#include "net/base/auth.h" +#include "net/http/http_auth.h" +#include "net/http/http_auth_cache.h" +#include "net/http/http_network_session.h" +#include "net/http/http_transaction_factory.h" +#include "net/log/file_net_log_observer.h" +#include "net/log/net_log.h" +#include "net/log/net_log_capture_mode.h" +#include "net/log/net_log_entry.h" +#include "net/log/net_log_source.h" +#include "net/log/net_log_util.h" +#include "net/proxy/proxy_config.h" +#include "net/proxy/proxy_config_service_fixed.h" +#include "net/proxy/proxy_service.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/tcp_server_socket.h" +#include "net/tools/naive/naive_client.h" +#include "net/url_request/url_request_context.h" +#include "net/url_request/url_request_context_builder.h" +#include "url/gurl.h" +#include "url/scheme_host_port.h" + +#if defined(OS_MACOSX) +#include "base/mac/scoped_nsautorelease_pool.h" +#endif + +namespace { + +constexpr int kListenBackLog = 512; +constexpr int kDefaultMaxSocketsPerPool = 256; +constexpr int kDefaultMaxSocketsPerGroup = 255; +constexpr int kExpectedMaxUsers = 8; + +std::unique_ptr GetConstants( + const base::CommandLine::StringType& command_line_string) { + auto constants_dict = net::GetNetConstants(); + DCHECK(constants_dict); + + // Add a dictionary with the version of the client and its command line + // arguments. + auto dict = std::make_unique(); + + // We have everything we need to send the right values. + std::string os_type = base::StringPrintf( + "%s: %s (%s)", base::SysInfo::OperatingSystemName().c_str(), + base::SysInfo::OperatingSystemVersion().c_str(), + base::SysInfo::OperatingSystemArchitecture().c_str()); + dict->SetString("os_type", os_type); + dict->SetString("command_line", command_line_string); + + constants_dict->Set("clientInfo", std::move(dict)); + + return std::move(constants_dict); +} + +// Builds a URLRequestContext assuming there's only a single loop. +std::unique_ptr BuildURLRequestContext( + const std::string& proxy_url, + const std::string& proxy_user, + const std::string& proxy_pass, + net::NetLog* net_log) { + net::URLRequestContextBuilder builder; + + net::ProxyConfig proxy_config; + proxy_config.proxy_rules().ParseFromString(proxy_url); + auto proxy_service = net::ProxyService::CreateWithoutProxyResolver( + std::make_unique(proxy_config), net_log); + proxy_service->ForceReloadProxyConfig(); + + builder.set_proxy_service(std::move(proxy_service)); + builder.DisableHttpCache(); + builder.set_net_log(net_log); + + auto context = builder.Build(); + + net::HttpNetworkSession* session = + context->http_transaction_factory()->GetSession(); + net::HttpAuthCache* auth_cache = session->http_auth_cache(); + GURL auth_origin(proxy_url); + net::AuthCredentials credentials(base::ASCIIToUTF16(proxy_user), + base::ASCIIToUTF16(proxy_pass)); + auth_cache->Add(auth_origin, /*realm=*/std::string(), + net::HttpAuth::AUTH_SCHEME_BASIC, /*challenge=*/"Basic", + credentials, /*path=*/"/"); + + return context; +} + +struct Params { + std::string listen_addr; + int listen_port; + std::string proxy_url; + std::string proxy_user; + std::string proxy_pass; + logging::LoggingSettings log_settings; + base::FilePath net_log_path; + base::FilePath ssl_key_path; +}; + +bool ParseCommandLineFlags(Params* params) { + const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); + + if (line.HasSwitch("h") || line.HasSwitch("help")) { + LOG(INFO) << "Usage: naive_client [options]\n" + "\n" + "Options:\n" + "-h, --help Show this help message and exit\n" + "--addr=
Address to listen on\n" + "--port= Port to listen on\n" + "--proxy=https://:@[:port]\n" + " Proxy specification\n" + "--log Log to stderr, otherwise no log\n" + "--log-net-log= Save NetLog\n" + "--ssl-key-log-file= Save SSL keys for Wireshark\n"; + exit(EXIT_SUCCESS); + return false; + } + + if (!line.HasSwitch("addr")) { + LOG(ERROR) << "Missing --addr"; + return false; + } + params->listen_addr = line.GetSwitchValueASCII("addr"); + if (params->listen_addr.empty()) { + LOG(ERROR) << "Invalid --port"; + return false; + } + + if (!line.HasSwitch("port")) { + LOG(ERROR) << "Missing --port"; + return false; + } + if (!base::StringToInt(line.GetSwitchValueASCII("port"), + ¶ms->listen_port)) { + LOG(ERROR) << "Invalid --port"; + return false; + } + if (params->listen_port <= 0 || + params->listen_port > std::numeric_limits::max()) { + LOG(ERROR) << "Invalid --port"; + return false; + } + + if (!line.HasSwitch("proxy")) { + LOG(ERROR) << "Missing --proxy"; + return false; + } + GURL url(line.GetSwitchValueASCII("proxy")); + if (!url.is_valid()) { + LOG(ERROR) << "Invalid proxy URL"; + return false; + } + if (url.scheme() != "https") { + LOG(ERROR) << "Must be HTTPS proxy"; + return false; + } + if (url.username().empty() || url.password().empty()) { + LOG(ERROR) << "Missing user or pass"; + return false; + } + params->proxy_url = url::SchemeHostPort(url).Serialize(); + params->proxy_user = url.username(); + params->proxy_pass = url.password(); + + if (line.HasSwitch("log")) { + params->log_settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG; + } else { + params->log_settings.logging_dest = logging::LOG_NONE; + } + + if (line.HasSwitch("log-net-log")) { + params->net_log_path = line.GetSwitchValuePath("log-net-log"); + } + + if (line.HasSwitch("ssl-key-log-file")) { + params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file"); + } + + return true; +} + +// NetLog::ThreadSafeObserver implementation that simply prints events +// to the logs. +class PrintingLogObserver : public net::NetLog::ThreadSafeObserver { + public: + PrintingLogObserver() = default; + + ~PrintingLogObserver() override { + // This is guaranteed to be safe as this program is single threaded. + net_log()->RemoveObserver(this); + } + + // NetLog::ThreadSafeObserver implementation: + void OnAddEntry(const net::NetLogEntry& entry) override { + switch (entry.type()) { + case net::NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS: + case net::NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP: + case net::NetLogEventType:: + HTTP2_SESSION_STREAM_STALLED_BY_SESSION_SEND_WINDOW: + case net::NetLogEventType:: + HTTP2_SESSION_STREAM_STALLED_BY_STREAM_SEND_WINDOW: + case net::NetLogEventType::HTTP2_SESSION_STALLED_MAX_STREAMS: + case net::NetLogEventType::HTTP2_STREAM_FLOW_CONTROL_UNSTALLED: + break; + default: + return; + } + const char* const source_type = + net::NetLog::SourceTypeToString(entry.source().type); + const char* const event_type = net::NetLog::EventTypeToString(entry.type()); + const char* const event_phase = + net::NetLog::EventPhaseToString(entry.phase()); + auto params = entry.ParametersToValue(); + std::string params_str; + if (params.get()) { + base::JSONWriter::Write(*params, ¶ms_str); + params_str.insert(0, ": "); + } + + LOG(INFO) << source_type << "(" << entry.source().id << "): " << event_type + << ": " << event_phase << params_str; + } + + private: + DISALLOW_COPY_AND_ASSIGN(PrintingLogObserver); +}; + +} // namespace + +int main(int argc, char* argv[]) { + base::TaskScheduler::CreateAndStartWithDefaultParams(""); + base::AtExitManager exit_manager; + base::MessageLoopForIO main_loop; + +#if defined(OS_MACOSX) + base::mac::ScopedNSAutoreleasePool pool; +#endif + + base::CommandLine::Init(argc, argv); + + Params params; + if (!ParseCommandLineFlags(¶ms)) { + return EXIT_FAILURE; + } + + net::ClientSocketPoolManager::set_max_sockets_per_pool( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_proxy_server( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerPool * kExpectedMaxUsers); + net::ClientSocketPoolManager::set_max_sockets_per_group( + net::HttpNetworkSession::NORMAL_SOCKET_POOL, + kDefaultMaxSocketsPerGroup * kExpectedMaxUsers); + + CHECK(logging::InitLogging(params.log_settings)); + + if (!params.ssl_key_path.empty()) { + net::SSLClientSocket::SetSSLKeyLogFile(params.ssl_key_path); + } + + // The declaration order for net_log and printing_log_observer is + // important. The destructor of PrintingLogObserver removes itself + // from net_log, so net_log must be available for entire lifetime of + // printing_log_observer. + net::NetLog net_log; + std::unique_ptr observer; + if (!params.net_log_path.empty()) { + const auto& cmdline = + base::CommandLine::ForCurrentProcess()->GetCommandLineString(); + observer = net::FileNetLogObserver::CreateUnbounded(params.net_log_path, + GetConstants(cmdline)); + observer->StartObserving(&net_log, net::NetLogCaptureMode::Default()); + } + PrintingLogObserver printing_log_observer; + net_log.AddObserver(&printing_log_observer, + net::NetLogCaptureMode::Default()); + + auto context = BuildURLRequestContext(params.proxy_url, params.proxy_user, + params.proxy_pass, &net_log); + + auto server_socket = + std::make_unique(&net_log, net::NetLogSource()); + + int result = server_socket->ListenWithAddressAndPort( + params.listen_addr, params.listen_port, kListenBackLog); + if (result != net::OK) { + LOG(ERROR) << "Failed to listen: " << result; + return EXIT_FAILURE; + } + + net::NaiveClient naive_client( + std::move(server_socket), + context->http_transaction_factory()->GetSession()); + + base::RunLoop().Run(); + + return EXIT_SUCCESS; +} diff --git a/net/tools/naive/naive_client_connection.cc b/net/tools/naive/naive_client_connection.cc new file mode 100644 index 0000000000..eaee101c89 --- /dev/null +++ b/net/tools/naive/naive_client_connection.cc @@ -0,0 +1,272 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/naive_client_connection.h" + +#include + +#include "base/bind.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/base/privacy_mode.h" +#include "net/http/http_network_session.h" +#include "net/proxy/proxy_config.h" +#include "net/proxy/proxy_info.h" +#include "net/proxy/proxy_list.h" +#include "net/proxy/proxy_service.h" +#include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool_manager.h" +#include "net/socket/stream_socket.h" +#include "net/tools/naive/socks5_server_socket.h" + +namespace net { + +namespace { +static const int kBufferSize = 64 * 1024; +} + +NaiveClientConnection::NaiveClientConnection( + int id, + std::unique_ptr accepted_socket, + HttpNetworkSession* session) + : id_(id), + next_state_(STATE_NONE), + session_(session), + net_log_( + NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)), + client_socket_( + std::make_unique(std::move(accepted_socket))), + server_socket_handle_(std::make_unique()), + client_error_(OK), + server_error_(OK), + full_duplex_(false), + weak_ptr_factory_(this) { + io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete, + weak_ptr_factory_.GetWeakPtr()); +} + +NaiveClientConnection::~NaiveClientConnection() { + Disconnect(); +} + +int NaiveClientConnection::Connect(const CompletionCallback& callback) { + DCHECK(client_socket_); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(connect_callback_.is_null()); + + if (full_duplex_) + return OK; + + next_state_ = STATE_CONNECT_CLIENT; + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = callback; + } + return rv; +} + +void NaiveClientConnection::Disconnect() { + full_duplex_ = false; + client_socket_->Disconnect(); + if (server_socket_handle_->socket()) + server_socket_handle_->socket()->Disconnect(); + + next_state_ = STATE_NONE; + connect_callback_.Reset(); + run_callback_.Reset(); +} + +void NaiveClientConnection::DoCallback(int result) { + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(!connect_callback_.is_null()); + + // Since Run() may result in Read being called, + // clear connect_callback_ up front. + base::ResetAndReturn(&connect_callback_).Run(result); +} + +void NaiveClientConnection::OnIOComplete(int result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + DoCallback(rv); + } +} + +int NaiveClientConnection::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_CONNECT_CLIENT: + DCHECK_EQ(rv, OK); + rv = DoConnectClient(); + break; + case STATE_CONNECT_CLIENT_COMPLETE: + rv = DoConnectClientComplete(rv); + break; + case STATE_CONNECT_SERVER: + DCHECK_EQ(rv, OK); + rv = DoConnectServer(); + case STATE_CONNECT_SERVER_COMPLETE: + rv = DoConnectServerComplete(rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int NaiveClientConnection::DoConnectClient() { + next_state_ = STATE_CONNECT_CLIENT_COMPLETE; + + return client_socket_->Connect(&request_endpoint_, io_callback_); +} + +int NaiveClientConnection::DoConnectClientComplete(int result) { + if (result < 0) + return result; + + next_state_ = STATE_CONNECT_SERVER; + return OK; +} + +int NaiveClientConnection::DoConnectServer() { + ProxyInfo proxy_info; + const ProxyList& proxy_list = + session_->proxy_service()->config().proxy_rules().single_proxies; + if (proxy_list.IsEmpty()) + return ERR_MANDATORY_PROXY_CONFIGURATION_FAILED; + proxy_info.UseProxyList(proxy_list); + + HttpRequestInfo req_info; + SSLConfig server_ssl_config; + SSLConfig proxy_ssl_config; + session_->GetSSLConfig(req_info, &server_ssl_config, &proxy_ssl_config); + proxy_ssl_config.rev_checking_enabled = false; + + next_state_ = STATE_CONNECT_SERVER_COMPLETE; + + DCHECK_NE(request_endpoint_.port(), 0); + + LOG(INFO) << "Connection " << id_ << " to " << request_endpoint_.ToString(); + + return InitSocketHandleForRawConnect( + request_endpoint_, session_, proxy_info, server_ssl_config, + proxy_ssl_config, PRIVACY_MODE_DISABLED, net_log_, + server_socket_handle_.get(), io_callback_); +} + +int NaiveClientConnection::DoConnectServerComplete(int result) { + if (result < 0) + return result; + + full_duplex_ = true; + next_state_ = STATE_NONE; + return OK; +} + +int NaiveClientConnection::Run(const CompletionCallback& callback) { + DCHECK(client_socket_); + DCHECK(server_socket_handle_->socket()); + DCHECK_EQ(next_state_, STATE_NONE); + DCHECK(connect_callback_.is_null()); + + run_callback_ = callback; + + Pull(client_socket_.get(), server_socket_handle_->socket()); + Pull(server_socket_handle_->socket(), client_socket_.get()); + return ERR_IO_PENDING; +} + +void NaiveClientConnection::Pull(StreamSocket* from, StreamSocket* to) { + if (client_error_ < 0 || server_error_ < 0) + return; + + auto buffer = base::MakeRefCounted(kBufferSize); + int rv = + from->Read(buffer.get(), kBufferSize, + base::Bind(&NaiveClientConnection::OnReadComplete, + weak_ptr_factory_.GetWeakPtr(), from, to, buffer)); + + if (rv != ERR_IO_PENDING) + OnReadComplete(from, to, buffer, rv); +} + +void NaiveClientConnection::Push(StreamSocket* from, + StreamSocket* to, + scoped_refptr buffer, + int size) { + if (client_error_ < 0 || server_error_ < 0) + return; + + auto drainable = base::MakeRefCounted(buffer.get(), size); + int rv = to->Write( + drainable.get(), size, + base::Bind(&NaiveClientConnection::OnWriteComplete, + weak_ptr_factory_.GetWeakPtr(), from, to, drainable)); + + if (rv != ERR_IO_PENDING) + OnWriteComplete(from, to, drainable, rv); +} + +void NaiveClientConnection::OnIOError(StreamSocket* socket, int error) { + if (socket == client_socket_.get()) { + if (client_error_ == OK) { + base::ResetAndReturn(&run_callback_).Run(error); + } + client_error_ = error; + return; + } + if (socket == server_socket_handle_->socket()) { + if (server_error_ == OK) { + base::ResetAndReturn(&run_callback_).Run(error); + } + server_error_ = error; + return; + } +} + +void NaiveClientConnection::OnReadComplete(StreamSocket* from, + StreamSocket* to, + scoped_refptr buffer, + int result) { + if (result <= 0) { + OnIOError(from, result ? result : ERR_CONNECTION_CLOSED); + return; + } + + Push(from, to, buffer, result); +} + +void NaiveClientConnection::OnWriteComplete( + StreamSocket* from, + StreamSocket* to, + scoped_refptr drainable, + int result) { + if (result < 0) { + OnIOError(to, result); + return; + } + + drainable->DidConsume(result); + int size = drainable->BytesRemaining(); + if (size > 0) { + Push(from, to, drainable.get(), size); + return; + } + + Pull(from, to); +} + +} // namespace net diff --git a/net/tools/naive/naive_client_connection.h b/net/tools/naive/naive_client_connection.h new file mode 100644 index 0000000000..bcdee0bf5c --- /dev/null +++ b/net/tools/naive/naive_client_connection.h @@ -0,0 +1,100 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_NAIVE_CLIENT_CONNECTION_H_ +#define NET_TOOLS_NAIVE_NAIVE_CLIENT_CONNECTION_H_ + +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/log/net_log_with_source.h" +#include "net/proxy/proxy_info.h" +#include "net/ssl/ssl_config.h" + +namespace net { + +class ClientSocketHandle; +class DrainableIOBuffer; +class HttpNetworkSession; +class IOBuffer; +class Socks5ServerSocket; +class StreamSocket; + +class NaiveClientConnection { + public: + NaiveClientConnection(int id, + std::unique_ptr accepted_socket, + HttpNetworkSession* session); + ~NaiveClientConnection(); + + int id() const { return id_; } + int Connect(const CompletionCallback& callback); + void Disconnect(); + int Run(const CompletionCallback& callback); + + private: + enum State { + STATE_CONNECT_CLIENT, + STATE_CONNECT_CLIENT_COMPLETE, + STATE_CONNECT_SERVER, + STATE_CONNECT_SERVER_COMPLETE, + STATE_NONE, + }; + + void DoCallback(int result); + void OnIOComplete(int result); + int DoLoop(int last_io_result); + int DoConnectClient(); + int DoConnectClientComplete(int result); + int DoConnectServer(); + int DoConnectServerComplete(int result); + void Pull(StreamSocket* from, StreamSocket* to); + void Push(StreamSocket* from, + StreamSocket* to, + scoped_refptr buffer, + int size); + void OnIOError(StreamSocket* socket, int error); + void OnReadComplete(StreamSocket* from, + StreamSocket* to, + scoped_refptr buffer, + int result); + void OnWriteComplete(StreamSocket* from, + StreamSocket* to, + scoped_refptr drainable, + int result); + + int id_; + + CompletionCallback io_callback_; + CompletionCallback connect_callback_; + CompletionCallback run_callback_; + + State next_state_; + + HttpNetworkSession* session_; + NetLogWithSource net_log_; + + HostPortPair request_endpoint_; + + std::unique_ptr client_socket_; + std::unique_ptr server_socket_; + std::unique_ptr server_socket_handle_; + + int client_error_; + int server_error_; + + bool full_duplex_; + + base::WeakPtrFactory weak_ptr_factory_; + + DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_NAIVE_CLIENT_CONNECTION_H_ diff --git a/net/tools/naive/socks5_server_socket.cc b/net/tools/naive/socks5_server_socket.cc new file mode 100644 index 0000000000..191e3e304a --- /dev/null +++ b/net/tools/naive/socks5_server_socket.cc @@ -0,0 +1,565 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/socks5_server_socket.h" + +#include +#include + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "base/sys_byteorder.h" +#include "net/base/ip_address.h" +#include "net/base/net_errors.h" +#include "net/log/net_log.h" +#include "net/log/net_log_event_type.h" + +namespace net { + +const unsigned int Socks5ServerSocket::kGreetReadHeaderSize = 2; +const unsigned int Socks5ServerSocket::kReadHeaderSize = 5; +const char Socks5ServerSocket::kSOCKS5Version = '\x05'; +const char Socks5ServerSocket::kSOCKS5Reserved = '\x00'; +const char Socks5ServerSocket::kAuthMethodNone = '\x00'; +const char Socks5ServerSocket::kAuthMethodNoAcceptable = '\xff'; +const char Socks5ServerSocket::kReplySuccess = '\x00'; +const char Socks5ServerSocket::kReplyCommandNotSupported = '\x07'; + +static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4"); +static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6"); + +Socks5ServerSocket::Socks5ServerSocket( + std::unique_ptr transport_socket) + : io_callback_(base::Bind(&Socks5ServerSocket::OnIOComplete, + base::Unretained(this))), + transport_(std::move(transport_socket)), + next_state_(STATE_NONE), + completed_handshake_(false), + bytes_received_(0), + bytes_sent_(0), + greet_read_header_size_(kGreetReadHeaderSize), + read_header_size_(kReadHeaderSize), + was_ever_used_(false), + net_log_(transport_->NetLog()) {} + +Socks5ServerSocket::~Socks5ServerSocket() { + Disconnect(); +} + +int Socks5ServerSocket::Connect(const CompletionCallback& callback) { + DCHECK(transport_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + + // If already connected, then just return OK. + if (completed_handshake_) + return OK; + + net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT); + + next_state_ = STATE_GREET_READ; + buffer_.clear(); + + int rv = DoLoop(OK); + if (rv == ERR_IO_PENDING) { + user_callback_ = callback; + } else { + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv); + } + return rv; +} + +int Socks5ServerSocket::Connect(HostPortPair* request_endpoint, + const CompletionCallback& callback) { + int result = + Connect(base::Bind(&Socks5ServerSocket::DoCallbackReturnRequest, + base::Unretained(this), request_endpoint, callback)); + if (result == OK) + *request_endpoint = host_port_pair_; + return result; +} + +void Socks5ServerSocket::DoCallbackReturnRequest( + HostPortPair* request_endpoint, + const CompletionCallback& callback, + int result) { + if (result == OK) + *request_endpoint = host_port_pair_; + callback.Run(result); +} + +void Socks5ServerSocket::Disconnect() { + completed_handshake_ = false; + transport_->Disconnect(); + + // Reset other states to make sure they aren't mistakenly used later. + // These are the states initialized by Connect(). + next_state_ = STATE_NONE; + user_callback_.Reset(); +} + +bool Socks5ServerSocket::IsConnected() const { + return completed_handshake_ && transport_->IsConnected(); +} + +bool Socks5ServerSocket::IsConnectedAndIdle() const { + return completed_handshake_ && transport_->IsConnectedAndIdle(); +} + +const NetLogWithSource& Socks5ServerSocket::NetLog() const { + return net_log_; +} + +void Socks5ServerSocket::SetSubresourceSpeculation() { + if (transport_) { + transport_->SetSubresourceSpeculation(); + } else { + NOTREACHED(); + } +} + +void Socks5ServerSocket::SetOmniboxSpeculation() { + if (transport_) { + transport_->SetOmniboxSpeculation(); + } else { + NOTREACHED(); + } +} + +bool Socks5ServerSocket::WasEverUsed() const { + return was_ever_used_; +} + +bool Socks5ServerSocket::WasAlpnNegotiated() const { + if (transport_) { + return transport_->WasAlpnNegotiated(); + } + NOTREACHED(); + return false; +} + +NextProto Socks5ServerSocket::GetNegotiatedProtocol() const { + if (transport_) { + return transport_->GetNegotiatedProtocol(); + } + NOTREACHED(); + return kProtoUnknown; +} + +bool Socks5ServerSocket::GetSSLInfo(SSLInfo* ssl_info) { + if (transport_) { + return transport_->GetSSLInfo(ssl_info); + } + NOTREACHED(); + return false; +} + +void Socks5ServerSocket::GetConnectionAttempts(ConnectionAttempts* out) const { + out->clear(); +} + +int64_t Socks5ServerSocket::GetTotalReceivedBytes() const { + return transport_->GetTotalReceivedBytes(); +} + +// Read is called by the transport layer above to read. This can only be done +// if the SOCKS handshake is complete. +int Socks5ServerSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + DCHECK(!callback.is_null()); + + int rv = transport_->Read(buf, buf_len, + base::Bind(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), callback)); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +// Write is called by the transport layer. This can only be done if the +// SOCKS handshake is complete. +int Socks5ServerSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(completed_handshake_); + DCHECK_EQ(STATE_NONE, next_state_); + DCHECK(user_callback_.is_null()); + DCHECK(!callback.is_null()); + + int rv = + transport_->Write(buf, buf_len, + base::Bind(&Socks5ServerSocket::OnReadWriteComplete, + base::Unretained(this), callback)); + if (rv > 0) + was_ever_used_ = true; + return rv; +} + +int Socks5ServerSocket::SetReceiveBufferSize(int32_t size) { + return transport_->SetReceiveBufferSize(size); +} + +int Socks5ServerSocket::SetSendBufferSize(int32_t size) { + return transport_->SetSendBufferSize(size); +} + +void Socks5ServerSocket::DoCallback(int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(!user_callback_.is_null()); + + // Since Run() may result in Read being called, + // clear user_callback_ up front. + base::ResetAndReturn(&user_callback_).Run(result); +} + +void Socks5ServerSocket::OnIOComplete(int result) { + DCHECK_NE(STATE_NONE, next_state_); + int rv = DoLoop(result); + if (rv != ERR_IO_PENDING) { + net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT); + DoCallback(rv); + } +} + +void Socks5ServerSocket::OnReadWriteComplete(const CompletionCallback& callback, + int result) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(!callback.is_null()); + + if (result > 0) + was_ever_used_ = true; + callback.Run(result); +} + +int Socks5ServerSocket::DoLoop(int last_io_result) { + DCHECK_NE(next_state_, STATE_NONE); + int rv = last_io_result; + do { + State state = next_state_; + next_state_ = STATE_NONE; + switch (state) { + case STATE_GREET_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetRead(); + break; + case STATE_GREET_READ_COMPLETE: + rv = DoGreetReadComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE, + rv); + break; + case STATE_GREET_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ); + rv = DoGreetWrite(); + break; + case STATE_GREET_WRITE_COMPLETE: + rv = DoGreetWriteComplete(rv); + net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ, + rv); + break; + case STATE_HANDSHAKE_READ: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ); + rv = DoHandshakeRead(); + break; + case STATE_HANDSHAKE_READ_COMPLETE: + rv = DoHandshakeReadComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_READ, rv); + break; + case STATE_HANDSHAKE_WRITE: + DCHECK_EQ(OK, rv); + net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE); + rv = DoHandshakeWrite(); + break; + case STATE_HANDSHAKE_WRITE_COMPLETE: + rv = DoHandshakeWriteComplete(rv); + net_log_.EndEventWithNetErrorCode( + NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv); + break; + default: + NOTREACHED() << "bad state"; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); + return rv; +} + +int Socks5ServerSocket::DoGreetRead() { + next_state_ = STATE_GREET_READ_COMPLETE; + + if (buffer_.empty()) { + DCHECK_EQ(0U, bytes_received_); + DCHECK_EQ(kGreetReadHeaderSize, greet_read_header_size_); + } + + int handshake_buf_len = greet_read_header_size_ - bytes_received_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoGreetReadComplete(int result) { + if (result < 0) + return result; + + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING); + return ERR_SOCKS_CONNECTION_FAILED; + } + + bytes_received_ += result; + buffer_.append(handshake_buf_->data(), result); + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (bytes_received_ == kGreetReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version) { + net_log_.AddEvent(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + NetLog::IntCallback("version", buffer_[0])); + return ERR_SOCKS_CONNECTION_FAILED; + } + if (buffer_[1] == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + + greet_read_header_size_ += buffer_[1]; + next_state_ = STATE_GREET_READ; + return OK; + } + + if (bytes_received_ == greet_read_header_size_) { + void* match = std::memchr(&buffer_[kGreetReadHeaderSize], kAuthMethodNone, + greet_read_header_size_ - kGreetReadHeaderSize); + if (match) { + auth_method_ = kAuthMethodNone; + } else { + auth_method_ = kAuthMethodNoAcceptable; + } + buffer_.clear(); + next_state_ = STATE_GREET_WRITE; + return OK; + } + + next_state_ = STATE_GREET_READ; + return OK; +} + +int Socks5ServerSocket::DoGreetWrite() { + if (buffer_.empty()) { + const char write_data[] = {kSOCKS5Version, auth_method_}; + buffer_ = std::string(write_data, arraysize(write_data)); + bytes_sent_ = 0; + } + + next_state_ = STATE_GREET_WRITE_COMPLETE; + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_], + handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoGreetWriteComplete(int result) { + if (result < 0) + return result; + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + bytes_received_ = 0; + if (auth_method_ != kAuthMethodNoAcceptable) { + next_state_ = STATE_HANDSHAKE_READ; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else { + next_state_ = STATE_GREET_WRITE; + } + return OK; +} + +int Socks5ServerSocket::DoHandshakeRead() { + next_state_ = STATE_HANDSHAKE_READ_COMPLETE; + + if (buffer_.empty()) { + DCHECK_EQ(0U, bytes_received_); + DCHECK_EQ(kReadHeaderSize, read_header_size_); + } + + int handshake_buf_len = read_header_size_ - bytes_received_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + return transport_->Read(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoHandshakeReadComplete(int result) { + if (result < 0) + return result; + + // The underlying socket closed unexpectedly. + if (result == 0) { + net_log_.AddEvent( + NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE); + return ERR_SOCKS_CONNECTION_FAILED; + } + + buffer_.append(handshake_buf_->data(), result); + bytes_received_ += result; + + // When the first few bytes are read, check how many more are required + // and accordingly increase them + if (bytes_received_ == kReadHeaderSize) { + if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) { + net_log_.AddEvent(NetLogEventType::SOCKS_UNEXPECTED_VERSION, + NetLog::IntCallback("version", buffer_[0])); + return ERR_SOCKS_CONNECTION_FAILED; + } + SocksCommandType command = static_cast(buffer_[1]); + if (command == kCommandConnect) { + // The proxy replies with success immediately without first connecting + // to the requested endpoint. + reply_ = kReplySuccess; + } else if (command == kCommandBind || command == kCommandUDPAssociate) { + reply_ = kReplyCommandNotSupported; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_UNEXPECTED_COMMAND, + NetLog::IntCallback("commmand", buffer_[1])); + return ERR_SOCKS_CONNECTION_FAILED; + } + + // We check the type of IP/Domain the server returns and accordingly + // increase the size of the request. For domains, we need to read the + // size of the domain, so the initial request size is upto the domain + // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is + // read, we substract 1 byte from the additional request size. + address_type_ = static_cast(buffer_[3]); + if (address_type_ == kEndPointDomain) { + address_size_ = static_cast(buffer_[4]); + if (address_size_ == 0) { + net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (address_type_ == kEndPointResolvedIPv4) { + address_size_ = sizeof(struct in_addr); + --read_header_size_; + } else if (address_type_ == kEndPointResolvedIPv6) { + address_size_ = sizeof(struct in6_addr); + --read_header_size_; + } else { + // Aborts connection on unspecified address type. + net_log_.AddEvent(NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, + NetLog::IntCallback("address_type", buffer_[3])); + return ERR_SOCKS_CONNECTION_FAILED; + } + + read_header_size_ += address_size_ + sizeof(uint16_t); + next_state_ = STATE_HANDSHAKE_READ; + return OK; + } + + // When the final bytes are read, setup handshake. + if (bytes_received_ == read_header_size_) { + size_t port_start = read_header_size_ - sizeof(uint16_t); + uint16_t port_net; + std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t)); + uint16_t port_host = base::NetToHost16(port_net); + + size_t address_start = port_start - address_size_; + if (address_type_ == kEndPointDomain) { + std::string domain(&buffer_[address_start], address_size_); + host_port_pair_ = HostPortPair(domain, port_host); + } else { + IPAddress ip_addr( + reinterpret_cast(&buffer_[address_start]), + address_size_); + IPEndPoint endpoint(ip_addr, port_host); + host_port_pair_ = HostPortPair::FromIPEndPoint(endpoint); + } + buffer_.clear(); + next_state_ = STATE_HANDSHAKE_WRITE; + return OK; + } + + next_state_ = STATE_HANDSHAKE_READ; + return OK; +} + +// Writes the SOCKS handshake data to the underlying socket connection. +int Socks5ServerSocket::DoHandshakeWrite() { + next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; + + if (buffer_.empty()) { + const char write_data[] = { + kSOCKS5Version, + reply_, + kSOCKS5Reserved, + kEndPointResolvedIPv4, + 0x00, 0x00, 0x00, 0x00, // BND.ADDR + 0x00, 0x00, // BND.PORT + }; + buffer_ = std::string(write_data, arraysize(write_data)); + bytes_sent_ = 0; + } + + int handshake_buf_len = buffer_.size() - bytes_sent_; + DCHECK_LT(0, handshake_buf_len); + handshake_buf_ = new IOBuffer(handshake_buf_len); + std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len); + return transport_->Write(handshake_buf_.get(), handshake_buf_len, + io_callback_); +} + +int Socks5ServerSocket::DoHandshakeWriteComplete(int result) { + if (result < 0) + return result; + + // We ignore the case when result is 0, since the underlying Write + // may return spurious writes while waiting on the socket. + + bytes_sent_ += result; + if (bytes_sent_ == buffer_.size()) { + buffer_.clear(); + if (reply_ == kReplySuccess) { + completed_handshake_ = true; + next_state_ = STATE_NONE; + } else { + net_log_.AddEvent(NetLogEventType::SOCKS_SERVER_ERROR, + NetLog::IntCallback("error_code", reply_)); + return ERR_SOCKS_CONNECTION_FAILED; + } + } else if (bytes_sent_ < buffer_.size()) { + next_state_ = STATE_HANDSHAKE_WRITE; + } else { + NOTREACHED(); + } + + return OK; +} + +int Socks5ServerSocket::GetPeerAddress(IPEndPoint* address) const { + return transport_->GetPeerAddress(address); +} + +int Socks5ServerSocket::GetLocalAddress(IPEndPoint* address) const { + return transport_->GetLocalAddress(address); +} + +} // namespace net diff --git a/net/tools/naive/socks5_server_socket.h b/net/tools/naive/socks5_server_socket.h new file mode 100644 index 0000000000..20423a7b45 --- /dev/null +++ b/net/tools/naive/socks5_server_socket.h @@ -0,0 +1,172 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Copyright 2018 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ +#define NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_ + +#include +#include +#include +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "net/base/completion_callback.h" +#include "net/base/host_port_pair.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/log/net_log_with_source.h" +#include "net/socket/connection_attempts.h" +#include "net/socket/next_proto.h" +#include "net/socket/stream_socket.h" +#include "net/ssl/ssl_info.h" + +namespace net { + +// This StreamSocket is used to setup a SOCKSv5 handshake with a socks client. +// Currently no SOCKSv5 authentication is supported. +class NET_EXPORT_PRIVATE Socks5ServerSocket : public StreamSocket { + public: + explicit Socks5ServerSocket(std::unique_ptr transport_socket); + + // On destruction Disconnect() is called. + ~Socks5ServerSocket() override; + + int Connect(HostPortPair* request_endpoint, + const CompletionCallback& callback); + + // StreamSocket implementation. + + // Does the SOCKS handshake and completes the protocol. + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + const NetLogWithSource& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool WasAlpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + + // Socket implementation. + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + + private: + enum State { + STATE_GREET_READ, + STATE_GREET_READ_COMPLETE, + STATE_GREET_WRITE, + STATE_GREET_WRITE_COMPLETE, + STATE_HANDSHAKE_WRITE, + STATE_HANDSHAKE_WRITE_COMPLETE, + STATE_HANDSHAKE_READ, + STATE_HANDSHAKE_READ_COMPLETE, + STATE_NONE, + }; + + // Addressing type that can be specified in requests or responses. + enum SocksEndPointAddressType { + kEndPointDomain = 0x03, + kEndPointResolvedIPv4 = 0x01, + kEndPointResolvedIPv6 = 0x04, + }; + + enum SocksCommandType { + kCommandConnect = 0x01, + kCommandBind = 0x02, + kCommandUDPAssociate = 0x03, + }; + + static const unsigned int kGreetReadHeaderSize; + static const unsigned int kReadHeaderSize; + static const char kSOCKS5Version; + static const char kSOCKS5Reserved; + static const char kAuthMethodNone; + static const char kAuthMethodNoAcceptable; + static const char kReplySuccess; + static const char kReplyCommandNotSupported; + + void DoCallback(int result); + void DoCallbackReturnRequest(HostPortPair* request_endpoint, + const CompletionCallback& callback, + int result); + void OnIOComplete(int result); + void OnReadWriteComplete(const CompletionCallback& callback, int result); + + int DoLoop(int last_io_result); + int DoGreetWrite(); + int DoGreetWriteComplete(int result); + int DoGreetRead(); + int DoGreetReadComplete(int result); + int DoHandshakeRead(); + int DoHandshakeReadComplete(int result); + int DoHandshakeWrite(); + int DoHandshakeWriteComplete(int result); + + CompletionCallback io_callback_; + + // Stores the underlying socket. + std::unique_ptr transport_; + + State next_state_; + + // Stores the callback to the layer above, called on completing Connect(). + CompletionCallback user_callback_; + + // This IOBuffer is used by the class to read and write + // SOCKS handshake data. The length contains the expected size to + // read or write. + scoped_refptr handshake_buf_; + + // While writing, this buffer stores the complete write handshake data. + // While reading, it stores the handshake information received so far. + std::string buffer_; + + // This becomes true when the SOCKS handshake has completed and the + // overlying connection is free to communicate. + bool completed_handshake_; + + // These contain the bytes received / sent by the SOCKS handshake. + size_t bytes_received_; + size_t bytes_sent_; + + size_t greet_read_header_size_; + size_t read_header_size_; + + bool was_ever_used_; + + SocksEndPointAddressType address_type_; + int address_size_; + + char auth_method_; + char reply_; + + HostPortPair host_port_pair_; + + NetLogWithSource net_log_; + + DISALLOW_COPY_AND_ASSIGN(Socks5ServerSocket); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_