Use a default hostname for IP addresses

Also set defaults for the listening address and port.
This commit is contained in:
klzgrad 2018-01-29 00:04:45 -05:00
parent 6a48db64fd
commit 27473896b4

View File

@ -24,6 +24,8 @@
#include "base/values.h" #include "base/values.h"
#include "build/build_config.h" #include "build/build_config.h"
#include "net/base/auth.h" #include "net/base/auth.h"
#include "net/dns/host_resolver.h"
#include "net/dns/mapped_host_resolver.h"
#include "net/http/http_auth.h" #include "net/http/http_auth.h"
#include "net/http/http_auth_cache.h" #include "net/http/http_auth_cache.h"
#include "net/http/http_network_session.h" #include "net/http/http_network_session.h"
@ -56,6 +58,19 @@ constexpr int kListenBackLog = 512;
constexpr int kDefaultMaxSocketsPerPool = 256; constexpr int kDefaultMaxSocketsPerPool = 256;
constexpr int kDefaultMaxSocketsPerGroup = 255; constexpr int kDefaultMaxSocketsPerGroup = 255;
constexpr int kExpectedMaxUsers = 8; constexpr int kExpectedMaxUsers = 8;
constexpr char kDefaultHostName[] = "example";
struct Params {
std::string listen_addr;
int listen_port;
std::string proxy_url;
std::string proxy_user;
std::string proxy_pass;
std::string host_resolver_rules;
logging::LoggingSettings log_settings;
base::FilePath net_log_path;
base::FilePath ssl_key_path;
};
std::unique_ptr<base::Value> GetConstants( std::unique_ptr<base::Value> GetConstants(
const base::CommandLine::StringType& command_line_string) { const base::CommandLine::StringType& command_line_string) {
@ -81,14 +96,12 @@ std::unique_ptr<base::Value> GetConstants(
// Builds a URLRequestContext assuming there's only a single loop. // Builds a URLRequestContext assuming there's only a single loop.
std::unique_ptr<net::URLRequestContext> BuildURLRequestContext( std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
const std::string& proxy_url, const Params& params,
const std::string& proxy_user,
const std::string& proxy_pass,
net::NetLog* net_log) { net::NetLog* net_log) {
net::URLRequestContextBuilder builder; net::URLRequestContextBuilder builder;
net::ProxyConfig proxy_config; net::ProxyConfig proxy_config;
proxy_config.proxy_rules().ParseFromString(proxy_url); proxy_config.proxy_rules().ParseFromString(params.proxy_url);
auto proxy_service = net::ProxyService::CreateWithoutProxyResolver( auto proxy_service = net::ProxyService::CreateWithoutProxyResolver(
std::make_unique<net::ProxyConfigServiceFixed>(proxy_config), net_log); std::make_unique<net::ProxyConfigServiceFixed>(proxy_config), net_log);
proxy_service->ForceReloadProxyConfig(); proxy_service->ForceReloadProxyConfig();
@ -97,14 +110,21 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
builder.DisableHttpCache(); builder.DisableHttpCache();
builder.set_net_log(net_log); builder.set_net_log(net_log);
if (!params.host_resolver_rules.empty()) {
auto remapped_resolver = std::make_unique<net::MappedHostResolver>(
net::HostResolver::CreateDefaultResolver(net_log));
remapped_resolver->SetRulesFromString(params.host_resolver_rules);
builder.set_host_resolver(std::move(remapped_resolver));
}
auto context = builder.Build(); auto context = builder.Build();
net::HttpNetworkSession* session = net::HttpNetworkSession* session =
context->http_transaction_factory()->GetSession(); context->http_transaction_factory()->GetSession();
net::HttpAuthCache* auth_cache = session->http_auth_cache(); net::HttpAuthCache* auth_cache = session->http_auth_cache();
GURL auth_origin(proxy_url); GURL auth_origin(params.proxy_url);
net::AuthCredentials credentials(base::ASCIIToUTF16(proxy_user), net::AuthCredentials credentials(base::ASCIIToUTF16(params.proxy_user),
base::ASCIIToUTF16(proxy_pass)); base::ASCIIToUTF16(params.proxy_pass));
auth_cache->Add(auth_origin, /*realm=*/std::string(), auth_cache->Add(auth_origin, /*realm=*/std::string(),
net::HttpAuth::AUTH_SCHEME_BASIC, /*challenge=*/"Basic", net::HttpAuth::AUTH_SCHEME_BASIC, /*challenge=*/"Basic",
credentials, /*path=*/"/"); credentials, /*path=*/"/");
@ -112,17 +132,6 @@ std::unique_ptr<net::URLRequestContext> BuildURLRequestContext(
return context; return context;
} }
struct Params {
std::string listen_addr;
int listen_port;
std::string proxy_url;
std::string proxy_user;
std::string proxy_pass;
logging::LoggingSettings log_settings;
base::FilePath net_log_path;
base::FilePath ssl_key_path;
};
bool ParseCommandLineFlags(Params* params) { bool ParseCommandLineFlags(Params* params) {
const base::CommandLine& line = *base::CommandLine::ForCurrentProcess(); const base::CommandLine& line = *base::CommandLine::ForCurrentProcess();
@ -130,11 +139,11 @@ bool ParseCommandLineFlags(Params* params) {
LOG(INFO) << "Usage: naive_client [options]\n" LOG(INFO) << "Usage: naive_client [options]\n"
"\n" "\n"
"Options:\n" "Options:\n"
"-h, --help Show this help message and exit\n" "-h, --help Show this message\n"
"--addr=<address> Address to listen on\n" "--addr=<address> Address to listen on (0.0.0.0)\n"
"--port=<port> Port to listen on\n" "--port=<port> Port to listen on (1080)\n"
"--proxy=https://<user>:<pass>@<domain>[:port]\n" "--proxy=https://<user>:<pass>@<hostname>[:port]\n"
" Proxy specification\n" " Proxy specification.\n"
"--log Log to stderr, otherwise no log\n" "--log Log to stderr, otherwise no log\n"
"--log-net-log=<path> Save NetLog\n" "--log-net-log=<path> Save NetLog\n"
"--ssl-key-log-file=<path> Save SSL keys for Wireshark\n"; "--ssl-key-log-file=<path> Save SSL keys for Wireshark\n";
@ -142,20 +151,17 @@ bool ParseCommandLineFlags(Params* params) {
return false; return false;
} }
if (!line.HasSwitch("addr")) { params->listen_addr = "0.0.0.0";
LOG(ERROR) << "Missing --addr"; if (line.HasSwitch("addr")) {
return false;
}
params->listen_addr = line.GetSwitchValueASCII("addr"); params->listen_addr = line.GetSwitchValueASCII("addr");
}
if (params->listen_addr.empty()) { if (params->listen_addr.empty()) {
LOG(ERROR) << "Invalid --port"; LOG(ERROR) << "Invalid --addr";
return false; return false;
} }
if (!line.HasSwitch("port")) { params->listen_port = 1080;
LOG(ERROR) << "Missing --port"; if (line.HasSwitch("port")) {
return false;
}
if (!base::StringToInt(line.GetSwitchValueASCII("port"), if (!base::StringToInt(line.GetSwitchValueASCII("port"),
&params->listen_port)) { &params->listen_port)) {
LOG(ERROR) << "Invalid --port"; LOG(ERROR) << "Invalid --port";
@ -166,6 +172,7 @@ bool ParseCommandLineFlags(Params* params) {
LOG(ERROR) << "Invalid --port"; LOG(ERROR) << "Invalid --port";
return false; return false;
} }
}
if (!line.HasSwitch("proxy")) { if (!line.HasSwitch("proxy")) {
LOG(ERROR) << "Missing --proxy"; LOG(ERROR) << "Missing --proxy";
@ -188,6 +195,23 @@ bool ParseCommandLineFlags(Params* params) {
params->proxy_user = url.username(); params->proxy_user = url.username();
params->proxy_pass = url.password(); params->proxy_pass = url.password();
if (line.HasSwitch("host-resolver-rules")) {
params->host_resolver_rules =
line.GetSwitchValueASCII("host-resolver-rules");
} else {
// SNI should only contain DNS hostnames not IP addresses per RFC 6066.
if (url.HostIsIPAddress()) {
GURL::Replacements replacements;
replacements.SetHostStr(kDefaultHostName);
params->proxy_url =
url::SchemeHostPort(url.ReplaceComponents(replacements)).Serialize();
LOG(INFO) << "Using '" << kDefaultHostName << "' as the hostname for "
<< url.host();
params->host_resolver_rules =
std::string("MAP ") + kDefaultHostName + " " + url.host();
}
}
if (line.HasSwitch("log")) { if (line.HasSwitch("log")) {
params->log_settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG; params->log_settings.logging_dest = logging::LOG_TO_SYSTEM_DEBUG_LOG;
} else { } else {
@ -302,8 +326,7 @@ int main(int argc, char* argv[]) {
net_log.AddObserver(&printing_log_observer, net_log.AddObserver(&printing_log_observer,
net::NetLogCaptureMode::Default()); net::NetLogCaptureMode::Default());
auto context = BuildURLRequestContext(params.proxy_url, params.proxy_user, auto context = BuildURLRequestContext(params, &net_log);
params.proxy_pass, &net_log);
auto server_socket = auto server_socket =
std::make_unique<net::TCPServerSocket>(&net_log, net::NetLogSource()); std::make_unique<net::TCPServerSocket>(&net_log, net::NetLogSource());