diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc index c30612c772..4a133396ba 100644 --- a/src/net/tools/naive/naive_proxy_bin.cc +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -12,6 +12,7 @@ #include "base/at_exit.h" #include "base/command_line.h" #include "base/files/file_path.h" +#include "base/json/json_file_value_serializer.h" #include "base/json/json_writer.h" #include "base/logging.h" #include "base/macros.h" @@ -70,6 +71,17 @@ constexpr int kExpectedMaxUsers = 8; constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = net::DefineNetworkTrafficAnnotation("naive", ""); +struct CommandLine { + std::string listen; + std::string proxy; + bool padding; + std::string host_resolver_rules; + bool no_log; + base::FilePath log; + base::FilePath log_net_log; + base::FilePath ssl_key_log_file; +}; + struct Params { net::NaiveConnection::Protocol protocol; std::string listen_addr; @@ -80,7 +92,6 @@ struct Params { std::u16string proxy_pass; std::string host_resolver_rules; logging::LoggingSettings log_settings; - base::FilePath log_path; base::FilePath net_log_path; base::FilePath ssl_key_path; }; @@ -108,6 +119,7 @@ std::unique_ptr BuildURLRequestContext( net::ProxyConfig proxy_config; proxy_config.proxy_rules().ParseFromString(params.proxy_url); + LOG(INFO) << "Proxying via " << params.proxy_url; auto proxy_service = net::ConfiguredProxyResolutionService::CreateWithoutProxyResolver( std::make_unique( @@ -146,11 +158,9 @@ std::unique_ptr BuildURLRequestContext( return context; } -bool ParseCommandLineFlags(Params* params) { - const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); - - if (line.HasSwitch("h") || line.HasSwitch("help")) { - std::cout << "Usage: naive [options]\n" +void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { + if (proc.HasSwitch("h") || proc.HasSwitch("help")) { + std::cout << "Usage: naive { OPTIONS | config.json }\n" "\n" "Options:\n" "-h, --help Show this message\n" @@ -160,26 +170,90 @@ bool ParseCommandLineFlags(Params* params) { "--proxy=://[:@][:]\n" " proto: https, quic\n" "--padding Use padding\n" + "--host-resolver-rules=... Resolver rules\n" "--log[=] Log to stderr, or file\n" "--log-net-log= Save NetLog\n" "--ssl-key-log-file= Save SSL keys for Wireshark\n" << std::endl; exit(EXIT_SUCCESS); - return false; } - if (line.HasSwitch("version")) { - std::cout << "Version: " << version_info::GetVersionNumber() << std::endl; + if (proc.HasSwitch("version")) { + std::cout << "naive " << version_info::GetVersionNumber() << std::endl; exit(EXIT_SUCCESS); - return false; } + cmdline->listen = proc.GetSwitchValueASCII("listen"); + cmdline->proxy = proc.GetSwitchValueASCII("proxy"); + cmdline->padding = proc.HasSwitch("padding"); + cmdline->host_resolver_rules = + proc.GetSwitchValueASCII("host-resolver-rules"); + cmdline->no_log = !proc.HasSwitch("log"); + cmdline->log = proc.GetSwitchValuePath("log"); + cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log"); + cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file"); +} + +void GetCommandLineFromConfig(const base::FilePath& config_path, + CommandLine* cmdline) { + JSONFileValueDeserializer reader(config_path); + int error_code; + std::string error_message; + auto value = reader.Deserialize(&error_code, &error_message); + if (value == nullptr) { + std::cerr << "Error reading " << config_path << ": (" << error_code << ") " + << error_message << std::endl; + exit(EXIT_FAILURE); + } + if (!value->is_dict()) { + std::cerr << "Invalid config format" << std::endl; + exit(EXIT_FAILURE); + } + if (value->FindKeyOfType("listen", base::Value::Type::STRING)) { + cmdline->listen = value->FindKey("listen")->GetString(); + } + if (value->FindKeyOfType("proxy", base::Value::Type::STRING)) { + cmdline->proxy = value->FindKey("proxy")->GetString(); + } + cmdline->padding = false; + if (value->FindKeyOfType("padding", base::Value::Type::BOOLEAN)) { + cmdline->padding = value->FindKey("padding")->GetBool(); + } + if (value->FindKeyOfType("host-resolver-rules", base::Value::Type::STRING)) { + cmdline->host_resolver_rules = + value->FindKey("host-resolver-rules")->GetString(); + } + cmdline->no_log = true; + if (value->FindKeyOfType("log", base::Value::Type::STRING)) { + cmdline->no_log = false; + cmdline->log = + base::FilePath::FromUTF8Unsafe(value->FindKey("log")->GetString()); + } + if (value->FindKeyOfType("log-net-log", base::Value::Type::STRING)) { + cmdline->log_net_log = base::FilePath::FromUTF8Unsafe( + value->FindKey("log-net-log")->GetString()); + } + if (value->FindKeyOfType("ssl-key-log-file", base::Value::Type::STRING)) { + cmdline->ssl_key_log_file = base::FilePath::FromUTF8Unsafe( + value->FindKey("ssl-key-log-file")->GetString()); + } +} + +std::string GetProxyFromURL(const GURL& url) { + std::string str = url.GetWithEmptyPath().spec(); + if (str.size() && str.back() == '/') { + str.pop_back(); + } + return str; +} + +bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->protocol = net::NaiveConnection::kSocks5; params->listen_addr = "0.0.0.0"; params->listen_port = 1080; url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT); - if (line.HasSwitch("listen")) { - GURL url(line.GetSwitchValueASCII("listen")); + if (!cmdline.listen.empty()) { + GURL url(cmdline.listen); if (url.scheme() == "socks") { params->protocol = net::NaiveConnection::kSocks5; params->listen_port = 1080; @@ -187,7 +261,7 @@ bool ParseCommandLineFlags(Params* params) { params->protocol = net::NaiveConnection::kHttp; params->listen_port = 8080; } else { - LOG(ERROR) << "Invalid scheme in --listen"; + std::cerr << "Invalid scheme in --listen" << std::endl; return false; } if (!url.host().empty()) { @@ -195,12 +269,12 @@ bool ParseCommandLineFlags(Params* params) { } if (!url.port().empty()) { if (!base::StringToInt(url.port(), ¶ms->listen_port)) { - LOG(ERROR) << "Invalid port in --listen"; + std::cerr << "Invalid port in --listen" << std::endl; return false; } if (params->listen_port <= 0 || params->listen_port > std::numeric_limits::max()) { - LOG(ERROR) << "Invalid port in --listen"; + std::cerr << "Invalid port in --listen" << std::endl; return false; } } @@ -209,50 +283,37 @@ bool ParseCommandLineFlags(Params* params) { url::AddStandardScheme("quic", url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); params->proxy_url = "direct://"; - GURL url(line.GetSwitchValueASCII("proxy")); - if (line.HasSwitch("proxy")) { + GURL url(cmdline.proxy); + GURL::Replacements remove_auth; + remove_auth.ClearUsername(); + remove_auth.ClearPassword(); + GURL url_no_auth = url.ReplaceComponents(remove_auth); + if (!cmdline.proxy.empty()) { if (!url.is_valid()) { - LOG(ERROR) << "Invalid proxy URL"; + std::cerr << "Invalid proxy URL" << std::endl; return false; } - if (url.scheme() != "https" && url.scheme() != "quic") { - LOG(ERROR) << "Must be HTTPS or QUIC proxy"; - return false; - } - params->proxy_url = url::SchemeHostPort(url).Serialize(); + params->proxy_url = GetProxyFromURL(url_no_auth); net::GetIdentityFromURL(url, ¶ms->proxy_user, ¶ms->proxy_pass); } - params->use_padding = false; - if (line.HasSwitch("padding")) { - params->use_padding = true; - } + params->use_padding = cmdline.padding; - if (line.HasSwitch("host-resolver-rules")) { - params->host_resolver_rules = - line.GetSwitchValueASCII("host-resolver-rules"); - } + params->host_resolver_rules = cmdline.host_resolver_rules; - if (line.HasSwitch("log")) { - params->log_settings.logging_dest = logging::LOG_DEFAULT; - params->log_path = line.GetSwitchValuePath("log"); - if (!params->log_path.empty()) { + if (!cmdline.no_log) { + if (!cmdline.log.empty()) { params->log_settings.logging_dest = logging::LOG_TO_FILE; - } else if (params->log_settings.logging_dest == logging::LOG_TO_FILE) { - params->log_path = base::FilePath::FromUTF8Unsafe("naive.log"); + params->log_settings.log_file_path = cmdline.log.value().c_str(); + } else { + params->log_settings.logging_dest = logging::LOG_TO_STDERR; } - params->log_settings.log_file_path = params->log_path.value().c_str(); } else { params->log_settings.logging_dest = logging::LOG_NONE; } - if (line.HasSwitch("log-net-log")) { - params->net_log_path = line.GetSwitchValuePath("log-net-log"); - } - - if (line.HasSwitch("ssl-key-log-file")) { - params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file"); - } + params->net_log_path = cmdline.log_net_log; + params->ssl_key_path = cmdline.ssl_key_log_file; return true; } @@ -314,8 +375,22 @@ int main(int argc, char* argv[]) { base::CommandLine::Init(argc, argv); + CommandLine cmdline; Params params; - if (!ParseCommandLineFlags(¶ms)) { + const auto& proc = *base::CommandLine::ForCurrentProcess(); + const auto& args = proc.GetArgs(); + if (args.empty()) { + if (proc.argv().size() >= 2) { + GetCommandLine(proc, &cmdline); + } else { + auto path = base::FilePath::FromUTF8Unsafe("config.json"); + GetCommandLineFromConfig(path, &cmdline); + } + } else { + base::FilePath path(args[0]); + GetCommandLineFromConfig(path, &cmdline); + } + if (!ParseCommandLine(cmdline, ¶ms)) { return EXIT_FAILURE; } @@ -368,6 +443,8 @@ int main(int argc, char* argv[]) { LOG(ERROR) << "Failed to listen: " << result; return EXIT_FAILURE; } + LOG(INFO) << "Listening on " << params.listen_addr << ":" + << params.listen_port; net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, params.use_padding, session, kTrafficAnnotation);