Support SOCKS proxy authentication

This commit is contained in:
klzgrad 2021-04-19 22:34:08 +08:00
parent 404aab75f6
commit 67e2c2f32a
5 changed files with 185 additions and 37 deletions

View File

@ -29,12 +29,16 @@ namespace net {
NaiveProxy::NaiveProxy(std::unique_ptr<ServerSocket> listen_socket, NaiveProxy::NaiveProxy(std::unique_ptr<ServerSocket> listen_socket,
ClientProtocol protocol, ClientProtocol protocol,
const std::string& listen_user,
const std::string& listen_pass,
int concurrency, int concurrency,
RedirectResolver* resolver, RedirectResolver* resolver,
HttpNetworkSession* session, HttpNetworkSession* session,
const NetworkTrafficAnnotationTag& traffic_annotation) const NetworkTrafficAnnotationTag& traffic_annotation)
: listen_socket_(std::move(listen_socket)), : listen_socket_(std::move(listen_socket)),
protocol_(protocol), protocol_(protocol),
listen_user_(listen_user),
listen_pass_(listen_pass),
concurrency_(std::min(4, std::max(1, concurrency))), concurrency_(std::min(4, std::max(1, concurrency))),
resolver_(resolver), resolver_(resolver),
session_(session), session_(session),
@ -108,6 +112,7 @@ void NaiveProxy::DoConnect() {
if (protocol_ == ClientProtocol::kSocks5) { if (protocol_ == ClientProtocol::kSocks5) {
socket = std::make_unique<Socks5ServerSocket>(std::move(accepted_socket_), socket = std::make_unique<Socks5ServerSocket>(std::move(accepted_socket_),
listen_user_, listen_pass_,
traffic_annotation_); traffic_annotation_);
} else if (protocol_ == ClientProtocol::kHttp) { } else if (protocol_ == ClientProtocol::kHttp) {
socket = std::make_unique<HttpProxySocket>(std::move(accepted_socket_), socket = std::make_unique<HttpProxySocket>(std::move(accepted_socket_),

View File

@ -34,6 +34,8 @@ class NaiveProxy {
public: public:
NaiveProxy(std::unique_ptr<ServerSocket> server_socket, NaiveProxy(std::unique_ptr<ServerSocket> server_socket,
ClientProtocol protocol, ClientProtocol protocol,
const std::string& listen_user,
const std::string& listen_pass,
int concurrency, int concurrency,
RedirectResolver* resolver, RedirectResolver* resolver,
HttpNetworkSession* session, HttpNetworkSession* session,
@ -59,6 +61,8 @@ class NaiveProxy {
std::unique_ptr<ServerSocket> listen_socket_; std::unique_ptr<ServerSocket> listen_socket_;
ClientProtocol protocol_; ClientProtocol protocol_;
std::string listen_user_;
std::string listen_pass_;
int concurrency_; int concurrency_;
ProxyInfo proxy_info_; ProxyInfo proxy_info_;
SSLConfig server_ssl_config_; SSLConfig server_ssl_config_;

View File

@ -19,6 +19,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/rand_util.h" #include "base/rand_util.h"
#include "base/run_loop.h" #include "base/run_loop.h"
#include "base/strings/escape.h"
#include "base/strings/string_number_conversions.h" #include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h" #include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h" #include "base/strings/utf_string_conversions.h"
@ -95,6 +96,8 @@ struct CommandLine {
struct Params { struct Params {
net::ClientProtocol protocol; net::ClientProtocol protocol;
std::string listen_user;
std::string listen_pass;
std::string listen_addr; std::string listen_addr;
int listen_port; int listen_port;
int concurrency; int concurrency;
@ -231,7 +234,8 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
params->protocol = net::ClientProtocol::kSocks5; params->protocol = net::ClientProtocol::kSocks5;
params->listen_addr = "0.0.0.0"; params->listen_addr = "0.0.0.0";
params->listen_port = 1080; params->listen_port = 1080;
url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT); url::AddStandardScheme("socks",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT); url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT);
if (!cmdline.listen.empty()) { if (!cmdline.listen.empty()) {
GURL url(cmdline.listen); GURL url(cmdline.listen);
@ -253,6 +257,12 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
std::cerr << "Invalid scheme in --listen" << std::endl; std::cerr << "Invalid scheme in --listen" << std::endl;
return false; return false;
} }
if (!url.username().empty()) {
params->listen_user = base::UnescapeBinaryURLComponent(url.username());
}
if (!url.password().empty()) {
params->listen_pass = base::UnescapeBinaryURLComponent(url.password());
}
if (!url.host().empty()) { if (!url.host().empty()) {
params->listen_addr = url.host(); params->listen_addr = url.host();
} }
@ -569,6 +579,7 @@ int main(int argc, char* argv[]) {
} }
net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol,
params.listen_user, params.listen_pass,
params.concurrency, resolver.get(), session, params.concurrency, resolver.get(), session,
kTrafficAnnotation); kTrafficAnnotation);

View File

@ -28,11 +28,16 @@ enum SocksCommandType {
}; };
static constexpr unsigned int kGreetReadHeaderSize = 2; static constexpr unsigned int kGreetReadHeaderSize = 2;
static constexpr unsigned int kAuthReadHeaderSize = 2;
static constexpr unsigned int kReadHeaderSize = 5; static constexpr unsigned int kReadHeaderSize = 5;
static constexpr char kSOCKS5Version = '\x05'; static constexpr char kSOCKS5Version = '\x05';
static constexpr char kSOCKS5Reserved = '\x00'; static constexpr char kSOCKS5Reserved = '\x00';
static constexpr char kAuthMethodNone = '\x00'; static constexpr char kAuthMethodNone = '\x00';
static constexpr char kAuthMethodUserPass = '\x02';
static constexpr char kAuthMethodNoAcceptable = '\xff'; static constexpr char kAuthMethodNoAcceptable = '\xff';
static constexpr char kSubnegotiationVersion = '\x01';
static constexpr char kAuthStatusSuccess = '\x00';
static constexpr char kAuthStatusFailure = '\xff';
static constexpr char kReplySuccess = '\x00'; static constexpr char kReplySuccess = '\x00';
static constexpr char kReplyCommandNotSupported = '\x07'; static constexpr char kReplyCommandNotSupported = '\x07';
@ -41,17 +46,18 @@ static_assert(sizeof(struct in6_addr) == 16, "incorrect system size of IPv6");
Socks5ServerSocket::Socks5ServerSocket( Socks5ServerSocket::Socks5ServerSocket(
std::unique_ptr<StreamSocket> transport_socket, std::unique_ptr<StreamSocket> transport_socket,
const std::string& user,
const std::string& pass,
const NetworkTrafficAnnotationTag& traffic_annotation) const NetworkTrafficAnnotationTag& traffic_annotation)
: io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete, : io_callback_(base::BindRepeating(&Socks5ServerSocket::OnIOComplete,
base::Unretained(this))), base::Unretained(this))),
transport_(std::move(transport_socket)), transport_(std::move(transport_socket)),
next_state_(STATE_NONE), next_state_(STATE_NONE),
completed_handshake_(false), completed_handshake_(false),
bytes_received_(0),
bytes_sent_(0), bytes_sent_(0),
greet_read_header_size_(kGreetReadHeaderSize),
read_header_size_(kReadHeaderSize),
was_ever_used_(false), was_ever_used_(false),
user_(user),
pass_(pass),
net_log_(transport_->NetLog()), net_log_(transport_->NetLog()),
traffic_annotation_(traffic_annotation) {} traffic_annotation_(traffic_annotation) {}
@ -252,6 +258,20 @@ int Socks5ServerSocket::DoLoop(int last_io_result) {
net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE, net_log_.EndEventWithNetErrorCode(NetLogEventType::SOCKS5_GREET_WRITE,
rv); rv);
break; break;
case STATE_AUTH_READ:
DCHECK_EQ(OK, rv);
rv = DoAuthRead();
break;
case STATE_AUTH_READ_COMPLETE:
rv = DoAuthReadComplete(rv);
break;
case STATE_AUTH_WRITE:
DCHECK_EQ(OK, rv);
rv = DoAuthWrite();
break;
case STATE_AUTH_WRITE_COMPLETE:
rv = DoAuthWriteComplete(rv);
break;
case STATE_HANDSHAKE_READ: case STATE_HANDSHAKE_READ:
DCHECK_EQ(OK, rv); DCHECK_EQ(OK, rv);
net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ); net_log_.BeginEvent(NetLogEventType::SOCKS5_HANDSHAKE_READ);
@ -285,11 +305,10 @@ int Socks5ServerSocket::DoGreetRead() {
next_state_ = STATE_GREET_READ_COMPLETE; next_state_ = STATE_GREET_READ_COMPLETE;
if (buffer_.empty()) { if (buffer_.empty()) {
DCHECK_EQ(0U, bytes_received_); read_header_size_ = kGreetReadHeaderSize;
DCHECK_EQ(kGreetReadHeaderSize, greet_read_header_size_);
} }
int handshake_buf_len = greet_read_header_size_ - bytes_received_; int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len); DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len); handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len, return transport_->Read(handshake_buf_.get(), handshake_buf_len,
@ -306,32 +325,37 @@ int Socks5ServerSocket::DoGreetReadComplete(int result) {
return ERR_SOCKS_CONNECTION_FAILED; return ERR_SOCKS_CONNECTION_FAILED;
} }
bytes_received_ += result;
buffer_.append(handshake_buf_->data(), result); buffer_.append(handshake_buf_->data(), result);
// When the first few bytes are read, check how many more are required // When the first few bytes are read, check how many more are required
// and accordingly increase them // and accordingly increase them
if (bytes_received_ == kGreetReadHeaderSize) { if (buffer_.size() == kGreetReadHeaderSize) {
if (buffer_[0] != kSOCKS5Version) { if (buffer_[0] != kSOCKS5Version) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]); "version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED; return ERR_SOCKS_CONNECTION_FAILED;
} }
if (buffer_[1] == 0) { int nmethods = buffer_[1];
if (nmethods == 0) {
net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH); net_log_.AddEvent(NetLogEventType::SOCKS_NO_REQUESTED_AUTH);
return ERR_SOCKS_CONNECTION_FAILED; return ERR_SOCKS_CONNECTION_FAILED;
} }
greet_read_header_size_ += buffer_[1]; read_header_size_ += nmethods;
next_state_ = STATE_GREET_READ; next_state_ = STATE_GREET_READ;
return OK; return OK;
} }
if (bytes_received_ == greet_read_header_size_) { if (buffer_.size() == read_header_size_) {
void* match = std::memchr(&buffer_[kGreetReadHeaderSize], kAuthMethodNone, int nmethods = buffer_[1];
greet_read_header_size_ - kGreetReadHeaderSize); char expected_method = kAuthMethodNone;
if (!user_.empty() || !pass_.empty()) {
expected_method = kAuthMethodUserPass;
}
void* match =
std::memchr(&buffer_[kGreetReadHeaderSize], expected_method, nmethods);
if (match) { if (match) {
auth_method_ = kAuthMethodNone; auth_method_ = expected_method;
} else { } else {
auth_method_ = kAuthMethodNoAcceptable; auth_method_ = kAuthMethodNoAcceptable;
} }
@ -368,9 +392,10 @@ int Socks5ServerSocket::DoGreetWriteComplete(int result) {
bytes_sent_ += result; bytes_sent_ += result;
if (bytes_sent_ == buffer_.size()) { if (bytes_sent_ == buffer_.size()) {
buffer_.clear(); buffer_.clear();
bytes_received_ = 0; if (auth_method_ == kAuthMethodNone) {
if (auth_method_ != kAuthMethodNoAcceptable) {
next_state_ = STATE_HANDSHAKE_READ; next_state_ = STATE_HANDSHAKE_READ;
} else if (auth_method_ == kAuthMethodUserPass) {
next_state_ = STATE_AUTH_READ;
} else { } else {
net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH); net_log_.AddEvent(NetLogEventType::SOCKS_NO_ACCEPTABLE_AUTH);
return ERR_SOCKS_CONNECTION_FAILED; return ERR_SOCKS_CONNECTION_FAILED;
@ -381,15 +406,112 @@ int Socks5ServerSocket::DoGreetWriteComplete(int result) {
return OK; return OK;
} }
int Socks5ServerSocket::DoAuthRead() {
next_state_ = STATE_AUTH_READ_COMPLETE;
if (buffer_.empty()) {
read_header_size_ = kAuthReadHeaderSize;
}
int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len,
io_callback_);
}
int Socks5ServerSocket::DoAuthReadComplete(int result) {
if (result < 0)
return result;
if (result == 0) {
return ERR_SOCKS_CONNECTION_FAILED;
}
buffer_.append(handshake_buf_->data(), result);
// When the first few bytes are read, check how many more are required
// and accordingly increase them
if (buffer_.size() == kAuthReadHeaderSize) {
if (buffer_[0] != kSubnegotiationVersion) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]);
return ERR_SOCKS_CONNECTION_FAILED;
}
int username_len = buffer_[1];
read_header_size_ += username_len + 1;
next_state_ = STATE_AUTH_READ;
return OK;
}
if (buffer_.size() == read_header_size_) {
int username_len = buffer_[1];
int password_len = buffer_[kAuthReadHeaderSize + username_len];
size_t password_offset = kAuthReadHeaderSize + username_len + 1;
if (buffer_.size() == password_offset && password_len != 0) {
read_header_size_ += password_len;
next_state_ = STATE_AUTH_READ;
return OK;
}
if (buffer_.compare(kAuthReadHeaderSize, username_len, user_) == 0 &&
buffer_.compare(password_offset, password_len, pass_) == 0) {
auth_status_ = kAuthStatusSuccess;
} else {
auth_status_ = kAuthStatusFailure;
}
buffer_.clear();
next_state_ = STATE_AUTH_WRITE;
return OK;
}
next_state_ = STATE_AUTH_READ;
return OK;
}
int Socks5ServerSocket::DoAuthWrite() {
if (buffer_.empty()) {
const char write_data[] = {kSubnegotiationVersion, auth_status_};
buffer_ = std::string(write_data, base::size(write_data));
bytes_sent_ = 0;
}
next_state_ = STATE_AUTH_WRITE_COMPLETE;
int handshake_buf_len = buffer_.size() - bytes_sent_;
DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
std::memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
handshake_buf_len);
return transport_->Write(handshake_buf_.get(), handshake_buf_len,
io_callback_, traffic_annotation_);
}
int Socks5ServerSocket::DoAuthWriteComplete(int result) {
if (result < 0)
return result;
bytes_sent_ += result;
if (bytes_sent_ == buffer_.size()) {
buffer_.clear();
if (auth_status_ == kAuthStatusSuccess) {
next_state_ = STATE_HANDSHAKE_READ;
} else {
return ERR_SOCKS_CONNECTION_FAILED;
}
} else {
next_state_ = STATE_AUTH_WRITE;
}
return OK;
}
int Socks5ServerSocket::DoHandshakeRead() { int Socks5ServerSocket::DoHandshakeRead() {
next_state_ = STATE_HANDSHAKE_READ_COMPLETE; next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
if (buffer_.empty()) { if (buffer_.empty()) {
DCHECK_EQ(0U, bytes_received_); read_header_size_ = kReadHeaderSize;
DCHECK_EQ(kReadHeaderSize, read_header_size_);
} }
int handshake_buf_len = read_header_size_ - bytes_received_; int handshake_buf_len = read_header_size_ - buffer_.size();
DCHECK_LT(0, handshake_buf_len); DCHECK_LT(0, handshake_buf_len);
handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len); handshake_buf_ = base::MakeRefCounted<IOBuffer>(handshake_buf_len);
return transport_->Read(handshake_buf_.get(), handshake_buf_len, return transport_->Read(handshake_buf_.get(), handshake_buf_len,
@ -408,11 +530,10 @@ int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
} }
buffer_.append(handshake_buf_->data(), result); buffer_.append(handshake_buf_->data(), result);
bytes_received_ += result;
// When the first few bytes are read, check how many more are required // When the first few bytes are read, check how many more are required
// and accordingly increase them // and accordingly increase them
if (bytes_received_ == kReadHeaderSize) { if (buffer_.size() == kReadHeaderSize) {
if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) { if (buffer_[0] != kSOCKS5Version || buffer_[2] != kSOCKS5Reserved) {
net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION, net_log_.AddEventWithIntParams(NetLogEventType::SOCKS_UNEXPECTED_VERSION,
"version", buffer_[0]); "version", buffer_[0]);
@ -463,7 +584,7 @@ int Socks5ServerSocket::DoHandshakeReadComplete(int result) {
} }
// When the final bytes are read, setup handshake. // When the final bytes are read, setup handshake.
if (bytes_received_ == read_header_size_) { if (buffer_.size() == read_header_size_) {
size_t port_start = read_header_size_ - sizeof(uint16_t); size_t port_start = read_header_size_ - sizeof(uint16_t);
uint16_t port_net; uint16_t port_net;
std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t)); std::memcpy(&port_net, &buffer_[port_start], sizeof(uint16_t));
@ -495,16 +616,14 @@ int Socks5ServerSocket::DoHandshakeWrite() {
if (buffer_.empty()) { if (buffer_.empty()) {
const char write_data[] = { const char write_data[] = {
// clang-format off
kSOCKS5Version, kSOCKS5Version,
reply_, reply_,
kSOCKS5Reserved, kSOCKS5Reserved,
kEndPointResolvedIPv4, kEndPointResolvedIPv4,
0x00, 0x00, 0x00, 0x00, 0x00, // BND.ADDR
0x00, 0x00, 0x00, // BND.PORT
0x00, // clang-format on
0x00, // BND.ADDR
0x00,
0x00, // BND.PORT
}; };
buffer_ = std::string(write_data, base::size(write_data)); buffer_ = std::string(write_data, base::size(write_data));
bytes_sent_ = 0; bytes_sent_ = 0;
@ -536,10 +655,8 @@ int Socks5ServerSocket::DoHandshakeWriteComplete(int result) {
"error_code", reply_); "error_code", reply_);
return ERR_SOCKS_CONNECTION_FAILED; return ERR_SOCKS_CONNECTION_FAILED;
} }
} else if (bytes_sent_ < buffer_.size()) {
next_state_ = STATE_HANDSHAKE_WRITE;
} else { } else {
NOTREACHED(); next_state_ = STATE_HANDSHAKE_WRITE;
} }
return OK; return OK;

View File

@ -32,6 +32,8 @@ struct NetworkTrafficAnnotationTag;
class Socks5ServerSocket : public StreamSocket { class Socks5ServerSocket : public StreamSocket {
public: public:
Socks5ServerSocket(std::unique_ptr<StreamSocket> transport_socket, Socks5ServerSocket(std::unique_ptr<StreamSocket> transport_socket,
const std::string& user,
const std::string& pass,
const NetworkTrafficAnnotationTag& traffic_annotation); const NetworkTrafficAnnotationTag& traffic_annotation);
// On destruction Disconnect() is called. // On destruction Disconnect() is called.
@ -78,6 +80,10 @@ class Socks5ServerSocket : public StreamSocket {
STATE_GREET_READ_COMPLETE, STATE_GREET_READ_COMPLETE,
STATE_GREET_WRITE, STATE_GREET_WRITE,
STATE_GREET_WRITE_COMPLETE, STATE_GREET_WRITE_COMPLETE,
STATE_AUTH_READ,
STATE_AUTH_READ_COMPLETE,
STATE_AUTH_WRITE,
STATE_AUTH_WRITE_COMPLETE,
STATE_HANDSHAKE_WRITE, STATE_HANDSHAKE_WRITE,
STATE_HANDSHAKE_WRITE_COMPLETE, STATE_HANDSHAKE_WRITE_COMPLETE,
STATE_HANDSHAKE_READ, STATE_HANDSHAKE_READ,
@ -97,10 +103,14 @@ class Socks5ServerSocket : public StreamSocket {
void OnReadWriteComplete(CompletionOnceCallback callback, int result); void OnReadWriteComplete(CompletionOnceCallback callback, int result);
int DoLoop(int last_io_result); int DoLoop(int last_io_result);
int DoGreetWrite();
int DoGreetWriteComplete(int result);
int DoGreetRead(); int DoGreetRead();
int DoGreetReadComplete(int result); int DoGreetReadComplete(int result);
int DoGreetWrite();
int DoGreetWriteComplete(int result);
int DoAuthRead();
int DoAuthReadComplete(int result);
int DoAuthWrite();
int DoAuthWriteComplete(int result);
int DoHandshakeRead(); int DoHandshakeRead();
int DoHandshakeReadComplete(int result); int DoHandshakeReadComplete(int result);
int DoHandshakeWrite(); int DoHandshakeWrite();
@ -129,11 +139,9 @@ class Socks5ServerSocket : public StreamSocket {
// overlying connection is free to communicate. // overlying connection is free to communicate.
bool completed_handshake_; bool completed_handshake_;
// These contain the bytes received / sent by the SOCKS handshake. // Contains the bytes sent by the SOCKS handshake.
size_t bytes_received_;
size_t bytes_sent_; size_t bytes_sent_;
size_t greet_read_header_size_;
size_t read_header_size_; size_t read_header_size_;
bool was_ever_used_; bool was_ever_used_;
@ -141,7 +149,10 @@ class Socks5ServerSocket : public StreamSocket {
SocksEndPointAddressType address_type_; SocksEndPointAddressType address_type_;
int address_size_; int address_size_;
std::string user_;
std::string pass_;
char auth_method_; char auth_method_;
char auth_status_;
char reply_; char reply_;
HostPortPair request_endpoint_; HostPortPair request_endpoint_;