diff --git a/src/net/BUILD.gn b/src/net/BUILD.gn index e5f9e9c550..e79db45cca 100644 --- a/src/net/BUILD.gn +++ b/src/net/BUILD.gn @@ -1721,6 +1721,8 @@ executable("naive") { "tools/naive/naive_proxy_bin.cc", "tools/naive/http_proxy_socket.cc", "tools/naive/http_proxy_socket.h", + "tools/naive/redirect_resolver.h", + "tools/naive/redirect_resolver.cc", "tools/naive/socks5_server_socket.cc", "tools/naive/socks5_server_socket.h", ] diff --git a/src/net/tools/naive/naive_connection.cc b/src/net/tools/naive/naive_connection.cc index ac03ea5b43..339990beb0 100644 --- a/src/net/tools/naive/naive_connection.cc +++ b/src/net/tools/naive/naive_connection.cc @@ -24,6 +24,7 @@ #include "net/socket/stream_socket.h" #include "net/spdy/spdy_session.h" #include "net/tools/naive/http_proxy_socket.h" +#include "net/tools/naive/redirect_resolver.h" #include "net/tools/naive/socks5_server_socket.h" #if defined(OS_LINUX) @@ -52,6 +53,7 @@ NaiveConnection::NaiveConnection( const ProxyInfo& proxy_info, const SSLConfig& server_ssl_config, const SSLConfig& proxy_ssl_config, + RedirectResolver* resolver, HttpNetworkSession* session, const NetLogWithSource& net_log, std::unique_ptr accepted_socket, @@ -62,6 +64,7 @@ NaiveConnection::NaiveConnection( proxy_info_(proxy_info), server_ssl_config_(server_ssl_config), proxy_ssl_config_(proxy_ssl_config), + resolver_(resolver), session_(session), net_log_(net_log), next_state_(STATE_NONE), @@ -208,7 +211,17 @@ int NaiveConnection::DoConnectServer() { if (rv == 0) { IPEndPoint ipe; if (ipe.FromSockAddr(dst.addr, dst.addr_len)) { - origin = HostPortPair::FromIPEndPoint(ipe); + const auto& addr = ipe.address(); + auto name = resolver_->FindNameByAddress(addr); + if (!name.empty()) { + origin = HostPortPair(name, ipe.port()); + } else if (!resolver_->IsInResolvedRange(addr)) { + origin = HostPortPair::FromIPEndPoint(ipe); + } else { + LOG(ERROR) << "Connection " << id_ << " to unresolved name for " + << addr.ToString(); + return ERR_ADDRESS_INVALID; + } } } #else diff --git a/src/net/tools/naive/naive_connection.h b/src/net/tools/naive/naive_connection.h index 5075a331d3..ef9c7f8f11 100644 --- a/src/net/tools/naive/naive_connection.h +++ b/src/net/tools/naive/naive_connection.h @@ -27,6 +27,7 @@ class ProxyInfo; class StreamSocket; struct NetworkTrafficAnnotationTag; struct SSLConfig; +class RedirectResolver; class NaiveConnection { public: @@ -52,6 +53,7 @@ class NaiveConnection { const ProxyInfo& proxy_info, const SSLConfig& server_ssl_config, const SSLConfig& proxy_ssl_config, + RedirectResolver* resolver, HttpNetworkSession* session, const NetLogWithSource& net_log, std::unique_ptr accepted_socket, @@ -103,6 +105,7 @@ class NaiveConnection { const ProxyInfo& proxy_info_; const SSLConfig& server_ssl_config_; const SSLConfig& proxy_ssl_config_; + RedirectResolver* resolver_; HttpNetworkSession* session_; const NetLogWithSource& net_log_; diff --git a/src/net/tools/naive/naive_proxy.cc b/src/net/tools/naive/naive_proxy.cc index ea1e8d7064..aeef545e0d 100644 --- a/src/net/tools/naive/naive_proxy.cc +++ b/src/net/tools/naive/naive_proxy.cc @@ -28,11 +28,13 @@ namespace net { NaiveProxy::NaiveProxy(std::unique_ptr listen_socket, NaiveConnection::Protocol protocol, bool use_padding, + RedirectResolver* resolver, HttpNetworkSession* session, const NetworkTrafficAnnotationTag& traffic_annotation) : listen_socket_(std::move(listen_socket)), protocol_(protocol), use_padding_(use_padding), + resolver_(resolver), session_(session), net_log_( NetLogWithSource::Make(session->net_log(), NetLogSourceType::NONE)), @@ -110,7 +112,7 @@ void NaiveProxy::DoConnect() { } auto connection_ptr = std::make_unique( ++last_id_, protocol_, pad_direction, proxy_info_, server_ssl_config_, - proxy_ssl_config_, session_, net_log_, std::move(socket), + proxy_ssl_config_, resolver_, session_, net_log_, std::move(socket), traffic_annotation_); auto* connection = connection_ptr.get(); connection_by_id_[connection->id()] = std::move(connection_ptr); diff --git a/src/net/tools/naive/naive_proxy.h b/src/net/tools/naive/naive_proxy.h index 270106ae42..fbae124249 100644 --- a/src/net/tools/naive/naive_proxy.h +++ b/src/net/tools/naive/naive_proxy.h @@ -25,12 +25,14 @@ class NaiveConnection; class ServerSocket; class StreamSocket; struct NetworkTrafficAnnotationTag; +class RedirectResolver; class NaiveProxy { public: NaiveProxy(std::unique_ptr server_socket, NaiveConnection::Protocol protocol, bool use_padding, + RedirectResolver* resolver, HttpNetworkSession* session, const NetworkTrafficAnnotationTag& traffic_annotation); ~NaiveProxy(); @@ -58,6 +60,7 @@ class NaiveProxy { ProxyInfo proxy_info_; SSLConfig server_ssl_config_; SSLConfig proxy_ssl_config_; + RedirectResolver* resolver_; HttpNetworkSession* session_; NetLogWithSource net_log_; diff --git a/src/net/tools/naive/naive_proxy_bin.cc b/src/net/tools/naive/naive_proxy_bin.cc index 54c4e91645..cdcdbe30f0 100644 --- a/src/net/tools/naive/naive_proxy_bin.cc +++ b/src/net/tools/naive/naive_proxy_bin.cc @@ -48,9 +48,11 @@ #include "net/socket/client_socket_pool_manager.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/tcp_server_socket.h" +#include "net/socket/udp_server_socket.h" #include "net/ssl/ssl_key_logger_impl.h" #include "net/third_party/quiche/src/quic/core/quic_versions.h" #include "net/tools/naive/naive_proxy.h" +#include "net/tools/naive/redirect_resolver.h" #include "net/traffic_annotation/network_traffic_annotation.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_context_builder.h" @@ -76,6 +78,7 @@ struct CommandLine { std::string proxy; bool padding; std::string host_resolver_rules; + std::string resolver_range; bool no_log; base::FilePath log; base::FilePath log_net_log; @@ -91,6 +94,8 @@ struct Params { std::u16string proxy_user; std::u16string proxy_pass; std::string host_resolver_rules; + net::IPAddress resolver_range; + size_t resolver_prefix; logging::LoggingSettings log_settings; base::FilePath net_log_path; base::FilePath ssl_key_path; @@ -171,6 +176,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { " proto: https, quic\n" "--padding Use padding\n" "--host-resolver-rules=... Resolver rules\n" + "--resolver-range=... Redirect resolver range\n" "--log[=] Log to stderr, or file\n" "--log-net-log= Save NetLog\n" "--ssl-key-log-file= Save SSL keys for Wireshark\n" @@ -188,6 +194,7 @@ void GetCommandLine(const base::CommandLine& proc, CommandLine* cmdline) { cmdline->padding = proc.HasSwitch("padding"); cmdline->host_resolver_rules = proc.GetSwitchValueASCII("host-resolver-rules"); + cmdline->resolver_range = proc.GetSwitchValueASCII("resolver-range"); cmdline->no_log = !proc.HasSwitch("log"); cmdline->log = proc.GetSwitchValuePath("log"); cmdline->log_net_log = proc.GetSwitchValuePath("log-net-log"); @@ -223,6 +230,9 @@ void GetCommandLineFromConfig(const base::FilePath& config_path, cmdline->host_resolver_rules = value->FindKey("host-resolver-rules")->GetString(); } + if (value->FindKeyOfType("resolver-range", base::Value::Type::STRING)) { + cmdline->resolver_range = value->FindKey("resolver-range")->GetString(); + } cmdline->no_log = true; if (value->FindKeyOfType("log", base::Value::Type::STRING)) { cmdline->no_log = false; @@ -262,8 +272,13 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->protocol = net::NaiveConnection::kHttp; params->listen_port = 8080; } else if (url.scheme() == "redir") { +#if defined(OS_LINUX) params->protocol = net::NaiveConnection::kRedir; params->listen_port = 1080; +#else + std::cerr << "Redir protocol only supports Linux." << std::endl; + return false; +#endif } else { std::cerr << "Invalid scheme in --listen" << std::endl; return false; @@ -305,6 +320,22 @@ bool ParseCommandLine(const CommandLine& cmdline, Params* params) { params->host_resolver_rules = cmdline.host_resolver_rules; + if (params->protocol == net::NaiveConnection::kRedir) { + std::string range = "100.64.0.0/10"; + if (!cmdline.resolver_range.empty()) + range = cmdline.resolver_range; + + if (!net::ParseCIDRBlock(range, ¶ms->resolver_range, + ¶ms->resolver_prefix)) { + std::cerr << "Invalid resolver range" << std::endl; + return false; + } + if (params->resolver_range.IsIPv6()) { + std::cerr << "IPv6 resolver range not supported" << std::endl; + return false; + } + } + if (!cmdline.no_log) { if (!cmdline.log.empty()) { params->log_settings.logging_dest = logging::LOG_TO_FILE; @@ -450,8 +481,32 @@ int main(int argc, char* argv[]) { LOG(INFO) << "Listening on " << params.listen_addr << ":" << params.listen_port; + std::unique_ptr resolver; + if (params.protocol == net::NaiveConnection::kRedir) { + auto resolver_socket = + std::make_unique(net_log, net::NetLogSource()); + resolver_socket->AllowAddressReuse(); + net::IPAddress listen_addr; + if (!listen_addr.AssignFromIPLiteral(params.listen_addr)) { + LOG(ERROR) << "Failed to open resolver: " << net::ERR_ADDRESS_INVALID; + return EXIT_FAILURE; + } + + result = resolver_socket->Listen( + net::IPEndPoint(listen_addr, params.listen_port)); + if (result != net::OK) { + LOG(ERROR) << "Failed to open resolver: " << result; + return EXIT_FAILURE; + } + + resolver = std::make_unique( + std::move(resolver_socket), params.resolver_range, + params.resolver_prefix); + } + net::NaiveProxy naive_proxy(std::move(listen_socket), params.protocol, - params.use_padding, session, kTrafficAnnotation); + params.use_padding, resolver.get(), session, + kTrafficAnnotation); base::RunLoop().Run(); diff --git a/src/net/tools/naive/redirect_resolver.cc b/src/net/tools/naive/redirect_resolver.cc new file mode 100644 index 0000000000..363e4c2661 --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.cc @@ -0,0 +1,246 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/naive/redirect_resolver.h" + +#include +#include +#include + +#include "base/logging.h" +#include "base/threading/thread_task_runner_handle.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/dns/dns_query.h" +#include "net/dns/dns_response.h" +#include "net/dns/dns_util.h" +#include "net/socket/datagram_server_socket.h" +#include "third_party/abseil-cpp/absl/types/optional.h" + +namespace { +constexpr int kUdpReadBufferSize = 1024; +constexpr int kResolutionTtl = 60; +constexpr int kResolutionRecycleTime = 60 * 5; + +std::string PackedIPv4ToString(uint32_t addr) { + return net::IPAddress(addr >> 24, addr >> 16, addr >> 8, addr).ToString(); +} +} // namespace + +namespace net { + +Resolution::Resolution() = default; + +Resolution::~Resolution() = default; + +RedirectResolver::RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix) + : socket_(std::move(socket)), + range_(range), + prefix_(prefix), + offset_(0), + buffer_(base::MakeRefCounted(kUdpReadBufferSize)) { + DCHECK(socket_); + // Start accepting connections in next run loop in case when delegate is not + // ready to get callbacks. + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::BindOnce(&RedirectResolver::DoRead, + weak_ptr_factory_.GetWeakPtr())); +} + +RedirectResolver::~RedirectResolver() = default; + +void RedirectResolver::DoRead() { + for (;;) { + int rv = socket_->RecvFrom( + buffer_.get(), kUdpReadBufferSize, &recv_address_, + base::BindOnce(&RedirectResolver::OnRecv, base::Unretained(this))); + if (rv == ERR_IO_PENDING) + return; + rv = HandleReadResult(rv); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "DoRead: ignoring error " << rv; + } + } +} + +void RedirectResolver::OnRecv(int result) { + int rv; + rv = HandleReadResult(result); + if (rv == ERR_IO_PENDING) + return; + if (rv < 0) { + LOG(INFO) << "OnRecv: ignoring error " << result; + } + + DoRead(); +} + +void RedirectResolver::OnSend(int result) { + if (result < 0) { + LOG(INFO) << "OnSend: ignoring error " << result; + } + + DoRead(); +} + +int RedirectResolver::HandleReadResult(int result) { + if (result < 0) + return result; + + DnsQuery query(buffer_.get()); + if (!query.Parse(result)) { + LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString(); + return ERR_INVALID_ARGUMENT; + } + + int size; + if (query.qtype() == dns_protocol::kTypeA) { + Resolution res; + + auto name_or = DnsDomainToString(query.qname()); + if (!name_or) { + LOG(INFO) << "Malformed DNS query from " << recv_address_.ToString(); + return ERR_INVALID_ARGUMENT; + } + const auto& name = name_or.value(); + + auto by_name_lookup = resolution_by_name_.emplace(name, resolutions_.end()); + auto by_name = by_name_lookup.first; + bool has_name = !by_name_lookup.second; + if (has_name) { + auto res_it = by_name->second; + auto by_addr = res_it->by_addr; + uint32_t addr = res_it->addr; + + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + uint32_t addr = (range_.bytes()[0] << 24) | (range_.bytes()[1] << 16) | + (range_.bytes()[2] << 8) | range_.bytes()[3]; + uint32_t subnet = ~0U >> prefix_; + addr &= ~subnet; + addr += offset_; + offset_ = (offset_ + 1) & subnet; + + auto by_addr_lookup = + resolution_by_addr_.emplace(addr, resolutions_.end()); + auto by_addr = by_addr_lookup.first; + bool has_addr = !by_addr_lookup.second; + if (has_addr) { + // Too few available addresses. Overwrites old one. + auto res_it = by_addr->second; + + LOG(INFO) << "Overwrite " << res_it->name << " " + << PackedIPv4ToString(res_it->addr) << " with " << name << " " + << PackedIPv4ToString(addr); + resolution_by_name_.erase(res_it->by_name); + resolutions_.erase(res_it); + resolutions_.emplace_back(); + res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + } else { + LOG(INFO) << "Add " << name << " " << PackedIPv4ToString(addr); + resolutions_.emplace_back(); + auto res_it = std::prev(resolutions_.end()); + + by_name->second = res_it; + by_addr->second = res_it; + res_it->addr = addr; + res_it->name = name; + res_it->time = base::TimeTicks::Now(); + res_it->by_name = by_name; + res_it->by_addr = by_addr; + + // Collects garbage. + auto now = base::TimeTicks::Now(); + for (auto it = resolutions_.begin(); + it != resolutions_.end() && + (now - it->time).InSeconds() > kResolutionRecycleTime;) { + auto next = std::next(it); + LOG(INFO) << "Drop " << it->name << " " + << PackedIPv4ToString(it->addr); + resolution_by_name_.erase(it->by_name); + resolution_by_addr_.erase(it->by_addr); + resolutions_.erase(it); + it = next; + } + } + } + + DnsResourceRecord record; + record.name = name; + record.type = dns_protocol::kTypeA; + record.klass = dns_protocol::kClassIN; + record.ttl = kResolutionTtl; + uint32_t addr = by_name->second->addr; + record.SetOwnedRdata(IPAddressToPackedString( + IPAddress(addr >> 24, addr >> 16, addr >> 8, addr))); + absl::optional query_opt; + query_opt.emplace(query.id(), query.qname(), query.qtype()); + DnsResponse response(query.id(), /*is_authoritative=*/false, + /*answers=*/{std::move(record)}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt); + size = response.io_buffer_size(); + if (size > buffer_->size() || !response.io_buffer()) { + return ERR_NO_BUFFER_SPACE; + } + std::memcpy(buffer_->data(), response.io_buffer()->data(), size); + } else { + absl::optional query_opt; + query_opt.emplace(query.id(), query.qname(), query.qtype()); + DnsResponse response(query.id(), /*is_authoritative=*/false, /*answers=*/{}, + /*authority_records=*/{}, /*additional_records=*/{}, + query_opt, dns_protocol::kRcodeSERVFAIL); + size = response.io_buffer_size(); + if (size > buffer_->size() || !response.io_buffer()) { + return ERR_NO_BUFFER_SPACE; + } + std::memcpy(buffer_->data(), response.io_buffer()->data(), size); + } + + return socket_->SendTo( + buffer_.get(), size, recv_address_, + base::BindOnce(&RedirectResolver::OnSend, base::Unretained(this))); +} + +bool RedirectResolver::IsInResolvedRange(const IPAddress& address) const { + if (!address.IsIPv4()) + return false; + return IPAddressMatchesPrefix(address, range_, prefix_); +} + +std::string RedirectResolver::FindNameByAddress( + const IPAddress& address) const { + if (!address.IsIPv4()) + return {}; + uint32_t addr = (address.bytes()[0] << 24) | (address.bytes()[1] << 16) | + (address.bytes()[2] << 8) | address.bytes()[3]; + auto by_addr = resolution_by_addr_.find(addr); + if (by_addr == resolution_by_addr_.end()) + return {}; + return by_addr->second->name; +} + +} // namespace net diff --git a/src/net/tools/naive/redirect_resolver.h b/src/net/tools/naive/redirect_resolver.h new file mode 100644 index 0000000000..c8551ac5bd --- /dev/null +++ b/src/net/tools/naive/redirect_resolver.h @@ -0,0 +1,70 @@ +// Copyright 2019 klzgrad . All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ +#define NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_ + +#include +#include +#include +#include +#include + +#include "base/macros.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/time/time.h" +#include "net/base/ip_address.h" +#include "net/base/ip_endpoint.h" + +namespace net { + +class DatagramServerSocket; +class IOBufferWithSize; + +struct Resolution { + Resolution(); + ~Resolution(); + + uint32_t addr; + std::string name; + base::TimeTicks time; + std::map::iterator>::iterator by_name; + std::map::iterator>::iterator by_addr; +}; + +class RedirectResolver { + public: + RedirectResolver(std::unique_ptr socket, + const IPAddress& range, + size_t prefix); + ~RedirectResolver(); + + bool IsInResolvedRange(const IPAddress& address) const; + std::string FindNameByAddress(const IPAddress& address) const; + + private: + void DoRead(); + void OnRecv(int result); + void OnSend(int result); + int HandleReadResult(int result); + + std::unique_ptr socket_; + IPAddress range_; + size_t prefix_; + uint32_t offset_; + scoped_refptr buffer_; + IPEndPoint recv_address_; + + std::map::iterator> resolution_by_name_; + std::map::iterator> resolution_by_addr_; + std::list resolutions_; + + base::WeakPtrFactory weak_ptr_factory_{this}; + + DISALLOW_COPY_AND_ASSIGN(RedirectResolver); +}; + +} // namespace net +#endif // NET_TOOLS_NAIVE_REDIRECT_RESOLVER_H_