Refactor config parsing

This commit is contained in:
klzgrad 2024-05-04 21:10:11 +08:00
parent c975384d22
commit d434781c6a
6 changed files with 459 additions and 362 deletions

View File

@ -1744,6 +1744,12 @@ static_library("preload_decoder") {
executable("naive") {
sources = [
"tools/naive/http_proxy_server_socket.cc",
"tools/naive/http_proxy_server_socket.h",
"tools/naive/naive_command_line.cc",
"tools/naive/naive_command_line.h",
"tools/naive/naive_config.cc",
"tools/naive/naive_config.h",
"tools/naive/naive_connection.cc",
"tools/naive/naive_connection.h",
"tools/naive/naive_padding_framer.cc",
@ -1752,15 +1758,13 @@ executable("naive") {
"tools/naive/naive_padding_socket.h",
"tools/naive/naive_protocol.cc",
"tools/naive/naive_protocol.h",
"tools/naive/naive_proxy_bin.cc",
"tools/naive/naive_proxy_delegate.cc",
"tools/naive/naive_proxy_delegate.h",
"tools/naive/naive_proxy.cc",
"tools/naive/naive_proxy.h",
"tools/naive/naive_proxy_bin.cc",
"tools/naive/naive_proxy_delegate.h",
"tools/naive/naive_proxy_delegate.cc",
"tools/naive/http_proxy_server_socket.cc",
"tools/naive/http_proxy_server_socket.h",
"tools/naive/redirect_resolver.h",
"tools/naive/redirect_resolver.cc",
"tools/naive/redirect_resolver.h",
"tools/naive/socks5_server_socket.cc",
"tools/naive/socks5_server_socket.h",
]

View File

@ -0,0 +1,61 @@
// Copyright 2024 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/naive_command_line.h"
#include <memory>
#include <utility>
#include "base/strings/utf_string_conversions.h"
DuplicateSwitchCollector::DuplicateSwitchCollector() = default;
DuplicateSwitchCollector::~DuplicateSwitchCollector() = default;
void DuplicateSwitchCollector::ResolveDuplicate(
std::string_view key,
base::CommandLine::StringPieceType new_value,
base::CommandLine::StringType& out_value) {
out_value = new_value;
values_by_key_[std::string(key)].push_back(
base::CommandLine::StringType(new_value));
}
const std::vector<base::CommandLine::StringType>&
DuplicateSwitchCollector::GetValuesByKey(std::string_view key) {
return values_by_key_[std::string(key)];
}
namespace {
DuplicateSwitchCollector* g_duplicate_switch_collector;
}
void DuplicateSwitchCollector::InitInstance() {
auto new_duplicate_switch_collector =
std::make_unique<DuplicateSwitchCollector>();
g_duplicate_switch_collector = new_duplicate_switch_collector.get();
base::CommandLine::SetDuplicateSwitchHandler(
std::move(new_duplicate_switch_collector));
}
DuplicateSwitchCollector& DuplicateSwitchCollector::GetInstance() {
CHECK(g_duplicate_switch_collector != nullptr);
return *g_duplicate_switch_collector;
}
base::Value::Dict GetSwitchesAsValue(const base::CommandLine& cmdline) {
base::Value::Dict dict;
for (const auto& [key, value] : cmdline.GetSwitches()) {
const std::vector<base::CommandLine::StringType>& values =
DuplicateSwitchCollector::GetInstance().GetValuesByKey(key);
if (values.size() > 1) {
base::Value::List list;
for (const base::CommandLine::StringType& v : values) {
list.Append(v);
}
dict.Set(key, std::move(list));
} else {
dict.Set(key, value);
}
}
return dict;
}

View File

@ -0,0 +1,36 @@
// Copyright 2024 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_COMMAND_LINE_H_
#define NET_TOOLS_NAIVE_NAIVE_COMMAND_LINE_H_
#include <map>
#include <string>
#include <string_view>
#include <vector>
#include "base/command_line.h"
#include "base/values.h"
class DuplicateSwitchCollector : public base::DuplicateSwitchHandler {
public:
DuplicateSwitchCollector();
~DuplicateSwitchCollector() override;
void ResolveDuplicate(std::string_view key,
base::CommandLine::StringPieceType new_value,
base::CommandLine::StringType& out_value) override;
const std::vector<base::CommandLine::StringType>& GetValuesByKey(
std::string_view key);
static void InitInstance();
static DuplicateSwitchCollector& GetInstance();
private:
std::map<std::string, std::vector<base::CommandLine::StringType>>
values_by_key_;
};
base::Value::Dict GetSwitchesAsValue(const base::CommandLine& cmdline);
#endif // NET_TOOLS_NAIVE_NAIVE_COMMAND_LINE_H_

View File

@ -0,0 +1,198 @@
// Copyright 2024 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/tools/naive/naive_config.h"
#include <iostream>
#include "base/strings/escape.h"
#include "base/strings/string_number_conversions.h"
#include "net/base/url_util.h"
#include "url/gurl.h"
namespace net {
NaiveListenConfig::NaiveListenConfig() = default;
NaiveListenConfig::NaiveListenConfig(const NaiveListenConfig&) = default;
NaiveListenConfig::~NaiveListenConfig() = default;
bool NaiveListenConfig::Parse(const std::string& str) {
GURL url(str);
if (url.scheme() == "socks") {
protocol = ClientProtocol::kSocks5;
} else if (url.scheme() == "http") {
protocol = ClientProtocol::kHttp;
} else if (url.scheme() == "redir") {
#if BUILDFLAG(IS_LINUX)
protocol = ClientProtocol::kRedir;
#else
std::cerr << "Redir protocol only supports Linux." << std::endl;
return false;
#endif
} else {
std::cerr << "Invalid scheme in " << str << std::endl;
return false;
}
if (!url.username().empty()) {
user = base::UnescapeBinaryURLComponent(url.username());
}
if (!url.password().empty()) {
pass = base::UnescapeBinaryURLComponent(url.password());
}
if (!url.host().empty()) {
addr = url.HostNoBrackets();
}
int effective_port = url.EffectiveIntPort();
if (effective_port == url::PORT_INVALID) {
std::cerr << "Invalid port in " << str << std::endl;
return false;
}
if (effective_port != url::PORT_UNSPECIFIED) {
port = effective_port;
}
return true;
}
NaiveConfig::NaiveConfig() = default;
NaiveConfig::NaiveConfig(const NaiveConfig&) = default;
NaiveConfig::~NaiveConfig() = default;
bool NaiveConfig::Parse(const base::Value::Dict& value) {
if (const base::Value* v = value.Find("listen")) {
listen.clear();
if (const std::string* str = v->GetIfString()) {
if (!listen.emplace_back().Parse(*str)) {
return false;
}
} else if (const base::Value::List* strs = v->GetIfList()) {
for (const auto& str_e : *strs) {
if (const std::string* s = str_e.GetIfString()) {
if (!listen.emplace_back().Parse(*s)) {
return false;
}
} else {
std::cerr << "Invalid listen element" << std::endl;
return false;
}
}
} else {
std::cerr << "Invalid listen" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("insecure-concurrency")) {
if (std::optional<int> i = v->GetIfInt()) {
insecure_concurrency = *i;
} else if (const std::string* str = v->GetIfString()) {
if (!base::StringToInt(*str, &insecure_concurrency)) {
std::cerr << "Invalid concurrency" << std::endl;
return false;
}
} else {
std::cerr << "Invalid concurrency" << std::endl;
return false;
}
if (insecure_concurrency < 1) {
std::cerr << "Invalid concurrency" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("extra-headers")) {
if (const std::string* str = v->GetIfString()) {
extra_headers.AddHeadersFromString(*str);
} else {
std::cerr << "Invalid extra-headers" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("proxy")) {
if (const std::string* str = v->GetIfString(); str && !str->empty()) {
GURL url(*str);
net::GetIdentityFromURL(url, &proxy_user, &proxy_pass);
GURL::Replacements remove_auth;
remove_auth.ClearUsername();
remove_auth.ClearPassword();
GURL url_no_auth = url.ReplaceComponents(remove_auth);
proxy_url = url_no_auth.GetWithEmptyPath().spec();
if (proxy_url.empty()) {
std::cerr << "Invalid proxy" << std::endl;
return false;
} else if (proxy_url.back() == '/') {
proxy_url.pop_back();
}
} else {
std::cerr << "Invalid proxy" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("host-resolver-rules")) {
if (const std::string* str = v->GetIfString()) {
host_resolver_rules = *str;
} else {
std::cerr << "Invalid host-resolver-rules" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("resolver-range")) {
if (const std::string* str = v->GetIfString(); str && !str->empty()) {
if (!net::ParseCIDRBlock(*str, &resolver_range, &resolver_prefix)) {
std::cerr << "Invalid resolver-range" << std::endl;
return false;
}
if (resolver_range.IsIPv6()) {
std::cerr << "IPv6 resolver range not supported" << std::endl;
return false;
}
} else {
std::cerr << "Invalid resolver-range" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("log")) {
if (const std::string* str = v->GetIfString()) {
if (!str->empty()) {
log.logging_dest = logging::LOG_TO_FILE;
log_file = base::FilePath::FromUTF8Unsafe(*str);
log.log_file_path = log_file.value().c_str();
} else {
log.logging_dest = logging::LOG_TO_STDERR;
}
} else {
std::cerr << "Invalid log" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("log-net-log")) {
if (const std::string* str = v->GetIfString(); str && !str->empty()) {
log_net_log = base::FilePath::FromUTF8Unsafe(*str);
} else {
std::cerr << "Invalid log-net-log" << std::endl;
return false;
}
}
if (const base::Value* v = value.Find("ssl-key-log-file")) {
if (const std::string* str = v->GetIfString(); str && !str->empty()) {
ssl_key_log_file = base::FilePath::FromUTF8Unsafe(*str);
} else {
std::cerr << "Invalid ssl-key-log-file" << std::endl;
return false;
}
}
return true;
}
} // namespace net

View File

@ -0,0 +1,64 @@
// Copyright 2024 klzgrad <kizdiv@gmail.com>. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_TOOLS_NAIVE_NAIVE_CONFIG_H_
#define NET_TOOLS_NAIVE_NAIVE_CONFIG_H_
#include <optional>
#include <string>
#include <vector>
#include "base/files/file_path.h"
#include "base/logging.h"
#include "base/values.h"
#include "net/base/ip_address.h"
#include "net/http/http_request_headers.h"
#include "net/tools/naive/naive_protocol.h"
namespace net {
struct NaiveListenConfig {
ClientProtocol protocol = ClientProtocol::kSocks5;
std::string user;
std::string pass;
std::string addr = "0.0.0.0";
int port = 1080;
NaiveListenConfig();
NaiveListenConfig(const NaiveListenConfig&);
~NaiveListenConfig();
bool Parse(const std::string& str);
};
struct NaiveConfig {
std::vector<NaiveListenConfig> listen = {NaiveListenConfig()};
int insecure_concurrency = 1;
HttpRequestHeaders extra_headers;
std::string proxy_url = "direct://";
std::u16string proxy_user;
std::u16string proxy_pass;
std::string host_resolver_rules;
IPAddress resolver_range = {100, 64, 0, 0};
size_t resolver_prefix = 10;
logging::LoggingSettings log = {.logging_dest = logging::LOG_NONE};
base::FilePath log_file;
base::FilePath log_net_log;
base::FilePath ssl_key_log_file;
NaiveConfig();
NaiveConfig(const NaiveConfig&);
~NaiveConfig();
bool Parse(const base::Value::Dict& value);
};
} // namespace net
#endif // NET_TOOLS_NAIVE_NAIVE_CONFIG_H_

View File

@ -7,6 +7,7 @@
#include <iostream>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include "base/allocator/allocator_check.h"
@ -60,8 +61,11 @@
#include "net/socket/ssl_client_socket.h"
#include "net/socket/tcp_server_socket.h"
#include "net/socket/udp_server_socket.h"
#include "net/ssl/ssl_config_service.h"
#include "net/ssl/ssl_key_logger_impl.h"
#include "net/third_party/quiche/src/quiche/quic/core/quic_versions.h"
#include "net/tools/naive/naive_command_line.h"
#include "net/tools/naive/naive_config.h"
#include "net/tools/naive/naive_protocol.h"
#include "net/tools/naive/naive_proxy.h"
#include "net/tools/naive/naive_proxy_delegate.h"
@ -87,42 +91,6 @@ constexpr int kExpectedMaxUsers = 8;
constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("naive", "");
struct CommandLine {
std::vector<std::string> listens;
std::string proxy;
std::string concurrency;
std::string extra_headers;
std::string host_resolver_rules;
std::string resolver_range;
bool no_log;
base::FilePath log;
base::FilePath log_net_log;
base::FilePath ssl_key_log_file;
};
struct ListenParams {
net::ClientProtocol protocol;
std::string listen_user;
std::string listen_pass;
std::string listen_addr;
int listen_port;
};
struct Params {
std::vector<ListenParams> listens;
int concurrency;
net::HttpRequestHeaders extra_headers;
std::string proxy_url;
std::u16string proxy_user;
std::u16string proxy_pass;
std::string host_resolver_rules;
net::IPAddress resolver_range;
size_t resolver_prefix;
logging::LoggingSettings log_settings;
base::FilePath net_log_path;
base::FilePath ssl_key_path;
};
std::unique_ptr<base::Value::Dict> GetConstants() {
base::Value::Dict constants_dict = net::GetNetConstants();
base::Value::Dict dict;
@ -134,278 +102,6 @@ std::unique_ptr<base::Value::Dict> GetConstants() {
constants_dict.Set("clientInfo", std::move(dict));
return std::make_unique<base::Value::Dict>(std::move(constants_dict));
}
class MultipleListenCollector : public base::DuplicateSwitchHandler {
public:
void ResolveDuplicate(std::string_view key,
base::CommandLine::StringPieceType new_value,
base::CommandLine::StringType& out_value) override {
out_value = new_value;
if (key == "listen") {
#if BUILDFLAG(IS_WIN)
all_values_.push_back(base::WideToUTF8(new_value));
#else
all_values_.push_back(std::string(new_value));
#endif
}
}
const std::vector<std::string>& GetAllValues() const {
return all_values_;
}
private:
std::vector<std::string> all_values_;
};
void GetCommandLine(const base::CommandLine& proc,
CommandLine* cmdline,
MultipleListenCollector& multiple_listens) {
if (proc.HasSwitch("h") || proc.HasSwitch("help")) {
std::cout << "Usage: naive { OPTIONS | config.json }\n"
"\n"
"Options:\n"
"-h, --help Show this message\n"
"--version Print version\n"
"--listen=<proto>://[addr][:port] [--listen=...]\n"
" proto: socks, http\n"
" redir (Linux only)\n"
"--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n"
" proto: https, quic\n"
"--insecure-concurrency=<N> Use N connections, insecure\n"
"--extra-headers=... Extra headers split by CRLF\n"
"--host-resolver-rules=... Resolver rules\n"
"--resolver-range=... Redirect resolver range\n"
"--log[=<path>] Log to stderr, or file\n"
"--log-net-log=<path> Save NetLog\n"
"--ssl-key-log-file=<path> Save SSL keys for Wireshark\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
if (proc.HasSwitch("version")) {
std::cout << "naive " << version_info::GetVersionNumber() << std::endl;
exit(EXIT_SUCCESS);
}
cmdline->listens = multiple_listens.GetAllValues();
cmdline->proxy = proc.GetSwitchValueASCII("proxy");
cmdline->concurrency = proc.GetSwitchValueASCII("insecure-concurrency");
cmdline->extra_headers = proc.GetSwitchValueASCII("extra-headers");
cmdline->host_resolver_rules =
proc.GetSwitchValueASCII("host-resolver-rules");
cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range");
cmdline->no_log = !proc.HasSwitch("log");
cmdline->log = proc.GetSwitchValuePath("log");
cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log");
cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file");
}
void GetCommandLineFromConfig(const base::FilePath& config_path,
CommandLine* cmdline) {
JSONFileValueDeserializer reader(config_path);
int error_code;
std::string error_message;
std::unique_ptr<base::Value> value =
reader.Deserialize(&error_code, &error_message);
if (value == nullptr) {
std::cerr << "Error reading " << config_path << ": (" << error_code << ") "
<< error_message << std::endl;
exit(EXIT_FAILURE);
}
base::Value::Dict* value_dict = value->GetIfDict();
if (value_dict == nullptr) {
std::cerr << "Invalid config format" << std::endl;
exit(EXIT_FAILURE);
}
const std::string* listen = value_dict->FindString("listen");
if (listen != nullptr) {
cmdline->listens = {*listen};
} else {
const base::Value::List* listen_list = value_dict->FindList("listen");
if (listen_list != nullptr) {
for (const auto& listen_element : *listen_list) {
const std::string* listen_elemet_str = listen_element.GetIfString();
if (listen_elemet_str == nullptr) {
std::cerr << "Invalid listen element" << std::endl;
exit(EXIT_FAILURE);
}
cmdline->listens.push_back(*listen_elemet_str);
}
}
}
const std::string* proxy = value_dict->FindString("proxy");
if (proxy) {
cmdline->proxy = *proxy;
}
const std::string* concurrency =
value_dict->FindString("insecure-concurrency");
if (concurrency) {
cmdline->concurrency = *concurrency;
}
const std::string* extra_headers = value_dict->FindString("extra-headers");
if (extra_headers) {
cmdline->extra_headers = *extra_headers;
}
const std::string* host_resolver_rules =
value_dict->FindString("host-resolver-rules");
if (host_resolver_rules) {
cmdline->host_resolver_rules = *host_resolver_rules;
}
const std::string* resolver_range = value_dict->FindString("resolver-range");
if (resolver_range) {
cmdline->resolver_range = *resolver_range;
}
cmdline->no_log = true;
const std::string* log = value_dict->FindString("log");
if (log) {
cmdline->no_log = false;
cmdline->log = base::FilePath::FromUTF8Unsafe(*log);
}
const std::string* log_net_log = value_dict->FindString("log-net-log");
if (log_net_log) {
cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(*log_net_log);
}
const std::string* ssl_key_log_file =
value_dict->FindString("ssl-key-log-file");
if (ssl_key_log_file) {
cmdline->ssl_key_log_file =
base::FilePath::FromUTF8Unsafe(*ssl_key_log_file);
}
}
bool ParseListenParams(const std::string& listen_str,
ListenParams& listen_params) {
GURL url(listen_str);
if (url.scheme() == "socks") {
listen_params.protocol = net::ClientProtocol::kSocks5;
} else if (url.scheme() == "http") {
listen_params.protocol = net::ClientProtocol::kHttp;
} else if (url.scheme() == "redir") {
#if BUILDFLAG(IS_LINUX)
listen_params.protocol = net::ClientProtocol::kRedir;
#else
std::cerr << "Redir protocol only supports Linux." << std::endl;
return false;
#endif
} else {
std::cerr << "Invalid scheme in --listen" << std::endl;
return false;
}
if (!url.username().empty()) {
listen_params.listen_user =
base::UnescapeBinaryURLComponent(url.username());
}
if (!url.password().empty()) {
listen_params.listen_pass =
base::UnescapeBinaryURLComponent(url.password());
}
if (!url.host().empty()) {
listen_params.listen_addr = url.HostNoBrackets();
} else {
listen_params.listen_addr = "0.0.0.0";
}
int port = url.EffectiveIntPort();
if (port == url::PORT_INVALID) {
std::cerr << "Invalid port in --listen" << std::endl;
return false;
} else if (port == url::PORT_UNSPECIFIED) {
port = 1080;
}
listen_params.listen_port = port;
return true;
}
bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
url::AddStandardScheme("socks",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT);
bool any_redir_protocol = false;
if (!cmdline.listens.empty()) {
for (const std::string& listen : cmdline.listens) {
ListenParams listen_params;
if (!ParseListenParams(listen, listen_params)) {
std::cerr << "Invalid listen: " << listen << std::endl;
return false;
}
if (listen_params.protocol == net::ClientProtocol::kRedir) {
any_redir_protocol = true;
}
params->listens.push_back(listen_params);
}
} else {
ListenParams default_listen = {
.protocol = net::ClientProtocol::kSocks5,
.listen_addr = "0.0.0.0",
.listen_port = 1080,
};
params->listens = {default_listen};
}
params->proxy_url = "direct://";
GURL url(cmdline.proxy);
GURL::Replacements remove_auth;
remove_auth.ClearUsername();
remove_auth.ClearPassword();
GURL url_no_auth = url.ReplaceComponents(remove_auth);
if (!cmdline.proxy.empty()) {
params->proxy_url = url_no_auth.GetWithEmptyPath().spec();
if (params->proxy_url.empty()) {
std::cerr << "Invalid proxy URL" << std::endl;
return false;
} else if (params->proxy_url.back() == '/') {
params->proxy_url.pop_back();
}
net::GetIdentityFromURL(url, &params->proxy_user, &params->proxy_pass);
}
if (!cmdline.concurrency.empty()) {
if (!base::StringToInt(cmdline.concurrency, &params->concurrency) ||
params->concurrency < 1) {
std::cerr << "Invalid concurrency" << std::endl;
return false;
}
} else {
params->concurrency = 1;
}
params->extra_headers.AddHeadersFromString(cmdline.extra_headers);
params->host_resolver_rules = cmdline.host_resolver_rules;
if (any_redir_protocol) {
std::string range = "100.64.0.0/10";
if (!cmdline.resolver_range.empty())
range = cmdline.resolver_range;
if (!net::ParseCIDRBlock(range, &params->resolver_range,
&params->resolver_prefix)) {
std::cerr << "Invalid resolver range" << std::endl;
return false;
}
if (params->resolver_range.IsIPv6()) {
std::cerr << "IPv6 resolver range not supported" << std::endl;
return false;
}
}
if (!cmdline.no_log) {
if (!cmdline.log.empty()) {
params->log_settings.logging_dest = logging::LOG_TO_FILE;
params->log_settings.log_file_path = cmdline.log.value().c_str();
} else {
params->log_settings.logging_dest = logging::LOG_TO_STDERR;
}
} else {
params->log_settings.logging_dest = logging::LOG_NONE;
}
params->net_log_path = cmdline.log_net_log;
params->ssl_key_path = cmdline.ssl_key_log_file;
return true;
}
} // namespace
namespace net {
@ -471,7 +167,7 @@ std::unique_ptr<URLRequestContext> BuildCertURLRequestContext(NetLog* net_log) {
// Builds a URLRequestContext assuming there's only a single loop.
std::unique_ptr<URLRequestContext> BuildURLRequestContext(
const Params& params,
const NaiveConfig& config,
scoped_refptr<CertNetFetcherURLRequest> cert_net_fetcher,
NetLog* net_log) {
URLRequestContextBuilder builder;
@ -479,7 +175,7 @@ std::unique_ptr<URLRequestContext> BuildURLRequestContext(
builder.DisableHttpCache();
builder.set_net_log(net_log);
std::string proxy_url = params.proxy_url;
std::string proxy_url = config.proxy_url;
bool force_quic = false;
if (proxy_url.compare(0, 7, "quic://") == 0) {
proxy_url.replace(0, 4, "https");
@ -505,21 +201,21 @@ std::unique_ptr<URLRequestContext> BuildURLRequestContext(
proxy_service->ForceReloadProxyConfig();
builder.set_proxy_resolution_service(std::move(proxy_service));
if (!params.host_resolver_rules.empty()) {
builder.set_host_mapping_rules(params.host_resolver_rules);
if (!config.host_resolver_rules.empty()) {
builder.set_host_mapping_rules(config.host_resolver_rules);
}
builder.SetCertVerifier(
CertVerifier::CreateDefault(std::move(cert_net_fetcher)));
builder.set_proxy_delegate(std::make_unique<NaiveProxyDelegate>(
params.extra_headers,
config.extra_headers,
std::vector<PaddingType>{PaddingType::kVariant1, PaddingType::kNone}));
auto context = builder.Build();
if (!params.proxy_url.empty() && !params.proxy_user.empty() &&
!params.proxy_pass.empty()) {
if (!config.proxy_url.empty() && !config.proxy_user.empty() &&
!config.proxy_pass.empty()) {
auto* session = context->http_transaction_factory()->GetSession();
auto* auth_cache = session->http_auth_cache();
GURL proxy_gurl(proxy_url);
@ -530,7 +226,7 @@ std::unique_ptr<URLRequestContext> BuildURLRequestContext(
net::HostPortPair::FromURL(proxy_gurl));
}
url::SchemeHostPort auth_origin(proxy_gurl);
AuthCredentials credentials(params.proxy_user, params.proxy_pass);
AuthCredentials credentials(config.proxy_user, config.proxy_pass);
auth_cache->Add(auth_origin, HttpAuth::AUTH_PROXY,
/*realm=*/{}, HttpAuth::AUTH_SCHEME_BASIC, {},
/*challenge=*/"Basic", credentials, /*path=*/"/");
@ -565,9 +261,7 @@ int main(int argc, char* argv[]) {
// content/app/content_main.cc: RunContentProcess()
base::EnableTerminationOnOutOfMemory();
auto multiple_listens = std::make_unique<MultipleListenCollector>();
MultipleListenCollector& multiple_listens_ref = *multiple_listens;
base::CommandLine::SetDuplicateSwitchHandler(std::move(multiple_listens));
DuplicateSwitchCollector::InitInstance();
// content/app/content_main.cc: RunContentProcess()
base::CommandLine::Init(argc, argv);
@ -606,6 +300,9 @@ int main(int argc, char* argv[]) {
url::AddStandardScheme("quic",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
url::AddStandardScheme("socks",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
url::AddStandardScheme("redir", url::SCHEME_WITH_HOST_AND_PORT);
net::ClientSocketPoolManager::set_max_sockets_per_pool(
net::HttpNetworkSession::NORMAL_SOCKET_POOL,
kDefaultMaxSocketsPerPool * kExpectedMaxUsers);
@ -616,30 +313,69 @@ int main(int argc, char* argv[]) {
net::HttpNetworkSession::NORMAL_SOCKET_POOL,
kDefaultMaxSocketsPerGroup * kExpectedMaxUsers);
CommandLine cmdline;
Params params;
const auto& proc = *base::CommandLine::ForCurrentProcess();
const auto& args = proc.GetArgs();
if (args.empty()) {
if (proc.argv().size() >= 2) {
GetCommandLine(proc, &cmdline, multiple_listens_ref);
base::Value::Dict config_dict;
if (args.empty() && proc.argv().size() >= 2) {
config_dict = GetSwitchesAsValue(proc);
} else {
auto path = base::FilePath::FromUTF8Unsafe("config.json");
GetCommandLineFromConfig(path, &cmdline);
}
base::FilePath config_file;
if (!args.empty()) {
config_file = base::FilePath(args[0]);
} else {
base::FilePath path(args[0]);
GetCommandLineFromConfig(path, &cmdline);
config_file = base::FilePath::FromUTF8Unsafe("config.json");
}
if (!ParseCommandLine(cmdline, &params)) {
JSONFileValueDeserializer reader(config_file);
int error_code;
std::string error_message;
std::unique_ptr<base::Value> value =
reader.Deserialize(&error_code, &error_message);
if (value == nullptr) {
std::cerr << "Error reading " << config_file << ": (" << error_code
<< ") " << error_message << std::endl;
return EXIT_FAILURE;
}
CHECK(logging::InitLogging(params.log_settings));
if (const base::Value::Dict* dict = value->GetIfDict()) {
config_dict = dict->Clone();
}
}
if (!params.ssl_key_path.empty()) {
if (config_dict.contains("h") || config_dict.contains("help")) {
std::cout << "Usage: naive { OPTIONS | config.json }\n"
"\n"
"Options:\n"
"-h, --help Show this message\n"
"--version Print version\n"
"--listen=<proto>://[addr][:port] [--listen=...]\n"
" proto: socks, http\n"
" redir (Linux only)\n"
"--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n"
" proto: https, quic\n"
"--insecure-concurrency=<N> Use N connections, insecure\n"
"--extra-headers=... Extra headers split by CRLF\n"
"--host-resolver-rules=... Resolver rules\n"
"--resolver-range=... Redirect resolver range\n"
"--log[=<path>] Log to stderr, or file\n"
"--log-net-log=<path> Save NetLog\n"
"--ssl-key-log-file=<path> Save SSL keys for Wireshark\n"
<< std::endl;
exit(EXIT_SUCCESS);
}
if (config_dict.contains("version")) {
std::cout << "naive " << version_info::GetVersionNumber() << std::endl;
exit(EXIT_SUCCESS);
}
net::NaiveConfig config;
if (!config.Parse(config_dict)) {
return EXIT_FAILURE;
}
CHECK(logging::InitLogging(config.log));
if (!config.ssl_key_log_file.empty()) {
net::SSLClientSocket::SetSSLKeyLogger(
std::make_unique<net::SSLKeyLoggerImpl>(params.ssl_key_path));
std::make_unique<net::SSLKeyLoggerImpl>(config.ssl_key_log_file));
}
// The declaration order for net_log and printing_log_observer is
@ -648,15 +384,15 @@ int main(int argc, char* argv[]) {
// printing_log_observer.
net::NetLog* net_log = net::NetLog::Get();
std::unique_ptr<net::FileNetLogObserver> observer;
if (!params.net_log_path.empty()) {
if (!config.log_net_log.empty()) {
observer = net::FileNetLogObserver::CreateUnbounded(
params.net_log_path, net::NetLogCaptureMode::kDefault, GetConstants());
config.log_net_log, net::NetLogCaptureMode::kDefault, GetConstants());
observer->StartObserving(net_log);
}
// Avoids net log overhead if verbose logging is disabled.
std::unique_ptr<net::PrintingLogObserver> printing_log_observer;
if (params.log_settings.logging_dest != logging::LOG_NONE && VLOG_IS_ON(1)) {
if (config.log.logging_dest != logging::LOG_NONE && VLOG_IS_ON(1)) {
printing_log_observer = std::make_unique<net::PrintingLogObserver>();
net_log->AddObserver(printing_log_observer.get(),
net::NetLogCaptureMode::kDefault);
@ -675,43 +411,41 @@ int main(int argc, char* argv[]) {
cert_net_fetcher->SetURLRequestContext(cert_context.get());
#endif
auto context =
net::BuildURLRequestContext(params, std::move(cert_net_fetcher), net_log);
net::BuildURLRequestContext(config, std::move(cert_net_fetcher), net_log);
auto* session = context->http_transaction_factory()->GetSession();
std::vector<std::unique_ptr<net::NaiveProxy>> naive_proxies;
std::unique_ptr<net::RedirectResolver> resolver;
for (const ListenParams& listen_params : params.listens) {
for (const net::NaiveListenConfig& listen_config : config.listen) {
auto listen_socket =
std::make_unique<net::TCPServerSocket>(net_log, net::NetLogSource());
int result = listen_socket->ListenWithAddressAndPort(
listen_params.listen_addr, listen_params.listen_port, kListenBackLog);
listen_config.addr, listen_config.port, kListenBackLog);
if (result != net::OK) {
LOG(ERROR) << "Failed to listen on "
<< net::ToString(listen_params.protocol) << "://"
<< listen_params.listen_addr << " "
<< listen_params.listen_port << ": "
<< net::ToString(listen_config.protocol) << "://"
<< listen_config.addr << " " << listen_config.port << ": "
<< net::ErrorToShortString(result);
return EXIT_FAILURE;
}
LOG(INFO) << "Listening on " << net::ToString(listen_params.protocol)
<< "://" << listen_params.listen_addr << ":"
<< listen_params.listen_port;
LOG(INFO) << "Listening on " << net::ToString(listen_config.protocol)
<< "://" << listen_config.addr << ":" << listen_config.port;
if (resolver == nullptr &&
listen_params.protocol == net::ClientProtocol::kRedir) {
listen_config.protocol == net::ClientProtocol::kRedir) {
auto resolver_socket =
std::make_unique<net::UDPServerSocket>(net_log, net::NetLogSource());
resolver_socket->AllowAddressReuse();
net::IPAddress listen_addr;
if (!listen_addr.AssignFromIPLiteral(listen_params.listen_addr)) {
LOG(ERROR) << "Failed to open resolver: " << listen_params.listen_addr;
if (!listen_addr.AssignFromIPLiteral(listen_config.addr)) {
LOG(ERROR) << "Failed to open resolver: " << listen_config.addr;
return EXIT_FAILURE;
}
result = resolver_socket->Listen(
net::IPEndPoint(listen_addr, listen_params.listen_port));
net::IPEndPoint(listen_addr, listen_config.port));
if (result != net::OK) {
LOG(ERROR) << "Failed to open resolver: "
<< net::ErrorToShortString(result);
@ -719,14 +453,14 @@ int main(int argc, char* argv[]) {
}
resolver = std::make_unique<net::RedirectResolver>(
std::move(resolver_socket), params.resolver_range,
params.resolver_prefix);
std::move(resolver_socket), config.resolver_range,
config.resolver_prefix);
}
auto naive_proxy = std::make_unique<net::NaiveProxy>(
std::move(listen_socket), listen_params.protocol,
listen_params.listen_user, listen_params.listen_pass,
params.concurrency, resolver.get(), session, kTrafficAnnotation,
std::move(listen_socket), listen_config.protocol, listen_config.user,
listen_config.pass, config.insecure_concurrency, resolver.get(),
session, kTrafficAnnotation,
std::vector<net::PaddingType>{net::PaddingType::kVariant1,
net::PaddingType::kNone});
naive_proxies.push_back(std::move(naive_proxy));