From 89e3c2207b248c933fbf1cc3950734c16a7e8c08 Mon Sep 17 00:00:00 2001 From: klzgrad Date: Thu, 10 Jan 2019 00:30:56 -0500 Subject: [PATCH] Support loading config.json --- src/net/tools/naive/naive_proxy_bin.cc | 139 ++++++++++++++++++------- 1 file changed, 101 insertions(+), 38 deletions(-) diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc index c0c4c86223..3df53890f7 100644 --- a/src/net/tools/naive/naive_proxy_bin.cc +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -12,6 +12,7 @@ #include "base/at_exit.h" #include "base/command_line.h" #include "base/files/file_path.h" +#include "base/json/json_file_value_serializer.h" #include "base/json/json_writer.h" #include "base/logging.h" #include "base/macros.h" @@ -68,6 +69,17 @@ constexpr char kDefaultHostName[] = "example"; constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = net::DefineNetworkTrafficAnnotation("naive", ""); +struct CommandLine { + std::string listen; + std::string proxy; + bool padding; + std::string host_resolver_rules; + bool no_log; + base::FilePath log; + base::FilePath log_net_log; + base::FilePath ssl_key_log_file; +}; + struct Params { net::NaiveConnection::Protocol protocol; std::string listen_addr; @@ -147,11 +159,9 @@ std::unique_ptr BuildURLRequestContext( return context; } -bool ParseCommandLineFlags(Params* params) { - const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); - - if (line.HasSwitch("h") || line.HasSwitch("help")) { - std::cout << "Usage: naive [options]\n" +void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { + if (proc.HasSwitch("h") || proc.HasSwitch("help")) { + std::cout << "Usage: naive { OPTIONS | config.json }\n" "\n" "Options:\n" "-h, --help Show this message\n" @@ -161,26 +171,75 @@ bool ParseCommandLineFlags(Params* params) { "--proxy=://[:@][:]\n" " proto: https, quic\n" "--padding Use padding\n" + "--host-resolver-rules=... Resolver rules\n" "--log[=] Log to stderr, or file\n" "--log-net-log= Save NetLog\n" "--ssl-key-log-file= Save SSL keys for Wireshark\n" << std::endl; exit(EXIT_SUCCESS); - return false; } - if (line.HasSwitch("version")) { - std::cout << "Version: " << version_info::GetVersionNumber() << std::endl; + if (proc.HasSwitch("version")) { + std::cout << "naive " << version_info::GetVersionNumber() << std::endl; exit(EXIT_SUCCESS); - return false; } + cmdline->listen = proc.GetSwitchValueASCII("listen"); + cmdline->proxy = proc.GetSwitchValueASCII("proxy"); + cmdline->padding = proc.HasSwitch("padding"); + cmdline->host_resolver_rules = proc.GetSwitchValueASCII("host-resolver-rules"); + cmdline->no_log = !proc.HasSwitch("log"); + cmdline->log = proc.GetSwitchValuePath("log"); + cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log"); + cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file"); +} + +void GetCommandLineFromConfig(const base::FilePath& config_path, CommandLine* cmdline) { + JSONFileValueDeserializer reader(config_path); + int error_code; + std::string error_message; + auto value = reader.Deserialize(&error_code, &error_message); + if (value == nullptr) { + std::cerr << "Error reading " << config_path << ": (" << error_code << ") " << error_message << std::endl; + exit(EXIT_FAILURE); + } + if (!value->is_dict()) { + std::cerr << "Invalid config format" << std::endl; + exit(EXIT_FAILURE); + } + if (value->FindKeyOfType("listen", base::Value::Type::STRING)) { + cmdline->listen = value->FindKey("listen")->GetString(); + } + if (value->FindKeyOfType("proxy", base::Value::Type::STRING)) { + cmdline->proxy = value->FindKey("proxy")->GetString(); + } + cmdline->padding = false; + if (value->FindKeyOfType("padding", base::Value::Type::BOOLEAN)) { + cmdline->padding = value->FindKey("padding")->GetBool(); + } + if (value->FindKeyOfType("host-resolver-rules", base::Value::Type::STRING)) { + cmdline->host_resolver_rules = value->FindKey("host-resolver-rules")->GetString(); + } + cmdline->no_log = true; + if (value->FindKeyOfType("log", base::Value::Type::STRING)) { + cmdline->no_log = false; + cmdline->log = base::FilePath::FromUTF8Unsafe(value->FindKey("log")->GetString()); + } + if (value->FindKeyOfType("log-net-log", base::Value::Type::STRING)) { + cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(value->FindKey("log-net-log")->GetString()); + } + if (value->FindKeyOfType("ssl-key-log-file", base::Value::Type::STRING)) { + cmdline->ssl_key_log_file = base::FilePath::FromUTF8Unsafe(value->FindKey("ssl-key-log-file")->GetString()); + } +} + +bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->protocol = net::NaiveConnection::kSocks5; params->listen_addr = "0.0.0.0"; params->listen_port = 1080; url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT); - if (line.HasSwitch("listen")) { - GURL url(line.GetSwitchValueASCII("listen")); + if (!cmdline.listen.empty()) { + GURL url(cmdline.listen); if (url.scheme() == "socks") { params->protocol = net::NaiveConnection::kSocks5; params->listen_port = 1080; @@ -188,7 +247,7 @@ bool ParseCommandLineFlags(Params* params) { params->protocol = net::NaiveConnection::kHttp; params->listen_port = 8080; } else { - LOG(ERROR) << "Invalid scheme in --listen"; + std::cerr << "Invalid scheme in --listen" << std::endl; return false; } if (!url.host().empty()) { @@ -196,12 +255,12 @@ bool ParseCommandLineFlags(Params* params) { } if (!url.port().empty()) { if (!base::StringToInt(url.port(), ¶ms->listen_port)) { - LOG(ERROR) << "Invalid port in --listen"; + std::cerr << "Invalid port in --listen" << std::endl; return false; } if (params->listen_port <= 0 || params->listen_port > std::numeric_limits::max()) { - LOG(ERROR) << "Invalid port in --listen"; + std::cerr << "Invalid port in --listen" << std::endl; return false; } } @@ -210,14 +269,14 @@ bool ParseCommandLineFlags(Params* params) { url::AddStandardScheme("quic", url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); params->proxy_url = "direct://"; - GURL url(line.GetSwitchValueASCII("proxy")); - if (line.HasSwitch("proxy")) { + GURL url(cmdline.proxy); + if (!cmdline.proxy.empty()) { if (!url.is_valid()) { - LOG(ERROR) << "Invalid proxy URL"; + std::cerr << "Invalid proxy URL" << std::endl; return false; } if (url.scheme() != "https" && url.scheme() != "quic") { - LOG(ERROR) << "Must be HTTPS or QUIC proxy"; + std::cerr << "Must be HTTPS or QUIC proxy" << std::endl; return false; } params->proxy_url = url::SchemeHostPort(url).Serialize(); @@ -225,14 +284,10 @@ bool ParseCommandLineFlags(Params* params) { params->proxy_pass = url.password(); } - params->use_padding = false; - if (line.HasSwitch("padding")) { - params->use_padding = true; - } + params->use_padding = cmdline.padding; - if (line.HasSwitch("host-resolver-rules")) { - params->host_resolver_rules = - line.GetSwitchValueASCII("host-resolver-rules"); + if (!cmdline.host_resolver_rules.empty()) { + params->host_resolver_rules = cmdline.host_resolver_rules; } else { // SNI should only contain DNS hostnames not IP addresses per RFC 6066. if (url.HostIsIPAddress()) { @@ -247,26 +302,20 @@ bool ParseCommandLineFlags(Params* params) { } } - if (line.HasSwitch("log")) { - params->log_settings.logging_dest = logging::LOG_DEFAULT; - params->log_path = line.GetSwitchValuePath("log"); + if (!cmdline.no_log) { if (!params->log_path.empty()) { params->log_settings.logging_dest = logging::LOG_TO_FILE; - } else if (params->log_settings.logging_dest == logging::LOG_TO_FILE) { - params->log_path = base::FilePath::FromUTF8Unsafe("naive.log"); + params->log_path = cmdline.log; + } else { + params->log_settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG; } params->log_settings.log_file = params->log_path.value().c_str(); } else { params->log_settings.logging_dest = logging::LOG_NONE; } - if (line.HasSwitch("log-net-log")) { - params->net_log_path = line.GetSwitchValuePath("log-net-log"); - } - - if (line.HasSwitch("ssl-key-log-file")) { - params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file"); - } + params->net_log_path = cmdline.log_net_log; + params->ssl_key_path = cmdline.ssl_key_log_file; return true; } @@ -330,8 +379,22 @@ int main(int argc, char* argv[]) { base::CommandLine::Init(argc, argv); + CommandLine cmdline; Params params; - if (!ParseCommandLineFlags(¶ms)) { + const auto& proc = *base::CommandLine::ForCurrentProcess(); + const auto& args = proc.GetArgs(); + if (args.empty()) { + if (proc.argv().size() >= 2) { + GetCommandLine(proc, &cmdline); + } else { + auto path = base::FilePath::FromUTF8Unsafe("config.json"); + GetCommandLineFromConfig(path, &cmdline); + } + } else { + base::FilePath path(args[0]); + GetCommandLineFromConfig(path, &cmdline); + } + if (!ParseCommandLine(cmdline, ¶ms)) { return EXIT_FAILURE; }