Support loading config.json

This commit is contained in:
klzgrad 2019-01-10 00:30:56 -05:00
parent 0387b45f0d
commit a070cc7ad5

View File

@ -12,6 +12,7 @@
#include "base/at_exit.h" #include "base/at_exit.h"
#include "base/command_line.h" #include "base/command_line.h"
#include "base/files/file_path.h" #include "base/files/file_path.h"
#include "base/json/json_file_value_serializer.h"
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
#include "base/logging.h" #include "base/logging.h"
#include "base/macros.h" #include "base/macros.h"
@ -68,6 +69,17 @@ constexpr char kDefaultHostName[] = "example";
constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation = constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("naive", ""); net::DefineNetworkTrafficAnnotation("naive", "");
struct CommandLine {
std::string listen;
std::string proxy;
bool padding;
std::string host_resolver_rules;
bool no_log;
base::FilePath log;
base::FilePath log_net_log;
base::FilePath ssl_key_log_file;
};
struct Params { struct Params {
net::NaiveConnection::Protocol protocol; net::NaiveConnection::Protocol protocol;
std::string listen_addr; std::string listen_addr;
@ -116,6 +128,7 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
net::ProxyConfig proxy_config; net::ProxyConfig proxy_config;
proxy_config.proxy_rules().ParseFromString(params.proxy_url); proxy_config.proxy_rules().ParseFromString(params.proxy_url);
LOG(INFO) << "Proxying via " << params.proxy_url;
auto proxy_service = net::ProxyResolutionService::CreateWithoutProxyResolver( auto proxy_service = net::ProxyResolutionService::CreateWithoutProxyResolver(
std::make_unique<net::ProxyConfigServiceFixed>( std::make_unique<net::ProxyConfigServiceFixed>(
net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)), net::ProxyConfigWithAnnotation(proxy_config, kTrafficAnnotation)),
@ -148,11 +161,9 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
return context; return context;
} }
bool ParseCommandLineFlags(Params* params) { void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) {
const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); if (proc.HasSwitch("h") || proc.HasSwitch("help")) {
std::cout << "Usage: naive { OPTIONS | config.json }\n"
if (line.HasSwitch("h") || line.HasSwitch("help")) {
std::cout << "Usage: naive [options]\n"
"\n" "\n"
"Options:\n" "Options:\n"
"-h, --help Show this message\n" "-h, --help Show this message\n"
@ -162,26 +173,90 @@ bool ParseCommandLineFlags(Params* params) {
"--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n" "--proxy=<proto>://[<user>:<pass>@]<hostname>[:<port>]\n"
" proto: https, quic\n" " proto: https, quic\n"
"--padding Use padding\n" "--padding Use padding\n"
"--host-resolver-rules=... Resolver rules\n"
"--log[=<path>] Log to stderr, or file\n" "--log[=<path>] Log to stderr, or file\n"
"--log-net-log=<path> Save NetLog\n" "--log-net-log=<path> Save NetLog\n"
"--ssl-key-log-file=<path> Save SSL keys for Wireshark\n" "--ssl-key-log-file=<path> Save SSL keys for Wireshark\n"
<< std::endl; << std::endl;
exit(EXIT_SUCCESS); exit(EXIT_SUCCESS);
return false;
} }
if (line.HasSwitch("version")) { if (proc.HasSwitch("version")) {
std::cout << "Version: " << version_info::GetVersionNumber() << std::endl; std::cout << "naive " << version_info::GetVersionNumber() << std::endl;
exit(EXIT_SUCCESS); exit(EXIT_SUCCESS);
return false;
} }
cmdline->listen = proc.GetSwitchValueASCII("listen");
cmdline->proxy = proc.GetSwitchValueASCII("proxy");
cmdline->padding = proc.HasSwitch("padding");
cmdline->host_resolver_rules =
proc.GetSwitchValueASCII("host-resolver-rules");
cmdline->no_log = !proc.HasSwitch("log");
cmdline->log = proc.GetSwitchValuePath("log");
cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log");
cmdline->ssl_key_log_file = proc.GetSwitchValuePath("ssl-key-log-file");
}
void GetCommandLineFromConfig(const base::FilePath& config_path,
CommandLine* cmdline) {
JSONFileValueDeserializer reader(config_path);
int error_code;
std::string error_message;
auto value = reader.Deserialize(&error_code, &error_message);
if (value == nullptr) {
std::cerr << "Error reading " << config_path << ": (" << error_code << ") "
<< error_message << std::endl;
exit(EXIT_FAILURE);
}
if (!value->is_dict()) {
std::cerr << "Invalid config format" << std::endl;
exit(EXIT_FAILURE);
}
if (value->FindKeyOfType("listen", base::Value::Type::STRING)) {
cmdline->listen = value->FindKey("listen")->GetString();
}
if (value->FindKeyOfType("proxy", base::Value::Type::STRING)) {
cmdline->proxy = value->FindKey("proxy")->GetString();
}
cmdline->padding = false;
if (value->FindKeyOfType("padding", base::Value::Type::BOOLEAN)) {
cmdline->padding = value->FindKey("padding")->GetBool();
}
if (value->FindKeyOfType("host-resolver-rules", base::Value::Type::STRING)) {
cmdline->host_resolver_rules =
value->FindKey("host-resolver-rules")->GetString();
}
cmdline->no_log = true;
if (value->FindKeyOfType("log", base::Value::Type::STRING)) {
cmdline->no_log = false;
cmdline->log =
base::FilePath::FromUTF8Unsafe(value->FindKey("log")->GetString());
}
if (value->FindKeyOfType("log-net-log", base::Value::Type::STRING)) {
cmdline->log_net_log = base::FilePath::FromUTF8Unsafe(
value->FindKey("log-net-log")->GetString());
}
if (value->FindKeyOfType("ssl-key-log-file", base::Value::Type::STRING)) {
cmdline->ssl_key_log_file = base::FilePath::FromUTF8Unsafe(
value->FindKey("ssl-key-log-file")->GetString());
}
}
std::string GetProxyFromURL(const GURL& url) {
std::string str = url.GetWithEmptyPath().spec();
if (str.size() && str.back() == '/') {
str.pop_back();
}
return str;
}
bool ParseCommandLine(const CommandLine& cmdline, Params* params) {
params->protocol = net::NaiveConnection::kSocks5; params->protocol = net::NaiveConnection::kSocks5;
params->listen_addr = "0.0.0.0"; params->listen_addr = "0.0.0.0";
params->listen_port = 1080; params->listen_port = 1080;
url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT); url::AddStandardScheme("socks", url::SCHEME_WITH_HOST_AND_PORT);
if (line.HasSwitch("listen")) { if (!cmdline.listen.empty()) {
GURL url(line.GetSwitchValueASCII("listen")); GURL url(cmdline.listen);
if (url.scheme() == "socks") { if (url.scheme() == "socks") {
params->protocol = net::NaiveConnection::kSocks5; params->protocol = net::NaiveConnection::kSocks5;
params->listen_port = 1080; params->listen_port = 1080;
@ -189,7 +264,7 @@ bool ParseCommandLineFlags(Params* params) {
params->protocol = net::NaiveConnection::kHttp; params->protocol = net::NaiveConnection::kHttp;
params->listen_port = 8080; params->listen_port = 8080;
} else { } else {
LOG(ERROR) << "Invalid scheme in --listen"; std::cerr << "Invalid scheme in --listen" << std::endl;
return false; return false;
} }
if (!url.host().empty()) { if (!url.host().empty()) {
@ -197,12 +272,12 @@ bool ParseCommandLineFlags(Params* params) {
} }
if (!url.port().empty()) { if (!url.port().empty()) {
if (!base::StringToInt(url.port(), &params->listen_port)) { if (!base::StringToInt(url.port(), &params->listen_port)) {
LOG(ERROR) << "Invalid port in --listen"; std::cerr << "Invalid port in --listen" << std::endl;
return false; return false;
} }
if (params->listen_port <= 0 || if (params->listen_port <= 0 ||
params->listen_port > std::numeric_limits<uint16_t>::max()) { params->listen_port > std::numeric_limits<uint16_t>::max()) {
LOG(ERROR) << "Invalid port in --listen"; std::cerr << "Invalid port in --listen" << std::endl;
return false; return false;
} }
} }
@ -211,36 +286,32 @@ bool ParseCommandLineFlags(Params* params) {
url::AddStandardScheme("quic", url::AddStandardScheme("quic",
url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION); url::SCHEME_WITH_HOST_PORT_AND_USER_INFORMATION);
params->proxy_url = "direct://"; params->proxy_url = "direct://";
GURL url(line.GetSwitchValueASCII("proxy")); GURL url(cmdline.proxy);
if (line.HasSwitch("proxy")) { GURL::Replacements remove_auth;
remove_auth.ClearUsername();
remove_auth.ClearPassword();
GURL url_no_auth = url.ReplaceComponents(remove_auth);
if (!cmdline.proxy.empty()) {
if (!url.is_valid()) { if (!url.is_valid()) {
LOG(ERROR) << "Invalid proxy URL"; std::cerr << "Invalid proxy URL" << std::endl;
return false; return false;
} }
if (url.scheme() != "https" && url.scheme() != "quic") { params->proxy_url = GetProxyFromURL(url_no_auth);
LOG(ERROR) << "Must be HTTPS or QUIC proxy";
return false;
}
params->proxy_url = url::SchemeHostPort(url).Serialize();
params->proxy_user = url.username(); params->proxy_user = url.username();
params->proxy_pass = url.password(); params->proxy_pass = url.password();
} }
params->use_padding = false; params->use_padding = cmdline.padding;
if (line.HasSwitch("padding")) {
params->use_padding = true;
}
if (line.HasSwitch("host-resolver-rules")) { if (!cmdline.host_resolver_rules.empty()) {
params->host_resolver_rules = params->host_resolver_rules = cmdline.host_resolver_rules;
line.GetSwitchValueASCII("host-resolver-rules");
} else { } else {
// SNI should only contain DNS hostnames not IP addresses per RFC 6066. // SNI should only contain DNS hostnames not IP addresses per RFC 6066.
if (url.HostIsIPAddress()) { if (url.HostIsIPAddress()) {
GURL::Replacements replacements; GURL::Replacements set_host;
replacements.SetHostStr(kDefaultHostName); set_host.SetHostStr(kDefaultHostName);
params->proxy_url = params->proxy_url =
url::SchemeHostPort(url.ReplaceComponents(replacements)).Serialize(); GetProxyFromURL(url_no_auth.ReplaceComponents(set_host));
LOG(INFO) << "Using '" << kDefaultHostName << "' as the hostname for " LOG(INFO) << "Using '" << kDefaultHostName << "' as the hostname for "
<< url.host(); << url.host();
params->host_resolver_rules = params->host_resolver_rules =
@ -248,26 +319,20 @@ bool ParseCommandLineFlags(Params* params) {
} }
} }
if (line.HasSwitch("log")) { if (!cmdline.no_log) {
params->log_settings.logging_dest = logging::LOG_DEFAULT;
params->log_path = line.GetSwitchValuePath("log");
if (!params->log_path.empty()) { if (!params->log_path.empty()) {
params->log_settings.logging_dest = logging::LOG_TO_FILE; params->log_settings.logging_dest = logging::LOG_TO_FILE;
} else if (params->log_settings.logging_dest == logging::LOG_TO_FILE) { params->log_path = cmdline.log;
params->log_path = base::FilePath::FromUTF8Unsafe("naive.log"); } else {
params->log_settings.logging_dest = logging::LOG_TO_STDERR;
} }
params->log_settings.log_file = params->log_path.value().c_str(); params->log_settings.log_file = params->log_path.value().c_str();
} else { } else {
params->log_settings.logging_dest = logging::LOG_NONE; params->log_settings.logging_dest = logging::LOG_NONE;
} }
if (line.HasSwitch("log-net-log")) { params->net_log_path = cmdline.log_net_log;
params->net_log_path = line.GetSwitchValuePath("log-net-log"); params->ssl_key_path = cmdline.ssl_key_log_file;
}
if (line.HasSwitch("ssl-key-log-file")) {
params->ssl_key_path = line.GetSwitchValuePath("ssl-key-log-file");
}
return true; return true;
} }
@ -329,8 +394,22 @@ int main(int argc, char* argv[]) {
base::CommandLine::Init(argc, argv); base::CommandLine::Init(argc, argv);
CommandLine cmdline;
Params params; Params params;
if (!ParseCommandLineFlags(&params)) { const auto& proc = *base::CommandLine::ForCurrentProcess();
const auto& args = proc.GetArgs();
if (args.empty()) {
if (proc.argv().size() >= 2) {
GetCommandLine(proc, &cmdline);
} else {
auto path = base::FilePath::FromUTF8Unsafe("config.json");
GetCommandLineFromConfig(path, &cmdline);
}
} else {
base::FilePath path(args[0]);
GetCommandLineFromConfig(path, &cmdline);
}
if (!ParseCommandLine(cmdline, &params)) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
@ -380,6 +459,8 @@ int main(int argc, char* argv[]) {
LOG(ERROR) << "Failed to listen: " << result; LOG(ERROR) << "Failed to listen: " << result;
return EXIT_FAILURE; return EXIT_FAILURE;
} }
LOG(INFO) << "Listening on " << params.listen_addr << ":"
<< params.listen_port;
net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol,
params.use_padding, session, kTrafficAnnotation); params.use_padding, session, kTrafficAnnotation);