Add initial implementation of Naive client

This commit is contained in:
klzgrad 2018-01-20 12:15:45 -05:00
parent 847e762352
commit be525e4436
16 changed files with 3542 additions and 0 deletions

View File

@ -1725,3 +1725,29 @@ static_library("preload_decoder") {
]
deps = [ "//base" ]
}
executable("naive") {
sources = [
"tools/naive/naive_connection.cc",
"tools/naive/naive_connection.h",
"tools/naive/naive_proxy.cc",
"tools/naive/naive_proxy.h",
"tools/naive/naive_proxy_bin.cc",
"tools/naive/naive_proxy_delegate.h",
"tools/naive/naive_proxy_delegate.cc",
"tools/naive/http_proxy_socket.cc",
"tools/naive/http_proxy_socket.h",
"tools/naive/redirect_resolver.h",
"tools/naive/redirect_resolver.cc",
"tools/naive/socks5_server_socket.cc",
"tools/naive/socks5_server_socket.h",
]
deps = [
":net",
"//base",
"//build/win:default_exe_manifest",
"//components/version_info:version_info",
"//url",
]
}

View File

@ -488,6 +488,11 @@ EVENT_TYPE(SOCKS_HOSTNAME_TOO_BIG)
EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING)
EVENT_TYPE(SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE)
EVENT_TYPE(SOCKS_NO_REQUESTED_AUTH)
EVENT_TYPE(SOCKS_NO_ACCEPTABLE_AUTH)
EVENT_TYPE(SOCKS_ZERO_LENGTH_DOMAIN)
EVENT_TYPE(SOCKS_UNEXPECTED_COMMAND)
// This event indicates that a bad version number was received in the
// proxy server's response. The extra parameters show its value:
// {

View File

@ -0,0 +1,356 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/http_proxy_socket.h"
#include <cstring>
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/logging.h"
#include "base/rand_util.h"
#include "base/sys_byteorder.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_headers.h"
#include "net/log/net_log.h"
#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h"
#include "net/tools/naive/naive_proxy_delegate.h"
namespace net {
namespace {
constexpr int kBufferSize = 64 * 1024;
constexpr size_t kMaxHeaderSize = 64 * 1024;
constexpr char kResponseHeader[] = "HTTP/1.1 200 OK\r\nPadding: ";
constexpr int kResponseHeaderSize = sizeof(kResponseHeader) - 1;
// A plain 200 is 10 bytes. Expected 48 bytes. "Padding" uses up 7 bytes.
constexpr int kMinPaddingSize = 30;
constexpr int kMaxPaddingSize = kMinPaddingSize + 32;
} // namespace
HttpProxySocket::HttpProxySocket(
std::unique_ptr<StreamSocket> transport_socket,
ClientPaddingDetectorDelegate* padding_detector_delegate,
const NetworkTrafficAnnotationTag& traffic_annotation)
: io_callback_(base::BindRepeating(&HttpProxySocket::OnIOComplete,
base::Unretained(this))),
transport_(std::move(transport_socket)),
padding_detector_delegate_(padding_detector_delegate),
next_state_(STATE_NONE),
completed_handshake_(false),
was_ever_used_(false),
header_write_size_(-1),
net_log_(transport_->NetLog()),
traffic_annotation_(traffic_annotation) {}
HttpProxySocket::~HttpProxySocket() {
Disconnect();
}
const HostPortPair& HttpProxySocket::request_endpoint() const {
return request_endpoint_;
}
int HttpProxySocket::Connect(CompletionOnceCallback callback) {
DCHECK(transport_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
// If already connected, then just return OK.
if (completed_handshake_)
return OK;
next_state_ = STATE_HEADER_READ;
buffer_.clear();
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING) {
user_callback_ = std::move(callback);
}
return rv;
}
void HttpProxySocket::Disconnect() {
completed_handshake_ = false;
transport_->Disconnect();
// Reset other states to make sure they aren't mistakenly used later.
// These are the states initialized by Connect().
next_state_ = STATE_NONE;
user_callback_.Reset();
}
bool HttpProxySocket::IsConnected() const {
return completed_handshake_ && transport_->IsConnected();
}
bool HttpProxySocket::IsConnectedAndIdle() const {
return completed_handshake_ && transport_->IsConnectedAndIdle();
}
const NetLogWithSource& HttpProxySocket::NetLog() const {
return net_log_;
}
bool HttpProxySocket::WasEverUsed() const {
return was_ever_used_;
}
bool HttpProxySocket::WasAlpnNegotiated() const {
if (transport_) {
return transport_->WasAlpnNegotiated();
}
NOTREACHED();
return false;
}
NextProto HttpProxySocket::GetNegotiatedProtocol() const {
if (transport_) {
return transport_->GetNegotiatedProtocol();
}
NOTREACHED();
return kProtoUnknown;
}
bool HttpProxySocket::GetSSLInfo(SSLInfo* ssl_info) {
if (transport_) {
return transport_->GetSSLInfo(ssl_info);
}
NOTREACHED();
return false;
}
int64_t HttpProxySocket::GetTotalReceivedBytes() const {
return transport_->GetTotalReceivedBytes();
}
void HttpProxySocket::ApplySocketTag(const SocketTag& tag) {
return transport_->ApplySocketTag(tag);
}
// Read is called by the transport layer above to read. This can only be done
// if the HTTP header is complete.
int HttpProxySocket::Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
DCHECK(callback);
if (!buffer_.empty()) {
was_ever_used_ = true;
int data_len = buffer_.size();
if (data_len <= buf_len) {
std::memcpy(buf->data(), buffer_.data(), data_len);
buffer_.clear();
return data_len;
} else {
std::memcpy(buf->data(), buffer_.data(), buf_len);
buffer_ = buffer_.substr(buf_len);
return buf_len;
}
}
int rv = transport_->Read(
buf, buf_len,
base::BindOnce(&HttpProxySocket::OnReadWriteComplete,
base::Unretained(this), std::move(callback)));
if (rv > 0)
was_ever_used_ = true;
return rv;
}
// Write is called by the transport layer. This can only be done if the
// SOCKS handshake is complete.
int HttpProxySocket::Write(
IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
DCHECK(callback);
int rv = transport_->Write(
buf, buf_len,
base::BindOnce(&HttpProxySocket::OnReadWriteComplete,
base::Unretained(this), std::move(callback)),
traffic_annotation);
if (rv > 0)
was_ever_used_ = true;
return rv;
}
int HttpProxySocket::SetReceiveBufferSize(int32_t size) {
return transport_->SetReceiveBufferSize(size);
}
int HttpProxySocket::SetSendBufferSize(int32_t size) {
return transport_->SetSendBufferSize(size);
}
void HttpProxySocket::DoCallback(int result) {
DCHECK_NE(ERR_IO_PENDING, result);
DCHECK(user_callback_);
// Since Run() may result in Read being called,
// clear user_callback_ up front.
std::move(user_callback_).Run(result);
}
void HttpProxySocket::OnIOComplete(int result) {
DCHECK_NE(STATE_NONE, next_state_);
int rv = DoLoop(result);
if (rv != ERR_IO_PENDING) {
DoCallback(rv);
}
}
void HttpProxySocket::OnReadWriteComplete(CompletionOnceCallback callback,
int result) {
DCHECK_NE(ERR_IO_PENDING, result);
DCHECK(callback);
if (result > 0)
was_ever_used_ = true;
std::move(callback).Run(result);
}
int HttpProxySocket::DoLoop(int last_io_result) {
DCHECK_NE(next_state_, STATE_NONE);
int rv = last_io_result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_HEADER_READ:
DCHECK_EQ(OK, rv);
rv = DoHeaderRead();
break;
case STATE_HEADER_READ_COMPLETE:
rv = DoHeaderReadComplete(rv);
break;
case STATE_HEADER_WRITE:
DCHECK_EQ(OK, rv);
rv = DoHeaderWrite();
break;
case STATE_HEADER_WRITE_COMPLETE:
rv = DoHeaderWriteComplete(rv);
break;
default:
NOTREACHED() << "bad state";
rv = ERR_UNEXPECTED;
break;
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
return rv;
}
int HttpProxySocket::DoHeaderRead() {
next_state_ = STATE_HEADER_READ_COMPLETE;
handshake_buf_ = base::MakeRefCounted<IOBuffer>(kBufferSize);
return transport_->Read(handshake_buf_.get(), kBufferSize, io_callback_);
}
int HttpProxySocket::DoHeaderReadComplete(int result) {
if (result < 0)
return result;
if (result == 0) {
return ERR_CONNECTION_CLOSED;
}
buffer_.append(handshake_buf_->data(), result);
if (buffer_.size() > kMaxHeaderSize) {
return ERR_MSG_TOO_BIG;
}
auto header_end = buffer_.find("\r\n\r\n");
if (header_end == std::string::npos) {
next_state_ = STATE_HEADER_READ;
return OK;
}
// HttpProxyClientSocket uses CONNECT for all endpoints.
auto first_line_end = buffer_.find("\r\n");
auto first_space = buffer_.find(' ');
if (first_space == std::string::npos || first_space + 1 >= first_line_end) {
return ERR_INVALID_ARGUMENT;
}
if (buffer_.compare(0, first_space, "CONNECT") != 0) {
return ERR_INVALID_ARGUMENT;
}
auto second_space = buffer_.find(' ', first_space + 1);
if (second_space == std::string::npos || second_space >= first_line_end) {
return ERR_INVALID_ARGUMENT;
}
request_endpoint_ = HostPortPair::FromString(
buffer_.substr(first_space + 1, second_space - (first_space + 1)));
auto second_line = first_line_end + 2;
HttpRequestHeaders headers;
std::string headers_str;
if (second_line < header_end) {
headers_str = buffer_.substr(second_line, header_end - second_line);
headers.AddHeadersFromString(headers_str);
}
if (headers.HasHeader("padding")) {
padding_detector_delegate_->SetClientPaddingSupport(
PaddingSupport::kCapable);
} else {
padding_detector_delegate_->SetClientPaddingSupport(
PaddingSupport::kIncapable);
}
buffer_ = buffer_.substr(header_end + 4);
next_state_ = STATE_HEADER_WRITE;
return OK;
}
int HttpProxySocket::DoHeaderWrite() {
next_state_ = STATE_HEADER_WRITE_COMPLETE;
// Adds padding.
int padding_size = base::RandInt(kMinPaddingSize, kMaxPaddingSize);
header_write_size_ = kResponseHeaderSize + padding_size + 4;
handshake_buf_ = base::MakeRefCounted<IOBuffer>(header_write_size_);
char* p = handshake_buf_->data();
std::memcpy(p, kResponseHeader, kResponseHeaderSize);
FillNonindexHeaderValue(base::RandUint64(), p + kResponseHeaderSize,
padding_size);
std::memcpy(p + kResponseHeaderSize + padding_size, "\r\n\r\n", 4);
return transport_->Write(handshake_buf_.get(), header_write_size_,
io_callback_, traffic_annotation_);
}
int HttpProxySocket::DoHeaderWriteComplete(int result) {
if (result < 0)
return result;
if (result != header_write_size_) {
return ERR_FAILED;
}
completed_handshake_ = true;
next_state_ = STATE_NONE;
return OK;
}
int HttpProxySocket::GetPeerAddress(IPEndPoint* address) const {
return transport_->GetPeerAddress(address);
}
int HttpProxySocket::GetLocalAddress(IPEndPoint* address) const {
return transport_->GetLocalAddress(address);
}
} // namespace net

View File

@ -0,0 +1,122 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_
#define NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include "base/memory/scoped_refptr.h"
#include "net/base/completion_once_callback.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/log/net_log_with_source.h"
#include "net/socket/connection_attempts.h"
#include "net/socket/next_proto.h"
#include "net/socket/stream_socket.h"
#include "net/ssl/ssl_info.h"
namespace net {
struct NetworkTrafficAnnotationTag;
class ClientPaddingDetectorDelegate;
// This StreamSocket is used to setup a HTTP CONNECT tunnel.
class HttpProxySocket : public StreamSocket {
public:
HttpProxySocket(std::unique_ptr<StreamSocket> transport_socket,
ClientPaddingDetectorDelegate* padding_detector_delegate,
const NetworkTrafficAnnotationTag& traffic_annotation);
HttpProxySocket(const HttpProxySocket&) = delete;
HttpProxySocket& operator=(const HttpProxySocket&) = delete;
// On destruction Disconnect() is called.
~HttpProxySocket() override;
const HostPortPair& request_endpoint() const;
// StreamSocket implementation.
int Connect(CompletionOnceCallback callback) override;
void Disconnect() override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
const NetLogWithSource& NetLog() const override;
bool WasEverUsed() const override;
bool WasAlpnNegotiated() const override;
NextProto GetNegotiatedProtocol() const override;
bool GetSSLInfo(SSLInfo* ssl_info) override;
int64_t GetTotalReceivedBytes() const override;
void ApplySocketTag(const SocketTag& tag) override;
// Socket implementation.
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
private:
enum State {
STATE_HEADER_READ,
STATE_HEADER_READ_COMPLETE,
STATE_HEADER_WRITE,
STATE_HEADER_WRITE_COMPLETE,
STATE_NONE,
};
void DoCallback(int result);
void OnIOComplete(int result);
void OnReadWriteComplete(CompletionOnceCallback callback, int result);
int DoLoop(int last_io_result);
int DoHeaderWrite();
int DoHeaderWriteComplete(int result);
int DoHeaderRead();
int DoHeaderReadComplete(int result);
CompletionRepeatingCallback io_callback_;
// Stores the underlying socket.
std::unique_ptr<StreamSocket> transport_;
ClientPaddingDetectorDelegate* padding_detector_delegate_;
State next_state_;
// Stores the callback to the layer above, called on completing Connect().
CompletionOnceCallback user_callback_;
// This IOBuffer is used by the class to read and write
// SOCKS handshake data. The length contains the expected size to
// read or write.
scoped_refptr<IOBuffer> handshake_buf_;
std::string buffer_;
bool completed_handshake_;
bool was_ever_used_;
int header_write_size_;
HostPortPair request_endpoint_;
NetLogWithSource net_log_;
// Traffic annotation for socket control.
const NetworkTrafficAnnotationTag& traffic_annotation_;
};
} // namespace net
#endif // NET_TOOLS_NAIVE_HTTP_PROXY_SOCKET_H_

View File

@ -0,0 +1,572 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/naive_connection.h"
#include <cstring>
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/logging.h"
#include "base/rand_util.h"
#include "base/strings/strcat.h"
#include "base/threading/thread_task_runner_handle.h"
#include "base/time/time.h"
#include "net/base/io_buffer.h"
#include "net/base/load_flags.h"
#include "net/base/net_errors.h"
#include "net/base/privacy_mode.h"
#include "net/base/url_util.h"
#include "net/proxy_resolution/proxy_info.h"
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_manager.h"
#include "net/socket/stream_socket.h"
#include "net/spdy/spdy_session.h"
#include "net/tools/naive/http_proxy_socket.h"
#include "net/tools/naive/redirect_resolver.h"
#include "net/tools/naive/socks5_server_socket.h"
#include "url/scheme_host_port.h"
#if defined(OS_LINUX)
#include <linux/netfilter_ipv4.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include "net/base/ip_endpoint.h"
#include "net/base/sockaddr_storage.h"
#include "net/socket/tcp_client_socket.h"
#endif
namespace net {
namespace {
constexpr int kBufferSize = 64 * 1024;
constexpr int kFirstPaddings = 8;
constexpr int kPaddingHeaderSize = 3;
constexpr int kMaxPaddingSize = 255;
} // namespace
NaiveConnection::NaiveConnection(
unsigned int id,
ClientProtocol protocol,
std::unique_ptr<PaddingDetectorDelegate> padding_detector_delegate,
const ProxyInfo& proxy_info,
const SSLConfig& server_ssl_config,
const SSLConfig& proxy_ssl_config,
RedirectResolver* resolver,
HttpNetworkSession* session,
const NetworkIsolationKey& network_isolation_key,
const NetLogWithSource& net_log,
std::unique_ptr<StreamSocket> accepted_socket,
const NetworkTrafficAnnotationTag& traffic_annotation)
: id_(id),
protocol_(protocol),
padding_detector_delegate_(std::move(padding_detector_delegate)),
proxy_info_(proxy_info),
server_ssl_config_(server_ssl_config),
proxy_ssl_config_(proxy_ssl_config),
resolver_(resolver),
session_(session),
network_isolation_key_(network_isolation_key),
net_log_(net_log),
next_state_(STATE_NONE),
client_socket_(std::move(accepted_socket)),
server_socket_handle_(std::make_unique<ClientSocketHandle>()),
sockets_{client_socket_.get(), nullptr},
errors_{OK, OK},
write_pending_{false, false},
early_pull_pending_(false),
can_push_to_server_(false),
early_pull_result_(ERR_IO_PENDING),
num_paddings_{0, 0},
read_padding_state_(STATE_READ_PAYLOAD_LENGTH_1),
full_duplex_(false),
time_func_(&base::TimeTicks::Now),
traffic_annotation_(traffic_annotation) {
io_callback_ = base::BindRepeating(&NaiveConnection::OnIOComplete,
weak_ptr_factory_.GetWeakPtr());
}
NaiveConnection::~NaiveConnection() {
Disconnect();
}
int NaiveConnection::Connect(CompletionOnceCallback callback) {
DCHECK(client_socket_);
DCHECK_EQ(next_state_, STATE_NONE);
DCHECK(!connect_callback_);
if (full_duplex_)
return OK;
next_state_ = STATE_CONNECT_CLIENT;
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING) {
connect_callback_ = std::move(callback);
}
return rv;
}
void NaiveConnection::Disconnect() {
full_duplex_ = false;
// Closes server side first because latency is higher.
if (server_socket_handle_->socket())
server_socket_handle_->socket()->Disconnect();
client_socket_->Disconnect();
next_state_ = STATE_NONE;
connect_callback_.Reset();
run_callback_.Reset();
}
void NaiveConnection::DoCallback(int result) {
DCHECK_NE(result, ERR_IO_PENDING);
DCHECK(connect_callback_);
// Since Run() may result in Read being called,
// clear connect_callback_ up front.
std::move(connect_callback_).Run(result);
}
void NaiveConnection::OnIOComplete(int result) {
DCHECK_NE(next_state_, STATE_NONE);
int rv = DoLoop(result);
if (rv != ERR_IO_PENDING) {
DoCallback(rv);
}
}
int NaiveConnection::DoLoop(int last_io_result) {
DCHECK_NE(next_state_, STATE_NONE);
int rv = last_io_result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_CONNECT_CLIENT:
DCHECK_EQ(rv, OK);
rv = DoConnectClient();
break;
case STATE_CONNECT_CLIENT_COMPLETE:
rv = DoConnectClientComplete(rv);
break;
case STATE_CONNECT_SERVER:
DCHECK_EQ(rv, OK);
rv = DoConnectServer();
break;
case STATE_CONNECT_SERVER_COMPLETE:
rv = DoConnectServerComplete(rv);
break;
default:
NOTREACHED() << "bad state";
rv = ERR_UNEXPECTED;
break;
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
return rv;
}
int NaiveConnection::DoConnectClient() {
next_state_ = STATE_CONNECT_CLIENT_COMPLETE;
return client_socket_->Connect(io_callback_);
}
int NaiveConnection::DoConnectClientComplete(int result) {
if (result < 0)
return result;
// For proxy client sockets, padding support detection is finished after the
// first server response which means there will be one missed early pull. For
// proxy server sockets (HttpProxySocket), padding support detection is
// done during client connect, so there shouldn't be any missed early pull.
if (!padding_detector_delegate_->IsPaddingSupportKnown()) {
early_pull_pending_ = false;
early_pull_result_ = 0;
next_state_ = STATE_CONNECT_SERVER;
return OK;
}
early_pull_pending_ = true;
Pull(kClient, kServer);
if (early_pull_result_ != ERR_IO_PENDING) {
// Pull has completed synchronously.
if (early_pull_result_ <= 0) {
return early_pull_result_ ? early_pull_result_ : ERR_CONNECTION_CLOSED;
}
}
next_state_ = STATE_CONNECT_SERVER;
return OK;
}
int NaiveConnection::DoConnectServer() {
next_state_ = STATE_CONNECT_SERVER_COMPLETE;
HostPortPair origin;
if (protocol_ == ClientProtocol::kSocks5) {
const auto* socket =
static_cast<const Socks5ServerSocket*>(client_socket_.get());
origin = socket->request_endpoint();
} else if (protocol_ == ClientProtocol::kHttp) {
const auto* socket =
static_cast<const HttpProxySocket*>(client_socket_.get());
origin = socket->request_endpoint();
} else if (protocol_ == ClientProtocol::kRedir) {
#if defined(OS_LINUX)
const auto* socket =
static_cast<const TCPClientSocket*>(client_socket_.get());
IPEndPoint peer_endpoint;
int rv;
rv = socket->GetPeerAddress(&peer_endpoint);
if (rv != OK) {
LOG(ERROR) << "Connection " << id_ << " cannot get peer address";
return rv;
}
int sd = socket->SocketDescriptorForTesting();
SockaddrStorage dst;
if (peer_endpoint.GetFamily() == ADDRESS_FAMILY_IPV4 ||
peer_endpoint.address().IsIPv4MappedIPv6()) {
rv = getsockopt(sd, SOL_IP, SO_ORIGINAL_DST, dst.addr, &dst.addr_len);
} else {
rv = getsockopt(sd, SOL_IPV6, SO_ORIGINAL_DST, dst.addr, &dst.addr_len);
}
if (rv == 0) {
IPEndPoint ipe;
if (ipe.FromSockAddr(dst.addr, dst.addr_len)) {
const auto& addr = ipe.address();
auto name = resolver_->FindNameByAddress(addr);
if (!name.empty()) {
origin = HostPortPair(name, ipe.port());
} else if (!resolver_->IsInResolvedRange(addr)) {
origin = HostPortPair::FromIPEndPoint(ipe);
} else {
LOG(ERROR) << "Connection " << id_ << " to unresolved name for "
<< addr.ToString();
return ERR_ADDRESS_INVALID;
}
}
} else {
LOG(ERROR) << "Failed to get original destination address";
return ERR_ADDRESS_INVALID;
}
#else
static_cast<void>(resolver_);
#endif
}
url::CanonHostInfo host_info;
url::SchemeHostPort endpoint(
"http", CanonicalizeHost(origin.HostForURL(), &host_info), origin.port(),
url::SchemeHostPort::ALREADY_CANONICALIZED);
if (!endpoint.IsValid()) {
LOG(ERROR) << "Connection " << id_ << " to invalid origin "
<< origin.ToString();
return ERR_ADDRESS_INVALID;
}
LOG(INFO) << "Connection " << id_ << " to " << origin.ToString();
// Ignores socket limit set by socket pool for this type of socket.
return InitSocketHandleForRawConnect2(
std::move(endpoint), session_, LOAD_IGNORE_LIMITS, MAXIMUM_PRIORITY,
proxy_info_, server_ssl_config_, proxy_ssl_config_, PRIVACY_MODE_DISABLED,
network_isolation_key_, net_log_, server_socket_handle_.get(),
io_callback_);
}
int NaiveConnection::DoConnectServerComplete(int result) {
if (result < 0)
return result;
DCHECK(server_socket_handle_->socket());
sockets_[kServer] = server_socket_handle_->socket();
full_duplex_ = true;
next_state_ = STATE_NONE;
return OK;
}
int NaiveConnection::Run(CompletionOnceCallback callback) {
DCHECK(sockets_[kClient]);
DCHECK(sockets_[kServer]);
DCHECK_EQ(next_state_, STATE_NONE);
DCHECK(!connect_callback_);
if (errors_[kClient] != OK)
return errors_[kClient];
if (errors_[kServer] != OK)
return errors_[kServer];
run_callback_ = std::move(callback);
bytes_passed_without_yielding_[kClient] = 0;
bytes_passed_without_yielding_[kServer] = 0;
yield_after_time_[kClient] =
time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds);
yield_after_time_[kServer] = yield_after_time_[kClient];
can_push_to_server_ = true;
// early_pull_result_ == 0 means the early pull was not started because
// padding support was not yet known.
if (!early_pull_pending_ && early_pull_result_ == 0) {
Pull(kClient, kServer);
} else if (!early_pull_pending_) {
DCHECK_GT(early_pull_result_, 0);
Push(kClient, kServer, early_pull_result_);
}
Pull(kServer, kClient);
return ERR_IO_PENDING;
}
void NaiveConnection::Pull(Direction from, Direction to) {
if (errors_[kClient] < 0 || errors_[kServer] < 0)
return;
int read_size = kBufferSize;
auto padding_direction = padding_detector_delegate_->GetPaddingDirection();
if (from == padding_direction && num_paddings_[from] < kFirstPaddings) {
auto buffer = base::MakeRefCounted<GrowableIOBuffer>();
buffer->SetCapacity(kBufferSize);
buffer->set_offset(kPaddingHeaderSize);
read_buffers_[from] = buffer;
read_size = kBufferSize - kPaddingHeaderSize - kMaxPaddingSize;
} else {
read_buffers_[from] = base::MakeRefCounted<IOBuffer>(kBufferSize);
}
DCHECK(sockets_[from]);
int rv = sockets_[from]->Read(
read_buffers_[from].get(), read_size,
base::BindRepeating(&NaiveConnection::OnPullComplete,
weak_ptr_factory_.GetWeakPtr(), from, to));
if (from == kClient && early_pull_pending_)
early_pull_result_ = rv;
if (rv != ERR_IO_PENDING)
OnPullComplete(from, to, rv);
}
void NaiveConnection::Push(Direction from, Direction to, int size) {
int write_size = size;
int write_offset = 0;
auto padding_direction = padding_detector_delegate_->GetPaddingDirection();
if (from == padding_direction && num_paddings_[from] < kFirstPaddings) {
// Adds padding.
++num_paddings_[from];
int padding_size = base::RandInt(0, kMaxPaddingSize);
auto* buffer = static_cast<GrowableIOBuffer*>(read_buffers_[from].get());
buffer->set_offset(0);
uint8_t* p = reinterpret_cast<uint8_t*>(buffer->data());
p[0] = size / 256;
p[1] = size % 256;
p[2] = padding_size;
std::memset(p + kPaddingHeaderSize + size, 0, padding_size);
write_size = kPaddingHeaderSize + size + padding_size;
} else if (to == padding_direction && num_paddings_[from] < kFirstPaddings) {
// Removes padding.
const char* p = read_buffers_[from]->data();
bool trivial_padding = false;
if (read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1 &&
size >= kPaddingHeaderSize) {
int payload_size =
static_cast<uint8_t>(p[0]) * 256 + static_cast<uint8_t>(p[1]);
int padding_size = static_cast<uint8_t>(p[2]);
if (size == kPaddingHeaderSize + payload_size + padding_size) {
write_size = payload_size;
write_offset = kPaddingHeaderSize;
++num_paddings_[from];
trivial_padding = true;
}
}
if (!trivial_padding) {
auto unpadded_buffer = base::MakeRefCounted<IOBuffer>(kBufferSize);
char* unpadded_ptr = unpadded_buffer->data();
for (int i = 0; i < size;) {
if (num_paddings_[from] >= kFirstPaddings &&
read_padding_state_ == STATE_READ_PAYLOAD_LENGTH_1) {
std::memcpy(unpadded_ptr, p + i, size - i);
unpadded_ptr += size - i;
break;
}
int copy_size;
switch (read_padding_state_) {
case STATE_READ_PAYLOAD_LENGTH_1:
payload_length_ = static_cast<uint8_t>(p[i]);
++i;
read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_2;
break;
case STATE_READ_PAYLOAD_LENGTH_2:
payload_length_ =
payload_length_ * 256 + static_cast<uint8_t>(p[i]);
++i;
read_padding_state_ = STATE_READ_PADDING_LENGTH;
break;
case STATE_READ_PADDING_LENGTH:
padding_length_ = static_cast<uint8_t>(p[i]);
++i;
read_padding_state_ = STATE_READ_PAYLOAD;
break;
case STATE_READ_PAYLOAD:
if (payload_length_ <= size - i) {
copy_size = payload_length_;
read_padding_state_ = STATE_READ_PADDING;
} else {
copy_size = size - i;
}
std::memcpy(unpadded_ptr, p + i, copy_size);
unpadded_ptr += copy_size;
i += copy_size;
payload_length_ -= copy_size;
break;
case STATE_READ_PADDING:
if (padding_length_ <= size - i) {
copy_size = padding_length_;
read_padding_state_ = STATE_READ_PAYLOAD_LENGTH_1;
++num_paddings_[from];
} else {
copy_size = size - i;
}
i += copy_size;
padding_length_ -= copy_size;
break;
}
}
write_size = unpadded_ptr - unpadded_buffer->data();
read_buffers_[from] = unpadded_buffer;
}
if (write_size == 0) {
OnPushComplete(from, to, OK);
return;
}
}
write_buffers_[to] = base::MakeRefCounted<DrainableIOBuffer>(
std::move(read_buffers_[from]), write_offset + write_size);
if (write_offset) {
write_buffers_[to]->DidConsume(write_offset);
}
write_pending_[to] = true;
DCHECK(sockets_[to]);
int rv = sockets_[to]->Write(
write_buffers_[to].get(), write_size,
base::BindRepeating(&NaiveConnection::OnPushComplete,
weak_ptr_factory_.GetWeakPtr(), from, to),
traffic_annotation_);
if (rv != ERR_IO_PENDING)
OnPushComplete(from, to, rv);
}
void NaiveConnection::Disconnect(Direction side) {
if (sockets_[side]) {
sockets_[side]->Disconnect();
sockets_[side] = nullptr;
write_pending_[side] = false;
}
}
bool NaiveConnection::IsConnected(Direction side) {
return sockets_[side];
}
void NaiveConnection::OnBothDisconnected() {
if (run_callback_) {
int error = OK;
if (errors_[kClient] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0)
error = errors_[kClient];
if (errors_[kServer] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0)
error = errors_[kServer];
std::move(run_callback_).Run(error);
}
}
void NaiveConnection::OnPullError(Direction from, Direction to, int error) {
DCHECK_LT(error, 0);
errors_[from] = error;
Disconnect(from);
if (!write_pending_[to])
Disconnect(to);
if (!IsConnected(from) && !IsConnected(to))
OnBothDisconnected();
}
void NaiveConnection::OnPushError(Direction from, Direction to, int error) {
DCHECK_LE(error, 0);
DCHECK(!write_pending_[to]);
if (error < 0) {
errors_[to] = error;
Disconnect(kServer);
Disconnect(kClient);
} else if (!IsConnected(from)) {
Disconnect(to);
}
if (!IsConnected(from) && !IsConnected(to))
OnBothDisconnected();
}
void NaiveConnection::OnPullComplete(Direction from, Direction to, int result) {
if (from == kClient && early_pull_pending_) {
early_pull_pending_ = false;
early_pull_result_ = result ? result : ERR_CONNECTION_CLOSED;
}
if (result <= 0) {
OnPullError(from, to, result ? result : ERR_CONNECTION_CLOSED);
return;
}
if (from == kClient && !can_push_to_server_)
return;
Push(from, to, result);
}
void NaiveConnection::OnPushComplete(Direction from, Direction to, int result) {
if (result >= 0 && write_buffers_[to] != nullptr) {
bytes_passed_without_yielding_[from] += result;
write_buffers_[to]->DidConsume(result);
int size = write_buffers_[to]->BytesRemaining();
if (size > 0) {
int rv = sockets_[to]->Write(
write_buffers_[to].get(), size,
base::BindRepeating(&NaiveConnection::OnPushComplete,
weak_ptr_factory_.GetWeakPtr(), from, to),
traffic_annotation_);
if (rv != ERR_IO_PENDING)
OnPushComplete(from, to, rv);
return;
}
}
write_pending_[to] = false;
// Checks for termination even if result is OK.
OnPushError(from, to, result >= 0 ? OK : result);
if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead ||
time_func_() > yield_after_time_[from]) {
bytes_passed_without_yielding_[from] = 0;
yield_after_time_[from] =
time_func_() + base::Milliseconds(kYieldAfterDurationMilliseconds);
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindRepeating(&NaiveConnection::Pull,
weak_ptr_factory_.GetWeakPtr(), from, to));
} else {
Pull(from, to);
}
}
} // namespace net

View File

@ -0,0 +1,142 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_
#define NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_
#include <memory>
#include <string>
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "net/base/completion_once_callback.h"
#include "net/base/completion_repeating_callback.h"
#include "net/tools/naive/naive_protocol.h"
#include "net/tools/naive/naive_proxy_delegate.h"
namespace net {
class ClientSocketHandle;
class DrainableIOBuffer;
class HttpNetworkSession;
class IOBuffer;
class NetLogWithSource;
class ProxyInfo;
class StreamSocket;
struct NetworkTrafficAnnotationTag;
struct SSLConfig;
class RedirectResolver;
class NetworkIsolationKey;
class NaiveConnection {
public:
using TimeFunc = base::TimeTicks (*)();
NaiveConnection(
unsigned int id,
ClientProtocol protocol,
std::unique_ptr<PaddingDetectorDelegate> padding_detector_delegate,
const ProxyInfo& proxy_info,
const SSLConfig& server_ssl_config,
const SSLConfig& proxy_ssl_config,
RedirectResolver* resolver,
HttpNetworkSession* session,
const NetworkIsolationKey& network_isolation_key,
const NetLogWithSource& net_log,
std::unique_ptr<StreamSocket> accepted_socket,
const NetworkTrafficAnnotationTag& traffic_annotation);
~NaiveConnection();
NaiveConnection(const NaiveConnection&) = delete;
NaiveConnection& operator=(const NaiveConnection&) = delete;
unsigned int id() const { return id_; }
int Connect(CompletionOnceCallback callback);
void Disconnect();
int Run(CompletionOnceCallback callback);
private:
enum State {
STATE_CONNECT_CLIENT,
STATE_CONNECT_CLIENT_COMPLETE,
STATE_CONNECT_SERVER,
STATE_CONNECT_SERVER_COMPLETE,
STATE_NONE,
};
enum PaddingState {
STATE_READ_PAYLOAD_LENGTH_1,
STATE_READ_PAYLOAD_LENGTH_2,
STATE_READ_PADDING_LENGTH,
STATE_READ_PAYLOAD,
STATE_READ_PADDING,
};
void DoCallback(int result);
void OnIOComplete(int result);
int DoLoop(int last_io_result);
int DoConnectClient();
int DoConnectClientComplete(int result);
int DoConnectServer();
int DoConnectServerComplete(int result);
void Pull(Direction from, Direction to);
void Push(Direction from, Direction to, int size);
void Disconnect(Direction side);
bool IsConnected(Direction side);
void OnBothDisconnected();
void OnPullError(Direction from, Direction to, int error);
void OnPushError(Direction from, Direction to, int error);
void OnPullComplete(Direction from, Direction to, int result);
void OnPushComplete(Direction from, Direction to, int result);
unsigned int id_;
ClientProtocol protocol_;
std::unique_ptr<PaddingDetectorDelegate> padding_detector_delegate_;
const ProxyInfo& proxy_info_;
const SSLConfig& server_ssl_config_;
const SSLConfig& proxy_ssl_config_;
RedirectResolver* resolver_;
HttpNetworkSession* session_;
const NetworkIsolationKey& network_isolation_key_;
const NetLogWithSource& net_log_;
CompletionRepeatingCallback io_callback_;
CompletionOnceCallback connect_callback_;
CompletionOnceCallback run_callback_;
State next_state_;
std::unique_ptr<StreamSocket> client_socket_;
std::unique_ptr<ClientSocketHandle> server_socket_handle_;
StreamSocket* sockets_[kNumDirections];
scoped_refptr<IOBuffer> read_buffers_[kNumDirections];
scoped_refptr<DrainableIOBuffer> write_buffers_[kNumDirections];
int errors_[kNumDirections];
bool write_pending_[kNumDirections];
int bytes_passed_without_yielding_[kNumDirections];
base::TimeTicks yield_after_time_[kNumDirections];
bool early_pull_pending_;
bool can_push_to_server_;
int early_pull_result_;
int num_paddings_[kNumDirections];
PaddingState read_padding_state_;
int payload_length_;
int padding_length_;
bool full_duplex_;
TimeFunc time_func_;
// Traffic annotation for socket control.
const NetworkTrafficAnnotationTag& traffic_annotation_;
base::WeakPtrFactory<NaiveConnection> weak_ptr_factory_{this};
};
} // namespace net
#endif // NET_TOOLS_NAIVE_NAIVE_CONNECTION_H_

View File

@ -0,0 +1,24 @@
// Copyright 2020 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_
#define NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_
namespace net {
enum class ClientProtocol {
kSocks5,
kHttp,
kRedir,
};
// Adds padding for traffic from this direction.
// Removes padding for traffic from the opposite direction.
enum Direction {
kClient = 0,
kServer = 1,
kNumDirections = 2,
kNone = 2,
};
} // namespace net
#endif // NET_TOOLS_NAIVE_NAIVE_PROTOCOL_H_

View File

@ -0,0 +1,211 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/naive_proxy.h"
#include <utility>
#include "base/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/threading/thread_task_runner_handle.h"
#include "net/base/load_flags.h"
#include "net/base/net_errors.h"
#include "net/http/http_network_session.h"
#include "net/proxy_resolution/configured_proxy_resolution_service.h"
#include "net/proxy_resolution/proxy_config.h"
#include "net/proxy_resolution/proxy_list.h"
#include "net/socket/client_socket_pool_manager.h"
#include "net/socket/server_socket.h"
#include "net/socket/stream_socket.h"
#include "net/tools/naive/http_proxy_socket.h"
#include "net/tools/naive/naive_proxy_delegate.h"
#include "net/tools/naive/socks5_server_socket.h"
namespace net {
NaiveProxy::NaiveProxy(std::unique_ptr<ServerSocket> listen_socket,
ClientProtocol protocol,
const std::string& listen_user,
const std::string& listen_pass,
int concurrency,
RedirectResolver* resolver,
HttpNetworkSession* session,
const NetworkTrafficAnnotationTag& traffic_annotation)
: listen_socket_(std::move(listen_socket)),
protocol_(protocol),
listen_user_(listen_user),
listen_pass_(listen_pass),
concurrency_(concurrency),
resolver_(resolver),
session_(session),
net_log_(
NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)),
last_id_(0),
traffic_annotation_(traffic_annotation) {
const auto& proxy_config = static_cast<ConfiguredProxyResolutionService*>(
session_->proxy_resolution_service())
->config();
DCHECK(proxy_config);
const ProxyList& proxy_list =
proxy_config.value().value().proxy_rules().single_proxies;
DCHECK(!proxy_list.IsEmpty());
proxy_info_.UseProxyList(proxy_list);
proxy_info_.set_traffic_annotation(
net::MutableNetworkTrafficAnnotationTag(traffic_annotation_));
// See HttpStreamFactory::Job::DoInitConnectionImpl()
proxy_ssl_config_.disable_cert_verification_network_fetches = true;
server_ssl_config_.alpn_protos = session_->GetAlpnProtos();
proxy_ssl_config_.alpn_protos = session_->GetAlpnProtos();
server_ssl_config_.application_settings = session_->GetApplicationSettings();
proxy_ssl_config_.application_settings = session_->GetApplicationSettings();
server_ssl_config_.ignore_certificate_errors =
session_->params().ignore_certificate_errors;
proxy_ssl_config_.ignore_certificate_errors =
session_->params().ignore_certificate_errors;
// TODO(https://crbug.com/964642): Also enable 0-RTT for TLS proxies.
server_ssl_config_.early_data_enabled = session_->params().enable_early_data;
for (int i = 0; i < concurrency_; i++) {
network_isolation_keys_.push_back(NetworkIsolationKey::CreateTransient());
}
DCHECK(listen_socket_);
// Start accepting connections in next run loop in case when delegate is not
// ready to get callbacks.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&NaiveProxy::DoAcceptLoop,
weak_ptr_factory_.GetWeakPtr()));
}
NaiveProxy::~NaiveProxy() = default;
void NaiveProxy::DoAcceptLoop() {
int result;
do {
result = listen_socket_->Accept(
&accepted_socket_, base::BindRepeating(&NaiveProxy::OnAcceptComplete,
weak_ptr_factory_.GetWeakPtr()));
if (result == ERR_IO_PENDING)
return;
HandleAcceptResult(result);
} while (result == OK);
}
void NaiveProxy::OnAcceptComplete(int result) {
HandleAcceptResult(result);
if (result == OK)
DoAcceptLoop();
}
void NaiveProxy::HandleAcceptResult(int result) {
if (result != OK) {
LOG(ERROR) << "Accept error: rv=" << result;
return;
}
DoConnect();
}
void NaiveProxy::DoConnect() {
std::unique_ptr<StreamSocket> socket;
auto* proxy_delegate =
static_cast<NaiveProxyDelegate*>(session_->context().proxy_delegate);
DCHECK(proxy_delegate);
DCHECK(!proxy_info_.is_empty());
const auto& proxy_server = proxy_info_.proxy_server();
auto padding_detector_delegate = std::make_unique<PaddingDetectorDelegate>(
proxy_delegate, proxy_server, protocol_);
if (protocol_ == ClientProtocol::kSocks5) {
socket = std::make_unique<Socks5ServerSocket>(std::move(accepted_socket_),
listen_user_, listen_pass_,
traffic_annotation_);
} else if (protocol_ == ClientProtocol::kHttp) {
socket = std::make_unique<HttpProxySocket>(std::move(accepted_socket_),
padding_detector_delegate.get(),
traffic_annotation_);
} else if (protocol_ == ClientProtocol::kRedir) {
socket = std::move(accepted_socket_);
} else {
return;
}
last_id_++;
const auto& nik = network_isolation_keys_[last_id_ % concurrency_];
auto connection_ptr = std::make_unique<NaiveConnection>(
last_id_, protocol_, std::move(padding_detector_delegate), proxy_info_,
server_ssl_config_, proxy_ssl_config_, resolver_, session_, nik, net_log_,
std::move(socket), traffic_annotation_);
auto* connection = connection_ptr.get();
connection_by_id_[connection->id()] = std::move(connection_ptr);
int result = connection->Connect(
base::BindRepeating(&NaiveProxy::OnConnectComplete,
weak_ptr_factory_.GetWeakPtr(), connection->id()));
if (result == ERR_IO_PENDING)
return;
HandleConnectResult(connection, result);
}
void NaiveProxy::OnConnectComplete(unsigned int connection_id, int result) {
auto* connection = FindConnection(connection_id);
if (!connection)
return;
HandleConnectResult(connection, result);
}
void NaiveProxy::HandleConnectResult(NaiveConnection* connection, int result) {
if (result != OK) {
Close(connection->id(), result);
return;
}
DoRun(connection);
}
void NaiveProxy::DoRun(NaiveConnection* connection) {
int result = connection->Run(
base::BindRepeating(&NaiveProxy::OnRunComplete,
weak_ptr_factory_.GetWeakPtr(), connection->id()));
if (result == ERR_IO_PENDING)
return;
HandleRunResult(connection, result);
}
void NaiveProxy::OnRunComplete(unsigned int connection_id, int result) {
auto* connection = FindConnection(connection_id);
if (!connection)
return;
HandleRunResult(connection, result);
}
void NaiveProxy::HandleRunResult(NaiveConnection* connection, int result) {
Close(connection->id(), result);
}
void NaiveProxy::Close(unsigned int connection_id, int reason) {
auto it = connection_by_id_.find(connection_id);
if (it == connection_by_id_.end())
return;
LOG(INFO) << "Connection " << connection_id
<< " closed: " << ErrorToShortString(reason);
// The call stack might have callbacks which still have the pointer of
// connection. Instead of referencing connection with ID all the time,
// destroys the connection in next run loop to make sure any pending
// callbacks in the call stack return.
base::ThreadTaskRunnerHandle::Get()->DeleteSoon(FROM_HERE,
std::move(it->second));
connection_by_id_.erase(it);
}
NaiveConnection* NaiveProxy::FindConnection(unsigned int connection_id) {
auto it = connection_by_id_.find(connection_id);
if (it == connection_by_id_.end())
return nullptr;
return it->second.get();
}
} // namespace net

View File

@ -0,0 +1,89 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_H_
#define NET_TOOLS_NAIVE_NAIVE_PROXY_H_
#include <map>
#include <memory>
#include <vector>
#include "base/memory/weak_ptr.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/network_isolation_key.h"
#include "net/log/net_log_with_source.h"
#include "net/proxy_resolution/proxy_info.h"
#include "net/ssl/ssl_config.h"
#include "net/tools/naive/naive_connection.h"
#include "net/tools/naive/naive_protocol.h"
namespace net {
class ClientSocketHandle;
class HttpNetworkSession;
class NaiveConnection;
class ServerSocket;
class StreamSocket;
struct NetworkTrafficAnnotationTag;
class RedirectResolver;
class NaiveProxy {
public:
NaiveProxy(std::unique_ptr<ServerSocket> server_socket,
ClientProtocol protocol,
const std::string& listen_user,
const std::string& listen_pass,
int concurrency,
RedirectResolver* resolver,
HttpNetworkSession* session,
const NetworkTrafficAnnotationTag& traffic_annotation);
~NaiveProxy();
NaiveProxy(const NaiveProxy&) = delete;
NaiveProxy& operator=(const NaiveProxy&) = delete;
private:
void DoAcceptLoop();
void OnAcceptComplete(int result);
void HandleAcceptResult(int result);
void DoConnect();
void OnConnectComplete(unsigned int connection_id, int result);
void HandleConnectResult(NaiveConnection* connection, int result);
void DoRun(NaiveConnection* connection);
void OnRunComplete(unsigned int connection_id, int result);
void HandleRunResult(NaiveConnection* connection, int result);
void Close(unsigned int connection_id, int reason);
NaiveConnection* FindConnection(unsigned int connection_id);
std::unique_ptr<ServerSocket> listen_socket_;
ClientProtocol protocol_;
std::string listen_user_;
std::string listen_pass_;
int concurrency_;
ProxyInfo proxy_info_;
SSLConfig server_ssl_config_;
SSLConfig proxy_ssl_config_;
RedirectResolver* resolver_;
HttpNetworkSession* session_;
NetLogWithSource net_log_;
unsigned int last_id_;
std::unique_ptr<StreamSocket> accepted_socket_;
std::vector<NetworkIsolationKey> network_isolation_keys_;
std::map<unsigned int, std::unique_ptr<NaiveConnection>> connection_by_id_;
const NetworkTrafficAnnotationTag& traffic_annotation_;
base::WeakPtrFactory<NaiveProxy> weak_ptr_factory_{this};
};
} // namespace net
#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_H_

View File

@ -0,0 +1,592 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <cstdlib>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include "base/at_exit.h"
#include "base/command_line.h"
#include "base/feature_list.h"
#include "base/files/file_path.h"
#include "base/json/json_file_value_serializer.h"
#include "base/json/json_writer.h"
#include "base/logging.h"
#include "base/rand_util.h"
#include "base/run_loop.h"
#include "base/strings/escape.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/system/sys_info.h"
#include "base/task/single_thread_task_executor.h"
#include "base/task/thread_pool/thread_pool_instance.h"
#include "base/values.h"
#include "build/build_config.h"
#include "components/version_info/version_info.h"
#include "net/base/auth.h"
#include "net/base/network_isolation_key.h"
#include "net/base/url_util.h"
#include "net/cert/cert_verifier.h"
#include "net/cert_net/cert_net_fetcher_url_request.h"
#include "net/dns/host_resolver.h"
#include "net/dns/mapped_host_resolver.h"
#include "net/http/http_auth.h"
#include "net/http/http_auth_cache.h"
#include "net/http/http_network_session.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_transaction_factory.h"
#include "net/log/file_net_log_observer.h"
#include "net/log/net_log.h"
#include "net/log/net_log_capture_mode.h"
#include "net/log/net_log_entry.h"
#include "net/log/net_log_event_type.h"
#include "net/log/net_log_source.h"
#include "net/log/net_log_util.h"
#include "net/proxy_resolution/configured_proxy_resolution_service.h"
#include "net/proxy_resolution/proxy_config.h"
#include "net/proxy_resolution/proxy_config_service_fixed.h"
#include "net/proxy_resolution/proxy_config_with_annotation.h"
#include "net/socket/client_socket_pool_manager.h"
#include "net/socket/ssl_client_socket.h"
#include "net/socket/tcp_server_socket.h"
#include "net/socket/udp_server_socket.h"
#include "net/ssl/ssl_key_logger_impl.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
#include "net/tools/naive/naive_protocol.h"
#include "net/tools/naive/naive_proxy.h"
#include "net/tools/naive/naive_proxy_delegate.h"
#include "net/tools/naive/redirect_resolver.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/url_request_context_builder.h"
#include "url/gurl.h"
#include "url/scheme_host_port.h"
#include "url/url_util.h"
#if defined(OS_MACOSX)
#include "base/mac/scoped_nsautorelease_pool.h"
#endif
namespace {
constexpr int kListenBackLog = 512;
constexpr int kDefaultMaxSocketsPerPool = 256;
constexpr int kDefaultMaxSocketsPerGroup = 255;
constexpr int kExpectedMaxUsers = 8;
constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("naive", "");
struct CommandLine {
std::string listen;
std::string proxy;
std::string concurrency;
std::string extra_headers;
std::string host_resolver_rules;
std::string resolver_range;
bool no_log;
base::FilePath log;
base::FilePath log_net_log;
base::FilePath ssl_key_log_file;
};
struct Params {
net::ClientProtocol protocol;
std::string listen_user;
std::string listen_pass;
std::string listen_addr;
int listen_port;
int concurrency;
net::HttpRequestHeaders extra_headers;
std::string proxy_url;
std::u16string proxy_user;
std::u16string proxy_pass;
std::string host_resolver_rules;
net::IPAddress resolver_range;
size_t resolver_prefix;
logging::LoggingSettings log_settings;
base::FilePath net_log_path;
base::FilePath ssl_key_path;
};
std::unique_ptr<base::Value> GetConstants() {
auto constants_dict = std::make_unique<base::Value>(net::GetNetConstants());
base::DictionaryValue dict;
std::string os_type = base::StringPrintf(
"%s: %s (%s)", base::SysInfo::OperatingSystemName().c_str(),
base::SysInfo::OperatingSystemVersion().c_str(),
base::SysInfo::OperatingSystemArchitecture().c_str());
dict.SetStringPath("os_type", os_type);
constants_dict->SetKey("clientInfo", std::move(dict));
return constants_dict;
}
void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) {
if (proc.HasSwitch("h") || proc.HasSwitch("help")) {
std::cout << "Usage: naive { OPTIONS | config.json }\n"
"\n"
"Options:\n"
"-h, --help Show this message\n"
"--version Print version\n"
"--listen=<proto>://[addr][:port]\n"
" proto: socks, http\n"
" redir (Linux only)\n"
"--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n"
" proto: https, quic\n"
"--insecure-concurrency=<N> Use N connections, insecure\n"
"--extra-headers=... Extra headers split by CRLF\n"
"--host-resolver-rules=... Resolver rules\n"
"--resolver-range=... Redirect resolver range\n"
"--log[=<path>] Log to stderr, or file\n"
"--log-net-log=<path> Save NetLog\n"
"--ssl-key-log-file=<path> Save SSL keys for Wireshark\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
if (proc.HasSwitch("version")) {
std::cout << "naive " << version_info::GetVersionNumber() << std::endl;
exit(EXIT_SUCCESS);
}
cmdline->listen = proc.GetSwitchValueASCII("listen");
cmdline->proxy = proc.GetSwitchValueASCII("proxy");
cmdline->concurrency = proc.GetSwitchValueASCII("insecure-concurrency");
cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers");
cmdline->host_resolver_rules =
proc.GetSwitchValueASCII("host-resolver-rules");
cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range");
cmdline->no_log = !proc.HasSwitch("log");
cmdline->log = proc.GetSwitchValuePath("log");
cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log");
cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file");
}
void GetCommandLineFromConfig(const base::FilePath& config_path,
CommandLine* cmdline) {
JSONFileValueDeserializer reader(config_path);
int error_code;
std::string error_message;
auto value = reader.Deserialize(&error_code, &error_message);
if (value == nullptr) {
std::cerr << "Error reading " << config_path << ": (" << error_code << ") "
<< error_message << std::endl;
exit(EXIT_FAILURE);
}
if (!value->is_dict()) {
std::cerr << "Invalid config format" << std::endl;
exit(EXIT_FAILURE);
}
const auto* listen = value->FindStringKey("listen");
if (listen) {
cmdline->listen = *listen;
}
const auto* proxy = value->FindStringKey("proxy");
if (proxy) {
cmdline->proxy = *proxy;
}
const auto* concurrency = value->FindStringKey("insecure-concurrency");
if (concurrency) {
cmdline->concurrency = *concurrency;
}
const auto* extra_headers = value->FindStringKey("extra-headers");
if (extra_headers) {
cmdline->extra_headers = *extra_headers;
}
const auto* host_resolver_rules = value->FindStringKey("host-resolver-rules");
if (host_resolver_rules) {
cmdline->host_resolver_rules = *host_resolver_rules;
}
const auto* resolver_range = value->FindStringKey("resolver-range");
if (resolver_range) {
cmdline->resolver_range = *resolver_range;
}
cmdline->no_log = true;
const auto* log = value->FindStringKey("log");
if (log) {
cmdline->no_log = false;
cmdline->log = base::FilePath::FromUTF8Unsafe(*log);
}
const auto* log_net_log = value->FindStringKey("log-net-log");
if (log_net_log) {
cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(*log_net_log);
}
const auto* ssl_key_log_file = value->FindStringKey("ssl-key-log-file");
if (ssl_key_log_file) {
cmdline->ssl_key_log_file =
base::FilePath::FromUTF8Unsafe(*ssl_key_log_file);
}
}
std::string GetProxyFromURL(const GURL& url) {
std::string str = url.GetWithEmptyPath().spec();
if (str.size() && str.back() == '/') {
str.pop_back();
}
return str;
}
bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
params->protocol = net::ClientProtocol::kSocks5;
params->listen_addr = "0.0.0.0";
params->listen_port = 1080;
url::AddStandardScheme("socks",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT);
if (!cmdline.listen.empty()) {
GURL url(cmdline.listen);
if (url.scheme() == "socks") {
params->protocol = net::ClientProtocol::kSocks5;
params->listen_port = 1080;
} else if (url.scheme() == "http") {
params->protocol = net::ClientProtocol::kHttp;
params->listen_port = 8080;
} else if (url.scheme() == "redir") {
#if defined(OS_LINUX)
params->protocol = net::ClientProtocol::kRedir;
params->listen_port = 1080;
#else
std::cerr << "Redir protocol only supports Linux." << std::endl;
return false;
#endif
} else {
std::cerr << "Invalid scheme in --listen" << std::endl;
return false;
}
if (!url.username().empty()) {
params->listen_user = base::UnescapeBinaryURLComponent(url.username());
}
if (!url.password().empty()) {
params->listen_pass = base::UnescapeBinaryURLComponent(url.password());
}
if (!url.host().empty()) {
params->listen_addr = url.HostNoBrackets();
}
if (!url.port().empty()) {
if (!base::StringToInt(url.port(), &params->listen_port)) {
std::cerr << "Invalid port in --listen" << std::endl;
return false;
}
if (params->listen_port <= 0 ||
params->listen_port > std::numeric_limits<uint16_t>::max()) {
std::cerr << "Invalid port in --listen" << std::endl;
return false;
}
}
}
params->proxy_url = "direct://";
GURL url(cmdline.proxy);
GURL::Replacements remove_auth;
remove_auth.ClearUsername();
remove_auth.ClearPassword();
GURL url_no_auth = url.ReplaceComponents(remove_auth);
if (!cmdline.proxy.empty()) {
if (!url.is_valid()) {
std::cerr << "Invalid proxy URL" << std::endl;
return false;
}
params->proxy_url = GetProxyFromURL(url_no_auth);
net::GetIdentityFromURL(url, &params->proxy_user, &params->proxy_pass);
}
if (!cmdline.concurrency.empty()) {
if (!base::StringToInt(cmdline.concurrency, &params->concurrency) ||
params->concurrency < 1) {
std::cerr << "Invalid concurrency" << std::endl;
return false;
}
} else {
params->concurrency = 1;
}
params->extra_headers.AddHeadersFromString(cmdline.extra_headers);
params->host_resolver_rules = cmdline.host_resolver_rules;
if (params->protocol == net::ClientProtocol::kRedir) {
std::string range = "100.64.0.0/10";
if (!cmdline.resolver_range.empty())
range = cmdline.resolver_range;
if (!net::ParseCIDRBlock(range, &params->resolver_range,
&params->resolver_prefix)) {
std::cerr << "Invalid resolver range" << std::endl;
return false;
}
if (params->resolver_range.IsIPv6()) {
std::cerr << "IPv6 resolver range not supported" << std::endl;
return false;
}
}
if (!cmdline.no_log) {
if (!cmdline.log.empty()) {
params->log_settings.logging_dest = logging::LOG_TO_FILE;
params->log_settings.log_file_path = cmdline.log.value().c_str();
} else {
params->log_settings.logging_dest = logging::LOG_TO_STDERR;
}
} else {
params->log_settings.logging_dest = logging::LOG_NONE;
}
params->net_log_path = cmdline.log_net_log;
params->ssl_key_path = cmdline.ssl_key_log_file;
return true;
}
} // namespace
namespace net {
namespace {
// NetLog::ThreadSafeObserver implementation that simply prints events
// to the logs.
class PrintingLogObserver : public NetLog::ThreadSafeObserver {
public:
PrintingLogObserver() = default;
PrintingLogObserver(const PrintingLogObserver&) = delete;
PrintingLogObserver& operator=(const PrintingLogObserver&) = delete;
~PrintingLogObserver() override {
// This is guaranteed to be safe as this program is single threaded.
net_log()->RemoveObserver(this);
}
// NetLog::ThreadSafeObserver implementation:
void OnAddEntry(const NetLogEntry& entry) override {
switch (entry.type) {
case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS:
case NetLogEventType::SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP:
case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_SESSION_SEND_WINDOW:
case NetLogEventType::HTTP2_SESSION_STREAM_STALLED_BY_STREAM_SEND_WINDOW:
case NetLogEventType::HTTP2_SESSION_STALLED_MAX_STREAMS:
case NetLogEventType::HTTP2_STREAM_FLOW_CONTROL_UNSTALLED:
break;
default:
return;
}
const char* source_type = NetLog::SourceTypeToString(entry.source.type);
const char* event_type = NetLogEventTypeToString(entry.type);
const char* event_phase = NetLog::EventPhaseToString(entry.phase);
base::Value params(entry.ToValue());
std::string params_str;
base::JSONWriter::Write(params, &params_str);
params_str.insert(0, ": ");
VLOG(1) << source_type << "(" << entry.source.id << "): " << event_type
<< ": " << event_phase << params_str;
}
};
} // namespace
namespace {
std::unique_ptr<URLRequestContext> BuildCertURLRequestContext(NetLog* net_log) {
URLRequestContextBuilder builder;
builder.DisableHttpCache();
builder.set_net_log(net_log);
ProxyConfig proxy_config;
auto proxy_service =
ConfiguredProxyResolutionService::CreateWithoutProxyResolver(
std::make_unique<ProxyConfigServiceFixed>(
ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)),
net_log);
proxy_service->ForceReloadProxyConfig();
builder.set_proxy_resolution_service(std::move(proxy_service));
return builder.Build();
}
// Builds a URLRequestContext assuming there's only a single loop.
std::unique_ptr<URLRequestContext> BuildURLRequestContext(
const Params& params,
scoped_refptr<CertNetFetcherURLRequest> cert_net_fetcher,
NetLog* net_log) {
URLRequestContextBuilder builder;
builder.DisableHttpCache();
builder.set_net_log(net_log);
ProxyConfig proxy_config;
proxy_config.proxy_rules().ParseFromString(params.proxy_url);
LOG(INFO) << "Proxying via " << params.proxy_url;
auto proxy_service =
ConfiguredProxyResolutionService::CreateWithoutProxyResolver(
std::make_unique<ProxyConfigServiceFixed>(
ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)),
net_log);
proxy_service->ForceReloadProxyConfig();
builder.set_proxy_resolution_service(std::move(proxy_service));
if (!params.host_resolver_rules.empty()) {
builder.set_host_mapping_rules(params.host_resolver_rules);
}
builder.SetCertVerifier(
CertVerifier::CreateDefault(std::move(cert_net_fetcher)));
builder.set_proxy_delegate(
std::make_unique<NaiveProxyDelegate>(params.extra_headers));
auto context = builder.Build();
if (!params.proxy_url.empty() && !params.proxy_user.empty() &&
!params.proxy_pass.empty()) {
auto* session = context->http_transaction_factory()->GetSession();
auto* auth_cache = session->http_auth_cache();
std::string proxy_url = params.proxy_url;
GURL proxy_gurl(proxy_url);
if (proxy_url.compare(0, 7, "quic://") == 0) {
proxy_url.replace(0, 4, "https");
proxy_gurl = GURL(proxy_url);
auto* quic = context->quic_context()->params();
quic->supported_versions = {quic::ParsedQuicVersion::RFCv1()};
quic->origins_to_force_quic_on.insert(
net::HostPortPair::FromURL(proxy_gurl));
}
url::SchemeHostPort auth_origin(proxy_gurl);
AuthCredentials credentials(params.proxy_user, params.proxy_pass);
auth_cache->Add(auth_origin, HttpAuth::AUTH_PROXY,
/*realm=*/{}, HttpAuth::AUTH_SCHEME_BASIC, {},
/*challenge=*/"Basic", credentials, /*path=*/"/");
}
return context;
}
} // namespace
} // namespace net
int main(int argc, char* argv[]) {
url::AddStandardScheme("quic",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
base::FeatureList::InitializeInstance(
"PartitionConnectionsByNetworkIsolationKey", std::string());
base::SingleThreadTaskExecutor io_task_executor(base::MessagePumpType::IO);
base::ThreadPoolInstance::CreateAndStartWithDefaultParams("naive");
base::AtExitManager exit_manager;
#if defined(OS_MACOSX)
base::mac::ScopedNSAutoreleasePool pool;
#endif
base::CommandLine::Init(argc, argv);
CommandLine cmdline;
Params params;
const auto& proc = *base::CommandLine::ForCurrentProcess();
const auto& args = proc.GetArgs();
if (args.empty()) {
if (proc.argv().size() >= 2) {
GetCommandLine(proc, &cmdline);
} else {
auto path = base::FilePath::FromUTF8Unsafe("config.json");
GetCommandLineFromConfig(path, &cmdline);
}
} else {
base::FilePath path(args[0]);
GetCommandLineFromConfig(path, &cmdline);
}
if (!ParseCommandLine(cmdline, &params)) {
return EXIT_FAILURE;
}
net::ClientSocketPoolManager::set_max_sockets_per_pool(
net::HttpNetworkSession::NORMAL_SOCKET_POOL,
kDefaultMaxSocketsPerPool * kExpectedMaxUsers);
net::ClientSocketPoolManager::set_max_sockets_per_proxy_server(
net::HttpNetworkSession::NORMAL_SOCKET_POOL,
kDefaultMaxSocketsPerPool * kExpectedMaxUsers);
net::ClientSocketPoolManager::set_max_sockets_per_group(
net::HttpNetworkSession::NORMAL_SOCKET_POOL,
kDefaultMaxSocketsPerGroup * kExpectedMaxUsers);
CHECK(logging::InitLogging(params.log_settings));
if (!params.ssl_key_path.empty()) {
net::SSLClientSocket::SetSSLKeyLogger(
std::make_unique<net::SSLKeyLoggerImpl>(params.ssl_key_path));
}
// The declaration order for net_log and printing_log_observer is
// important. The destructor of PrintingLogObserver removes itself
// from net_log, so net_log must be available for entire lifetime of
// printing_log_observer.
net::NetLog* net_log = net::NetLog::Get();
std::unique_ptr<net::FileNetLogObserver> observer;
if (!params.net_log_path.empty()) {
observer = net::FileNetLogObserver::CreateUnbounded(
params.net_log_path, net::NetLogCaptureMode::kDefault, GetConstants());
observer->StartObserving(net_log);
}
// Avoids net log overhead if verbose logging is disabled.
std::unique_ptr<net::PrintingLogObserver> printing_log_observer;
if (params.log_settings.logging_dest != logging::LOG_NONE && VLOG_IS_ON(1)) {
printing_log_observer = std::make_unique<net::PrintingLogObserver>();
net_log->AddObserver(printing_log_observer.get(),
net::NetLogCaptureMode::kDefault);
}
auto cert_context = net::BuildCertURLRequestContext(net_log);
scoped_refptr<net::CertNetFetcherURLRequest> cert_net_fetcher;
// The builtin verifier is supported but not enabled by default on Mac,
// falling back to CreateSystemVerifyProc() which drops the net fetcher.
// Skips defined(OS_MAC) for now, until it is enabled by default.
#if defined(OS_LINUX) || defined(OS_ANDROID)
cert_net_fetcher = base::MakeRefCounted<net::CertNetFetcherURLRequest>();
cert_net_fetcher->SetURLRequestContext(cert_context.get());
#endif
auto context =
net::BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log);
auto* session = context->http_transaction_factory()->GetSession();
auto listen_socket =
std::make_unique<net::TCPServerSocket>(net_log, net::NetLogSource());
int result = listen_socket->ListenWithAddressAndPort(
params.listen_addr, params.listen_port, kListenBackLog);
if (result != net::OK) {
LOG(ERROR) << "Failed to listen: " << result;
return EXIT_FAILURE;
}
LOG(INFO) << "Listening on " << params.listen_addr << ":"
<< params.listen_port;
std::unique_ptr<net::RedirectResolver> resolver;
if (params.protocol == net::ClientProtocol::kRedir) {
auto resolver_socket =
std::make_unique<net::UDPServerSocket>(net_log, net::NetLogSource());
resolver_socket->AllowAddressReuse();
net::IPAddress listen_addr;
if (!listen_addr.AssignFromIPLiteral(params.listen_addr)) {
LOG(ERROR) << "Failed to open resolver: " << net::ERR_ADDRESS_INVALID;
return EXIT_FAILURE;
}
result = resolver_socket->Listen(
net::IPEndPoint(listen_addr, params.listen_port));
if (result != net::OK) {
LOG(ERROR) << "Failed to open resolver: " << result;
return EXIT_FAILURE;
}
resolver = std::make_unique<net::RedirectResolver>(
std::move(resolver_socket), params.resolver_range,
params.resolver_prefix);
}
net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol,
params.listen_user, params.listen_pass,
params.concurrency, resolver.get(), session,
kTrafficAnnotation);
base::RunLoop().Run();
return EXIT_SUCCESS;
}

View File

@ -0,0 +1,159 @@
// Copyright 2020 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/naive_proxy_delegate.h"
#include <string>
#include "base/logging.h"
#include "base/rand_util.h"
#include "net/base/proxy_string_util.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/third_party/quiche/src/quiche/spdy/core/hpack/hpack_constants.h"
namespace net {
namespace {
bool g_nonindex_codes_initialized;
uint8_t g_nonindex_codes[17];
} // namespace
void InitializeNonindexCodes() {
if (g_nonindex_codes_initialized)
return;
g_nonindex_codes_initialized = true;
unsigned i = 0;
for (const auto& symbol : spdy::HpackHuffmanCodeVector()) {
if (symbol.id >= 0x20 && symbol.id <= 0x7f && symbol.length >= 8) {
g_nonindex_codes[i++] = symbol.id;
if (i >= sizeof(g_nonindex_codes))
break;
}
}
CHECK(i == sizeof(g_nonindex_codes));
}
void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len) {
DCHECK(g_nonindex_codes_initialized);
int first = len < 16 ? len : 16;
for (int i = 0; i < first; i++) {
buf[i] = g_nonindex_codes[unique_bits & 0b1111];
unique_bits >>= 4;
}
for (int i = first; i < len; i++) {
buf[i] = g_nonindex_codes[16];
}
}
NaiveProxyDelegate::NaiveProxyDelegate(const HttpRequestHeaders& extra_headers)
: extra_headers_(extra_headers) {
InitializeNonindexCodes();
}
NaiveProxyDelegate::~NaiveProxyDelegate() = default;
void NaiveProxyDelegate::OnBeforeTunnelRequest(
const ProxyServer& proxy_server,
HttpRequestHeaders* extra_headers) {
if (proxy_server.is_direct() || proxy_server.is_socks())
return;
// Sends client-side padding header regardless of server support
std::string padding(base::RandInt(16, 32), '~');
FillNonindexHeaderValue(base::RandUint64(), &padding[0], padding.size());
extra_headers->SetHeader("padding", padding);
// Enables Fast Open in H2/H3 proxy client socket once the state of server
// padding support is known.
if (padding_state_by_server_[proxy_server] != PaddingSupport::kUnknown) {
extra_headers->SetHeader("fastopen", "1");
}
extra_headers->MergeFrom(extra_headers_);
}
Error NaiveProxyDelegate::OnTunnelHeadersReceived(
const ProxyServer& proxy_server,
const HttpResponseHeaders& response_headers) {
if (proxy_server.is_direct() || proxy_server.is_socks())
return OK;
// Detects server padding support, even if it changes dynamically.
bool padding = response_headers.HasHeader("padding");
auto new_state =
padding ? PaddingSupport::kCapable : PaddingSupport::kIncapable;
auto& padding_state = padding_state_by_server_[proxy_server];
if (padding_state == PaddingSupport::kUnknown || padding_state != new_state) {
LOG(INFO) << "Padding capability of " << ProxyServerToProxyUri(proxy_server)
<< (padding ? " detected" : " undetected");
}
padding_state = new_state;
return OK;
}
PaddingSupport NaiveProxyDelegate::GetProxyServerPaddingSupport(
const ProxyServer& proxy_server) {
// Not possible to detect padding capability given underlying protocol.
if (proxy_server.is_direct() || proxy_server.is_socks())
return PaddingSupport::kIncapable;
return padding_state_by_server_[proxy_server];
}
PaddingDetectorDelegate::PaddingDetectorDelegate(
NaiveProxyDelegate* naive_proxy_delegate,
const ProxyServer& proxy_server,
ClientProtocol client_protocol)
: naive_proxy_delegate_(naive_proxy_delegate),
proxy_server_(proxy_server),
client_protocol_(client_protocol),
detected_client_padding_support_(PaddingSupport::kUnknown),
cached_server_padding_support_(PaddingSupport::kUnknown) {}
PaddingDetectorDelegate::~PaddingDetectorDelegate() = default;
bool PaddingDetectorDelegate::IsPaddingSupportKnown() {
auto c = GetClientPaddingSupport();
auto s = GetServerPaddingSupport();
return c != PaddingSupport::kUnknown && s != PaddingSupport::kUnknown;
}
Direction PaddingDetectorDelegate::GetPaddingDirection() {
auto c = GetClientPaddingSupport();
auto s = GetServerPaddingSupport();
// Padding support must be already detected at this point.
CHECK_NE(c, PaddingSupport::kUnknown);
CHECK_NE(s, PaddingSupport::kUnknown);
if (c == PaddingSupport::kCapable && s == PaddingSupport::kIncapable) {
return kServer;
}
if (c == PaddingSupport::kIncapable && s == PaddingSupport::kCapable) {
return kClient;
}
return kNone;
}
void PaddingDetectorDelegate::SetClientPaddingSupport(
PaddingSupport padding_support) {
detected_client_padding_support_ = padding_support;
}
PaddingSupport PaddingDetectorDelegate::GetClientPaddingSupport() {
// Not possible to detect padding capability given underlying protocol.
if (client_protocol_ == ClientProtocol::kSocks5) {
return PaddingSupport::kIncapable;
} else if (client_protocol_ == ClientProtocol::kRedir) {
return PaddingSupport::kIncapable;
}
return detected_client_padding_support_;
}
PaddingSupport PaddingDetectorDelegate::GetServerPaddingSupport() {
if (cached_server_padding_support_ != PaddingSupport::kUnknown)
return cached_server_padding_support_;
cached_server_padding_support_ =
naive_proxy_delegate_->GetProxyServerPaddingSupport(proxy_server_);
return cached_server_padding_support_;
}
} // namespace net

View File

@ -0,0 +1,94 @@
// Copyright 2020 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_
#define NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_
#include <cstdint>
#include <map>
#include <string>
#include "base/strings/string_piece.h"
#include "net/base/net_errors.h"
#include "net/base/proxy_delegate.h"
#include "net/base/proxy_server.h"
#include "net/proxy_resolution/proxy_retry_info.h"
#include "net/tools/naive/naive_protocol.h"
#include "url/gurl.h"
namespace net {
void InitializeNonindexCodes();
// |unique_bits| SHOULD have relatively unique values.
void FillNonindexHeaderValue(uint64_t unique_bits, char* buf, int len);
class ProxyInfo;
class HttpRequestHeaders;
class HttpResponseHeaders;
enum class PaddingSupport {
kUnknown = 0,
kCapable,
kIncapable,
};
class NaiveProxyDelegate : public ProxyDelegate {
public:
explicit NaiveProxyDelegate(const HttpRequestHeaders& extra_headers);
~NaiveProxyDelegate() override;
void OnResolveProxy(const GURL& url,
const std::string& method,
const ProxyRetryInfoMap& proxy_retry_info,
ProxyInfo* result) override {}
void OnFallback(const ProxyServer& bad_proxy, int net_error) override {}
// This only affects h2 proxy client socket.
void OnBeforeTunnelRequest(const ProxyServer& proxy_server,
HttpRequestHeaders* extra_headers) override;
Error OnTunnelHeadersReceived(
const ProxyServer& proxy_server,
const HttpResponseHeaders& response_headers) override;
PaddingSupport GetProxyServerPaddingSupport(const ProxyServer& proxy_server);
private:
const HttpRequestHeaders& extra_headers_;
std::map<ProxyServer, PaddingSupport> padding_state_by_server_;
};
class ClientPaddingDetectorDelegate {
public:
virtual ~ClientPaddingDetectorDelegate() = default;
virtual void SetClientPaddingSupport(PaddingSupport padding_support) = 0;
};
class PaddingDetectorDelegate : public ClientPaddingDetectorDelegate {
public:
PaddingDetectorDelegate(NaiveProxyDelegate* naive_proxy_delegate,
const ProxyServer& proxy_server,
ClientProtocol client_protocol);
~PaddingDetectorDelegate() override;
bool IsPaddingSupportKnown();
Direction GetPaddingDirection();
void SetClientPaddingSupport(PaddingSupport padding_support) override;
private:
PaddingSupport GetClientPaddingSupport();
PaddingSupport GetServerPaddingSupport();
NaiveProxyDelegate* naive_proxy_delegate_;
const ProxyServer& proxy_server_;
ClientProtocol client_protocol_;
PaddingSupport detected_client_padding_support_;
// The result is only cached during one connection, so it's still dynamically
// updated in the following connections after server changes support.
PaddingSupport cached_server_padding_support_;
};
} // namespace net
#endif // NET_TOOLS_NAIVE_NAIVE_PROXY_DELEGATE_H_

View File

@ -0,0 +1,246 @@
// Copyright 2019 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/redirect_resolver.h"
#include <cstring>
#include <iterator>
#include <utility>
#include "base/logging.h"
#include "base/threading/thread_task_runner_handle.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/dns/dns_query.h"
#include "net/dns/dns_response.h"
#include "net/dns/dns_util.h"
#include "net/socket/datagram_server_socket.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace {
constexpr int kUdpReadBufferSize = 1024;
constexpr int kResolutionTtl = 60;
constexpr int kResolutionRecycleTime = 60 * 5;
std::string PackedIPv4ToString(uint32_t addr) {
return net::IPAddress(addr >> 24, addr >> 16, addr >> 8, addr).ToString();
}
} // namespace
namespace net {
Resolution::Resolution() = default;
Resolution::~Resolution() = default;
RedirectResolver::RedirectResolver(std::unique_ptr<DatagramServerSocket> socket,
const IPAddress& range,
size_t prefix)
: socket_(std::move(socket)),
range_(range),
prefix_(prefix),
offset_(0),
buffer_(base::MakeRefCounted<IOBufferWithSize>(kUdpReadBufferSize)) {
DCHECK(socket_);
// Start accepting connections in next run loop in case when delegate is not
// ready to get callbacks.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(&RedirectResolver::DoRead,
weak_ptr_factory_.GetWeakPtr()));
}
RedirectResolver::~RedirectResolver() = default;
void RedirectResolver::DoRead() {
for (;;) {
int rv = socket_->RecvFrom(
buffer_.get(), kUdpReadBufferSize, &recv_address_,
base::BindOnce(&RedirectResolver::OnRecv, base::Unretained(this)));
if (rv == ERR_IO_PENDING)
return;
rv = HandleReadResult(rv);
if (rv == ERR_IO_PENDING)
return;
if (rv < 0) {
LOG(INFO) << "DoRead: ignoring error " << rv;
}
}
}
void RedirectResolver::OnRecv(int result) {
int rv;
rv = HandleReadResult(result);
if (rv == ERR_IO_PENDING)
return;
if (rv < 0) {
LOG(INFO) << "OnRecv: ignoring error " << result;
}
DoRead();
}
void RedirectResolver::OnSend(int result) {
if (result < 0) {
LOG(INFO) << "OnSend: ignoring error " << result;
}
DoRead();
}
int RedirectResolver::HandleReadResult(int result) {
if (result < 0)
return result;
DnsQuery query(buffer_.get());
if (!query.Parse(result)) {
LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString();
return ERR_INVALID_ARGUMENT;
}
int size;
if (query.qtype() == dns_protocol::kTypeA) {
Resolution res;
auto name_or = DnsDomainToString(query.qname());
if (!name_or) {
LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString();
return ERR_INVALID_ARGUMENT;
}
const auto& name = name_or.value();
auto by_name_lookup = resolution_by_name_.emplace(name, resolutions_.end());
auto by_name = by_name_lookup.first;
bool has_name = !by_name_lookup.second;
if (has_name) {
auto res_it = by_name->second;
auto by_addr = res_it->by_addr;
uint32_t addr = res_it->addr;
resolutions_.erase(res_it);
resolutions_.emplace_back();
res_it = std::prev(resolutions_.end());
by_name->second = res_it;
by_addr->second = res_it;
res_it->addr = addr;
res_it->name = name;
res_it->time = base::TimeTicks::Now();
res_it->by_name = by_name;
res_it->by_addr = by_addr;
} else {
uint32_t addr = (range_.bytes()[0] << 24) | (range_.bytes()[1] << 16) |
(range_.bytes()[2] << 8) | range_.bytes()[3];
uint32_t subnet = ~0U >> prefix_;
addr &= ~subnet;
addr += offset_;
offset_ = (offset_ + 1) & subnet;
auto by_addr_lookup =
resolution_by_addr_.emplace(addr, resolutions_.end());
auto by_addr = by_addr_lookup.first;
bool has_addr = !by_addr_lookup.second;
if (has_addr) {
// Too few available addresses. Overwrites old one.
auto res_it = by_addr->second;
LOG(INFO) << "Overwrite " << res_it->name << " "
<< PackedIPv4ToString(res_it->addr) << " with " << name << " "
<< PackedIPv4ToString(addr);
resolution_by_name_.erase(res_it->by_name);
resolutions_.erase(res_it);
resolutions_.emplace_back();
res_it = std::prev(resolutions_.end());
by_name->second = res_it;
by_addr->second = res_it;
res_it->addr = addr;
res_it->name = name;
res_it->time = base::TimeTicks::Now();
res_it->by_name = by_name;
res_it->by_addr = by_addr;
} else {
LOG(INFO) << "Add " << name << " " << PackedIPv4ToString(addr);
resolutions_.emplace_back();
auto res_it = std::prev(resolutions_.end());
by_name->second = res_it;
by_addr->second = res_it;
res_it->addr = addr;
res_it->name = name;
res_it->time = base::TimeTicks::Now();
res_it->by_name = by_name;
res_it->by_addr = by_addr;
// Collects garbage.
auto now = base::TimeTicks::Now();
for (auto it = resolutions_.begin();
it != resolutions_.end() &&
(now - it->time).InSeconds() > kResolutionRecycleTime;) {
auto next = std::next(it);
LOG(INFO) << "Drop " << it->name << " "
<< PackedIPv4ToString(it->addr);
resolution_by_name_.erase(it->by_name);
resolution_by_addr_.erase(it->by_addr);
resolutions_.erase(it);
it = next;
}
}
}
DnsResourceRecord record;
record.name = name;
record.type = dns_protocol::kTypeA;
record.klass = dns_protocol::kClassIN;
record.ttl = kResolutionTtl;
uint32_t addr = by_name->second->addr;
record.SetOwnedRdata(IPAddressToPackedString(
IPAddress(addr >> 24, addr >> 16, addr >> 8, addr)));
absl::optional<DnsQuery> query_opt;
query_opt.emplace(query.id(), query.qname(), query.qtype());
DnsResponse response(query.id(), /*is_authoritative=*/false,
/*answers=*/{std::move(record)},
/*authority_records=*/{}, /*additional_records=*/{},
query_opt);
size = response.io_buffer_size();
if (size > buffer_->size() || !response.io_buffer()) {
return ERR_NO_BUFFER_SPACE;
}
std::memcpy(buffer_->data(), response.io_buffer()->data(), size);
} else {
absl::optional<DnsQuery> query_opt;
query_opt.emplace(query.id(), query.qname(), query.qtype());
DnsResponse response(query.id(), /*is_authoritative=*/false, /*answers=*/{},
/*authority_records=*/{}, /*additional_records=*/{},
query_opt, dns_protocol::kRcodeSERVFAIL);
size = response.io_buffer_size();
if (size > buffer_->size() || !response.io_buffer()) {
return ERR_NO_BUFFER_SPACE;
}
std::memcpy(buffer_->data(), response.io_buffer()->data(), size);
}
return socket_->SendTo(
buffer_.get(), size, recv_address_,
base::BindOnce(&RedirectResolver::OnSend, base::Unretained(this)));
}
bool RedirectResolver::IsInResolvedRange(const IPAddress& address) const {
if (!address.IsIPv4())
return false;
return IPAddressMatchesPrefix(address, range_, prefix_);
}
std::string RedirectResolver::FindNameByAddress(
const IPAddress& address) const {
if (!address.IsIPv4())
return {};
uint32_t addr = (address.bytes()[0] << 24) | (address.bytes()[1] << 16) |
(address.bytes()[2] << 8) | address.bytes()[3];
auto by_addr = resolution_by_addr_.find(addr);
if (by_addr == resolution_by_addr_.end())
return {};
return by_addr->second->name;
}
} // namespace net

View File

@ -0,0 +1,69 @@
// Copyright 2019 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_
#define NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_
#include <cstdint>
#include <list>
#include <map>
#include <memory>
#include <string>
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "net/base/ip_address.h"
#include "net/base/ip_endpoint.h"
namespace net {
class DatagramServerSocket;
class IOBufferWithSize;
struct Resolution {
Resolution();
~Resolution();
uint32_t addr;
std::string name;
base::TimeTicks time;
std::map<std::string, std::list<Resolution>::iterator>::iterator by_name;
std::map<uint32_t, std::list<Resolution>::iterator>::iterator by_addr;
};
class RedirectResolver {
public:
RedirectResolver(std::unique_ptr<DatagramServerSocket> socket,
const IPAddress& range,
size_t prefix);
~RedirectResolver();
RedirectResolver(const RedirectResolver&) = delete;
RedirectResolver& operator=(const RedirectResolver&) = delete;
bool IsInResolvedRange(const IPAddress& address) const;
std::string FindNameByAddress(const IPAddress& address) const;
private:
void DoRead();
void OnRecv(int result);
void OnSend(int result);
int HandleReadResult(int result);
std::unique_ptr<DatagramServerSocket> socket_;
IPAddress range_;
size_t prefix_;
uint32_t offset_;
scoped_refptr<IOBufferWithSize> buffer_;
IPEndPoint recv_address_;
std::map<std::string, std::list<Resolution>::iterator> resolution_by_name_;
std::map<uint32_t, std::list<Resolution>::iterator> resolution_by_addr_;
std::list<Resolution> resolutions_;
base::WeakPtrFactory<RedirectResolver> weak_ptr_factory_{this};
};
} // namespace net
#endif // NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_

View File

@ -0,0 +1,669 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/socks5_server_socket.h"
#include <cstring>
#include <utility>
#include "base/bind.h"
#include "base/callback_helpers.h"
#include "base/cxx17_backports.h"
#include "base/logging.h"
#include "base/sys_byteorder.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
#include "net/base/sys_addrinfo.h"
#include "net/log/net_log.h"
#include "net/log/net_log_event_type.h"
namespace net {
enum SocksCommandType {
kCommandConnect = 0x01,
kCommandBind = 0x02,
kCommandUDPAssociate = 0x03,
};
static constexpr unsigned int kGreetReadHeaderSize = 2;
static constexpr unsigned int kAuthReadHeaderSize = 2;
static constexpr unsigned int kReadHeaderSize = 5;
static constexpr char kSOCKS5Version = '\x05';
static constexpr char kSOCKS5Reserved = '\x00';
static constexpr char kAuthMethodNone = '\x00';
static constexpr char kAuthMethodUserPass = '\x02';
static constexpr char kAuthMethodNoAcceptable = '\xff';
static constexpr char kSubnegotiationVersion = '\x01';
static constexpr char kAuthStatusSuccess = '\x00';
static constexpr char kAuthStatusFailure = '\xff';
static constexpr char kReplySuccess = '\x00';
static constexpr char kReplyCommandNotSupported = '\x07';
static_assert(sizeof(struct in_addr) == 4, "incorrect system size of IPv4");
static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
Socks5ServerSocket::Socks5ServerSocket(
std::unique_ptr<StreamSocket> transport_socket,
const std::string& user,
const std::string& pass,
const NetworkTrafficAnnotationTag& traffic_annotation)
: io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete,
base::Unretained(this))),
transport_(std::move(transport_socket)),
next_state_(STATE_NONE),
completed_handshake_(false),
bytes_sent_(0),
was_ever_used_(false),
user_(user),
pass_(pass),
net_log_(transport_->NetLog()),
traffic_annotation_(traffic_annotation) {}
Socks5ServerSocket::~Socks5ServerSocket() {
Disconnect();
}
const HostPortPair& Socks5ServerSocket::request_endpoint() const {
return request_endpoint_;
}
int Socks5ServerSocket::Connect(CompletionOnceCallback callback) {
DCHECK(transport_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
// If already connected, then just return OK.
if (completed_handshake_)
return OK;
net_log_.BeginEvent(NetLogEventType::SOCKS5_CONNECT);
next_state_ = STATE_GREET_READ;
buffer_.clear();
int rv = DoLoop(OK);
if (rv == ERR_IO_PENDING) {
user_callback_ = std::move(callback);
} else {
net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_CONNECT, rv);
}
return rv;
}
void Socks5ServerSocket::Disconnect() {
completed_handshake_ = false;
transport_->Disconnect();
// Reset other states to make sure they aren't mistakenly used later.
// These are the states initialized by Connect().
next_state_ = STATE_NONE;
user_callback_.Reset();
}
bool Socks5ServerSocket::IsConnected() const {
return completed_handshake_ && transport_->IsConnected();
}
bool Socks5ServerSocket::IsConnectedAndIdle() const {
return completed_handshake_ && transport_->IsConnectedAndIdle();
}
const NetLogWithSource& Socks5ServerSocket::NetLog() const {
return net_log_;
}
bool Socks5ServerSocket::WasEverUsed() const {
return was_ever_used_;
}
bool Socks5ServerSocket::WasAlpnNegotiated() const {
if (transport_) {
return transport_->WasAlpnNegotiated();
}
NOTREACHED();
return false;
}
NextProto Socks5ServerSocket::GetNegotiatedProtocol() const {
if (transport_) {
return transport_->GetNegotiatedProtocol();
}
NOTREACHED();
return kProtoUnknown;
}
bool Socks5ServerSocket::GetSSLInfo(SSLInfo* ssl_info) {
if (transport_) {
return transport_->GetSSLInfo(ssl_info);
}
NOTREACHED();
return false;
}
int64_t Socks5ServerSocket::GetTotalReceivedBytes() const {
return transport_->GetTotalReceivedBytes();
}
void Socks5ServerSocket::ApplySocketTag(const SocketTag& tag) {
return transport_->ApplySocketTag(tag);
}
// Read is called by the transport layer above to read. This can only be done
// if the SOCKS handshake is complete.
int Socks5ServerSocket::Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
DCHECK(callback);
int rv = transport_->Read(
buf, buf_len,
base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete,
base::Unretained(this), std::move(callback)));
if (rv > 0)
was_ever_used_ = true;
return rv;
}
// Write is called by the transport layer. This can only be done if the
// SOCKS handshake is complete.
int Socks5ServerSocket::Write(
IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) {
DCHECK(completed_handshake_);
DCHECK_EQ(STATE_NONE, next_state_);
DCHECK(!user_callback_);
DCHECK(callback);
int rv = transport_->Write(
buf, buf_len,
base::BindOnce(&Socks5ServerSocket::OnReadWriteComplete,
base::Unretained(this), std::move(callback)),
traffic_annotation);
if (rv > 0)
was_ever_used_ = true;
return rv;
}
int Socks5ServerSocket::SetReceiveBufferSize(int32_t size) {
return transport_->SetReceiveBufferSize(size);
}
int Socks5ServerSocket::SetSendBufferSize(int32_t size) {
return transport_->SetSendBufferSize(size);
}
void Socks5ServerSocket::DoCallback(int result) {
DCHECK_NE(ERR_IO_PENDING, result);
DCHECK(user_callback_);
// Since Run() may result in Read being called,
// clear user_callback_ up front.
std::move(user_callback_).Run(result);
}
void Socks5ServerSocket::OnIOComplete(int result) {
DCHECK_NE(STATE_NONE, next_state_);
int rv = DoLoop(result);
if (rv != ERR_IO_PENDING) {
net_log_.EndEvent(NetLogEventType::SOCKS5_CONNECT);
DoCallback(rv);
}
}
void Socks5ServerSocket::OnReadWriteComplete(CompletionOnceCallback callback,
int result) {
DCHECK_NE(ERR_IO_PENDING, result);
DCHECK(callback);
if (result > 0)
was_ever_used_ = true;
std::move(callback).Run(result);
}
int Socks5ServerSocket::DoLoop(int last_io_result) {
DCHECK_NE(next_state_, STATE_NONE);
int rv = last_io_result;
do {
State state = next_state_;
next_state_ = STATE_NONE;
switch (state) {
case STATE_GREET_READ:
DCHECK_EQ(OK, rv);
net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_READ);
rv = DoGreetRead();
break;
case STATE_GREET_READ_COMPLETE:
rv = DoGreetReadComplete(rv);
net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_READ,
rv);
break;
case STATE_GREET_WRITE:
DCHECK_EQ(OK, rv);
net_log_.BeginEvent(NetLogEventType::SOCKS5_GREET_WRITE);
rv = DoGreetWrite();
break;
case STATE_GREET_WRITE_COMPLETE:
rv = DoGreetWriteComplete(rv);
net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
rv);
break;
case STATE_AUTH_READ:
DCHECK_EQ(OK, rv);
rv = DoAuthRead();
break;
case STATE_AUTH_READ_COMPLETE:
rv = DoAuthReadComplete(rv);
break;
case STATE_AUTH_WRITE:
DCHECK_EQ(OK, rv);
rv = DoAuthWrite();
break;
case STATE_AUTH_WRITE_COMPLETE:
rv = DoAuthWriteComplete(rv);
break;
case STATE_HANDSHAKE_READ:
DCHECK_EQ(OK, rv);
net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
rv = DoHandshakeRead();
break;
case STATE_HANDSHAKE_READ_COMPLETE:
rv = DoHandshakeReadComplete(rv);
net_log_.EndEventWithNetErrorCode(
NetLogEventType::SOCKS5_HANDSHAKE_READ, rv);
break;
case STATE_HANDSHAKE_WRITE:
DCHECK_EQ(OK, rv);
net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_WRITE);
rv = DoHandshakeWrite();
break;
case STATE_HANDSHAKE_WRITE_COMPLETE:
rv = DoHandshakeWriteComplete(rv);
net_log_.EndEventWithNetErrorCode(
NetLogEventType::SOCKS5_HANDSHAKE_WRITE, rv);
break;
default:
NOTREACHED() << "bad state";
rv = ERR_UNEXPECTED;
break;
}
} while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
return rv;
}
int Socks5ServerSocket::DoGreetRead() {
next_state_ = STATE_GREET_READ_COMPLETE;
if (buffer_.empty()) {
read_header_size_ = kGreetReadHeaderSize;
}
int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len,
io_callback_);
}
int Socks5ServerSocket::DoGreetReadComplete(int result) {
if (result < 0)
return result;
if (result == 0) {
net_log_.AddEvent(
NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
// When the first few bytes are read, check how many more are required
// and accordingly increase them
if (buffer_.size() == kGreetReadHeaderSize) {
if (buffer_[0] != kSOCKS5Version) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
int nmethods = buffer_[1];
if (nmethods == 0) {
net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH);
return ERR_SOCKS_CONNECTION_FAILED;
}
read_header_size_ += nmethods;
next_state_ = STATE_GREET_READ;
return OK;
}
if (buffer_.size() == read_header_size_) {
int nmethods = buffer_[1];
char expected_method = kAuthMethodNone;
if (!user_.empty() || !pass_.empty()) {
expected_method = kAuthMethodUserPass;
}
void* match =
std::memchr(&buffer_[kGreetReadHeaderSize], expected_method, nmethods);
if (match) {
auth_method_ = expected_method;
} else {
auth_method_ = kAuthMethodNoAcceptable;
}
buffer_.clear();
next_state_ = STATE_GREET_WRITE;
return OK;
}
next_state_ = STATE_GREET_READ;
return OK;
}
int Socks5ServerSocket::DoGreetWrite() {
if (buffer_.empty()) {
const char write_data[] = {kSOCKS5Version, auth_method_};
buffer_ = std::string(write_data, std::size(write_data));
bytes_sent_ = 0;
}
next_state_ = STATE_GREET_WRITE_COMPLETE;
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
handshake_buf_len);
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
int Socks5ServerSocket::DoGreetWriteComplete(int result) {
if (result < 0)
return result;
bytes_sent_ += result;
if (bytes_sent_ == buffer_.size()) {
buffer_.clear();
if (auth_method_ == kAuthMethodNone) {
next_state_ = STATE_HANDSHAKE_READ;
} else if (auth_method_ == kAuthMethodUserPass) {
next_state_ = STATE_AUTH_READ;
} else {
net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH);
return ERR_SOCKS_CONNECTION_FAILED;
}
} else {
next_state_ = STATE_GREET_WRITE;
}
return OK;
}
int Socks5ServerSocket::DoAuthRead() {
next_state_ = STATE_AUTH_READ_COMPLETE;
if (buffer_.empty()) {
read_header_size_ = kAuthReadHeaderSize;
}
int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len,
io_callback_);
}
int Socks5ServerSocket::DoAuthReadComplete(int result) {
if (result < 0)
return result;
if (result == 0) {
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
// When the first few bytes are read, check how many more are required
// and accordingly increase them
if (buffer_.size() == kAuthReadHeaderSize) {
if (buffer_[0] != kSubnegotiationVersion) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
int username_len = buffer_[1];
read_header_size_ += username_len + 1;
next_state_ = STATE_AUTH_READ;
return OK;
}
if (buffer_.size() == read_header_size_) {
int username_len = buffer_[1];
int password_len = buffer_[kAuthReadHeaderSize + username_len];
size_t password_offset = kAuthReadHeaderSize + username_len + 1;
if (buffer_.size() == password_offset && password_len != 0) {
read_header_size_ += password_len;
next_state_ = STATE_AUTH_READ;
return OK;
}
if (buffer_.compare(kAuthReadHeaderSize, username_len, user_) == 0 &&
buffer_.compare(password_offset, password_len, pass_) == 0) {
auth_status_ = kAuthStatusSuccess;
} else {
auth_status_ = kAuthStatusFailure;
}
buffer_.clear();
next_state_ = STATE_AUTH_WRITE;
return OK;
}
next_state_ = STATE_AUTH_READ;
return OK;
}
int Socks5ServerSocket::DoAuthWrite() {
if (buffer_.empty()) {
const char write_data[] = {kSubnegotiationVersion, auth_status_};
buffer_ = std::string(write_data, std::size(write_data));
bytes_sent_ = 0;
}
next_state_ = STATE_AUTH_WRITE_COMPLETE;
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
handshake_buf_len);
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
int Socks5ServerSocket::DoAuthWriteComplete(int result) {
if (result < 0)
return result;
bytes_sent_ += result;
if (bytes_sent_ == buffer_.size()) {
buffer_.clear();
if (auth_status_ == kAuthStatusSuccess) {
next_state_ = STATE_HANDSHAKE_READ;
} else {
return ERR_SOCKS_CONNECTION_FAILED;
}
} else {
next_state_ = STATE_AUTH_WRITE;
}
return OK;
}
int Socks5ServerSocket::DoHandshakeRead() {
next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
if (buffer_.empty()) {
read_header_size_ = kReadHeaderSize;
}
int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len,
io_callback_);
}
int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
if (result < 0)
return result;
// The underlying socket closed unexpectedly.
if (result == 0) {
net_log_.AddEvent(
NetLogEventType::SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
// When the first few bytes are read, check how many more are required
// and accordingly increase them
if (buffer_.size() == kReadHeaderSize) {
if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
SocksCommandType command = static_cast<SocksCommandType>(buffer_[1]);
if (command == kCommandConnect) {
// The proxy replies with success immediately without first connecting
// to the requested endpoint.
reply_ = kReplySuccess;
} else if (command == kCommandBind || command == kCommandUDPAssociate) {
reply_ = kReplyCommandNotSupported;
} else {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_COMMAND,
"commmand", buffer_[1]);
return ERR_SOCKS_CONNECTION_FAILED;
}
// We check the type of IP/Domain the server returns and accordingly
// increase the size of the request. For domains, we need to read the
// size of the domain, so the initial request size is upto the domain
// size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
// read, we substract 1 byte from the additional request size.
address_type_ = static_cast<SocksEndPointAddressType>(buffer_[3]);
if (address_type_ == kEndPointDomain) {
address_size_ = static_cast<uint8_t>(buffer_[4]);
if (address_size_ == 0) {
net_log_.AddEvent(NetLogEventType::SOCKS_ZERO_LENGTH_DOMAIN);
return ERR_SOCKS_CONNECTION_FAILED;
}
} else if (address_type_ == kEndPointResolvedIPv4) {
address_size_ = sizeof(struct in_addr);
--read_header_size_;
} else if (address_type_ == kEndPointResolvedIPv6) {
address_size_ = sizeof(struct in6_addr);
--read_header_size_;
} else {
// Aborts connection on unspecified address type.
net_log_.AddEventWithIntParams(
NetLogEventType::SOCKS_UNKNOWN_ADDRESS_TYPE, "address_type",
buffer_[3]);
return ERR_SOCKS_CONNECTION_FAILED;
}
read_header_size_ += address_size_ + sizeof(uint16_t);
next_state_ = STATE_HANDSHAKE_READ;
return OK;
}
// When the final bytes are read, setup handshake.
if (buffer_.size() == read_header_size_) {
size_t port_start = read_header_size_ - sizeof(uint16_t);
uint16_t port_net;
std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t));
uint16_t port_host = base::NetToHost16(port_net);
size_t address_start = port_start - address_size_;
if (address_type_ == kEndPointDomain) {
std::string domain(&buffer_[address_start], address_size_);
request_endpoint_ = HostPortPair(domain, port_host);
} else {
IPAddress ip_addr(
reinterpret_cast<const uint8_t*>(&buffer_[address_start]),
address_size_);
IPEndPoint endpoint(ip_addr, port_host);
request_endpoint_ = HostPortPair::FromIPEndPoint(endpoint);
}
buffer_.clear();
next_state_ = STATE_HANDSHAKE_WRITE;
return OK;
}
next_state_ = STATE_HANDSHAKE_READ;
return OK;
}
// Writes the SOCKS handshake data to the underlying socket connection.
int Socks5ServerSocket::DoHandshakeWrite() {
next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
if (buffer_.empty()) {
const char write_data[] = {
// clang-format off
kSOCKS5Version,
reply_,
kSOCKS5Reserved,
kEndPointResolvedIPv4,
0x00, 0x00, 0x00, 0x00, // BND.ADDR
0x00, 0x00, // BND.PORT
// clang-format on
};
buffer_ = std::string(write_data, std::size(write_data));
bytes_sent_ = 0;
}
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], handshake_buf_len);
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
int Socks5ServerSocket::DoHandshakeWriteComplete(int result) {
if (result < 0)
return result;
// We ignore the case when result is 0, since the underlying Write
// may return spurious writes while waiting on the socket.
bytes_sent_ += result;
if (bytes_sent_ == buffer_.size()) {
buffer_.clear();
if (reply_ == kReplySuccess) {
completed_handshake_ = true;
next_state_ = STATE_NONE;
} else {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_SERVER_ERROR,
"error_code", reply_);
return ERR_SOCKS_CONNECTION_FAILED;
}
} else {
next_state_ = STATE_HANDSHAKE_WRITE;
}
return OK;
}
int Socks5ServerSocket::GetPeerAddress(IPEndPoint* address) const {
return transport_->GetPeerAddress(address);
}
int Socks5ServerSocket::GetLocalAddress(IPEndPoint* address) const {
return transport_->GetLocalAddress(address);
}
} // namespace net

View File

@ -0,0 +1,166 @@
// Copyright 2018 The Chromium Authors. All rights reserved.
// Copyright 2018 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_
#define NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>
#include "base/memory/scoped_refptr.h"
#include "net/base/completion_once_callback.h"
#include "net/base/completion_repeating_callback.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/log/net_log_with_source.h"
#include "net/socket/connection_attempts.h"
#include "net/socket/next_proto.h"
#include "net/socket/stream_socket.h"
#include "net/ssl/ssl_info.h"
namespace net {
struct NetworkTrafficAnnotationTag;
// This StreamSocket is used to setup a SOCKSv5 handshake with a socks client.
// Currently no SOCKSv5 authentication is supported.
class Socks5ServerSocket : public StreamSocket {
public:
Socks5ServerSocket(std::unique_ptr<StreamSocket> transport_socket,
const std::string& user,
const std::string& pass,
const NetworkTrafficAnnotationTag& traffic_annotation);
// On destruction Disconnect() is called.
~Socks5ServerSocket() override;
Socks5ServerSocket(const Socks5ServerSocket&) = delete;
Socks5ServerSocket& operator=(const Socks5ServerSocket&) = delete;
const HostPortPair& request_endpoint() const;
// StreamSocket implementation.
// Does the SOCKS handshake and completes the protocol.
int Connect(CompletionOnceCallback callback) override;
void Disconnect() override;
bool IsConnected() const override;
bool IsConnectedAndIdle() const override;
const NetLogWithSource& NetLog() const override;
bool WasEverUsed() const override;
bool WasAlpnNegotiated() const override;
NextProto GetNegotiatedProtocol() const override;
bool GetSSLInfo(SSLInfo* ssl_info) override;
int64_t GetTotalReceivedBytes() const override;
void ApplySocketTag(const SocketTag& tag) override;
// Socket implementation.
int Read(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback) override;
int Write(IOBuffer* buf,
int buf_len,
CompletionOnceCallback callback,
const NetworkTrafficAnnotationTag& traffic_annotation) override;
int SetReceiveBufferSize(int32_t size) override;
int SetSendBufferSize(int32_t size) override;
int GetPeerAddress(IPEndPoint* address) const override;
int GetLocalAddress(IPEndPoint* address) const override;
private:
enum State {
STATE_GREET_READ,
STATE_GREET_READ_COMPLETE,
STATE_GREET_WRITE,
STATE_GREET_WRITE_COMPLETE,
STATE_AUTH_READ,
STATE_AUTH_READ_COMPLETE,
STATE_AUTH_WRITE,
STATE_AUTH_WRITE_COMPLETE,
STATE_HANDSHAKE_WRITE,
STATE_HANDSHAKE_WRITE_COMPLETE,
STATE_HANDSHAKE_READ,
STATE_HANDSHAKE_READ_COMPLETE,
STATE_NONE,
};
// Addressing type that can be specified in requests or responses.
enum SocksEndPointAddressType {
kEndPointDomain = 0x03,
kEndPointResolvedIPv4 = 0x01,
kEndPointResolvedIPv6 = 0x04,
};
void DoCallback(int result);
void OnIOComplete(int result);
void OnReadWriteComplete(CompletionOnceCallback callback, int result);
int DoLoop(int last_io_result);
int DoGreetRead();
int DoGreetReadComplete(int result);
int DoGreetWrite();
int DoGreetWriteComplete(int result);
int DoAuthRead();
int DoAuthReadComplete(int result);
int DoAuthWrite();
int DoAuthWriteComplete(int result);
int DoHandshakeRead();
int DoHandshakeReadComplete(int result);
int DoHandshakeWrite();
int DoHandshakeWriteComplete(int result);
CompletionRepeatingCallback io_callback_;
// Stores the underlying socket.
std::unique_ptr<StreamSocket> transport_;
State next_state_;
// Stores the callback to the layer above, called on completing Connect().
CompletionOnceCallback user_callback_;
// This IOBuffer is used by the class to read and write
// SOCKS handshake data. The length contains the expected size to
// read or write.
scoped_refptr<IOBuffer> handshake_buf_;
// While writing, this buffer stores the complete write handshake data.
// While reading, it stores the handshake information received so far.
std::string buffer_;
// This becomes true when the SOCKS handshake has completed and the
// overlying connection is free to communicate.
bool completed_handshake_;
// Contains the bytes sent by the SOCKS handshake.
size_t bytes_sent_;
size_t read_header_size_;
bool was_ever_used_;
SocksEndPointAddressType address_type_;
int address_size_;
std::string user_;
std::string pass_;
char auth_method_;
char auth_status_;
char reply_;
HostPortPair request_endpoint_;
NetLogWithSource net_log_;
// Traffic annotation for socket control.
const NetworkTrafficAnnotationTag& traffic_annotation_;
};
} // namespace net
#endif // NET_TOOLS_NAIVE_SOCKS5_SERVER_SOCKET_H_