mirror of
https://github.com/klzgrad/naiveproxy.git
synced 2024-11-28 00:06:09 +03:00
353 lines
9.8 KiB
C++
353 lines
9.8 KiB
C++
// Copyright 2017 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "net/test/tcp_socket_proxy.h"
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "base/callback.h"
|
|
#include "base/memory/weak_ptr.h"
|
|
#include "base/single_thread_task_runner.h"
|
|
#include "base/synchronization/waitable_event.h"
|
|
#include "base/threading/thread_checker.h"
|
|
#include "net/base/io_buffer.h"
|
|
#include "net/base/net_errors.h"
|
|
#include "net/socket/stream_socket.h"
|
|
#include "net/socket/tcp_client_socket.h"
|
|
#include "net/socket/tcp_server_socket.h"
|
|
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
|
|
|
|
namespace net {
|
|
|
|
namespace {
|
|
|
|
const int kBufferSize = 1024;
|
|
|
|
// Helper that reads data from one socket and then forwards to another socket.
|
|
class SocketDataPump {
|
|
public:
|
|
SocketDataPump(StreamSocket* from_socket,
|
|
StreamSocket* to_socket,
|
|
base::OnceClosure on_done_callback)
|
|
: from_socket_(from_socket),
|
|
to_socket_(to_socket),
|
|
on_done_callback_(std::move(on_done_callback)) {
|
|
read_buffer_ = base::MakeRefCounted<IOBuffer>(kBufferSize);
|
|
}
|
|
|
|
~SocketDataPump() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); }
|
|
|
|
void Start() { Read(); }
|
|
|
|
private:
|
|
void Read() {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
DCHECK(!write_buffer_);
|
|
|
|
int result = from_socket_->Read(
|
|
read_buffer_.get(), kBufferSize,
|
|
base::Bind(&SocketDataPump::HandleReadResult, base::Unretained(this)));
|
|
if (result != ERR_IO_PENDING)
|
|
HandleReadResult(result);
|
|
}
|
|
|
|
void HandleReadResult(int result) {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
|
|
if (result <= 0) {
|
|
std::move(on_done_callback_).Run();
|
|
return;
|
|
}
|
|
|
|
write_buffer_ =
|
|
base::MakeRefCounted<DrainableIOBuffer>(read_buffer_, result);
|
|
Write();
|
|
}
|
|
|
|
void Write() {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
DCHECK(write_buffer_);
|
|
|
|
int result = to_socket_->Write(
|
|
write_buffer_.get(), write_buffer_->BytesRemaining(),
|
|
base::Bind(&SocketDataPump::HandleWriteResult, base::Unretained(this)),
|
|
TRAFFIC_ANNOTATION_FOR_TESTS);
|
|
if (result != ERR_IO_PENDING)
|
|
HandleWriteResult(result);
|
|
}
|
|
|
|
void HandleWriteResult(int result) {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
|
|
if (result <= 0) {
|
|
std::move(on_done_callback_).Run();
|
|
return;
|
|
}
|
|
|
|
write_buffer_->DidConsume(result);
|
|
if (write_buffer_->BytesRemaining()) {
|
|
Write();
|
|
} else {
|
|
write_buffer_ = nullptr;
|
|
Read();
|
|
}
|
|
}
|
|
|
|
StreamSocket* from_socket_;
|
|
StreamSocket* to_socket_;
|
|
|
|
scoped_refptr<IOBuffer> read_buffer_;
|
|
scoped_refptr<DrainableIOBuffer> write_buffer_;
|
|
|
|
base::OnceClosure on_done_callback_;
|
|
|
|
THREAD_CHECKER(thread_checker_);
|
|
|
|
DISALLOW_COPY_AND_ASSIGN(SocketDataPump);
|
|
};
|
|
|
|
// ConnectionProxy is responsible for proxying one connection to a remote
|
|
// address.
|
|
class ConnectionProxy {
|
|
public:
|
|
explicit ConnectionProxy(std::unique_ptr<StreamSocket> local_socket);
|
|
~ConnectionProxy();
|
|
|
|
void Start(const IPEndPoint& remote_endpoint,
|
|
base::OnceClosure on_done_callback);
|
|
|
|
private:
|
|
void Close();
|
|
|
|
void HandleConnectResult(const IPEndPoint& remote_endpoint, int result);
|
|
|
|
base::OnceClosure on_done_callback_;
|
|
|
|
std::unique_ptr<StreamSocket> local_socket_;
|
|
std::unique_ptr<StreamSocket> remote_socket_;
|
|
|
|
std::unique_ptr<SocketDataPump> incoming_pump_;
|
|
std::unique_ptr<SocketDataPump> outgoing_pump_;
|
|
|
|
THREAD_CHECKER(thread_checker_);
|
|
|
|
base::WeakPtrFactory<ConnectionProxy> weak_factory_;
|
|
|
|
DISALLOW_COPY_AND_ASSIGN(ConnectionProxy);
|
|
};
|
|
|
|
ConnectionProxy::ConnectionProxy(std::unique_ptr<StreamSocket> local_socket)
|
|
: local_socket_(std::move(local_socket)), weak_factory_(this) {}
|
|
|
|
ConnectionProxy::~ConnectionProxy() {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
}
|
|
|
|
void ConnectionProxy::Start(const IPEndPoint& remote_endpoint,
|
|
base::OnceClosure on_done_callback) {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
|
|
on_done_callback_ = std::move(on_done_callback);
|
|
remote_socket_ = std::make_unique<TCPClientSocket>(
|
|
AddressList(remote_endpoint), nullptr, nullptr, NetLogSource());
|
|
int result = remote_socket_->Connect(
|
|
base::Bind(&ConnectionProxy::HandleConnectResult, base::Unretained(this),
|
|
remote_endpoint));
|
|
if (result != ERR_IO_PENDING)
|
|
HandleConnectResult(remote_endpoint, result);
|
|
}
|
|
|
|
void ConnectionProxy::HandleConnectResult(const IPEndPoint& remote_endpoint,
|
|
int result) {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
DCHECK(!incoming_pump_);
|
|
DCHECK(!outgoing_pump_);
|
|
|
|
if (result < 0) {
|
|
LOG(ERROR) << "Connection to " << remote_endpoint.ToString()
|
|
<< " failed: " << ErrorToString(result);
|
|
Close();
|
|
return;
|
|
}
|
|
|
|
incoming_pump_ = std::make_unique<SocketDataPump>(
|
|
remote_socket_.get(), local_socket_.get(),
|
|
base::BindOnce(&ConnectionProxy::Close, base::Unretained(this)));
|
|
outgoing_pump_ = std::make_unique<SocketDataPump>(
|
|
local_socket_.get(), remote_socket_.get(),
|
|
base::BindOnce(&ConnectionProxy::Close, base::Unretained(this)));
|
|
|
|
auto self = weak_factory_.GetWeakPtr();
|
|
incoming_pump_->Start();
|
|
if (!self)
|
|
return;
|
|
|
|
outgoing_pump_->Start();
|
|
}
|
|
|
|
void ConnectionProxy::Close() {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
|
|
local_socket_.reset();
|
|
remote_socket_.reset();
|
|
std::move(on_done_callback_).Run();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
// TcpSocketProxy implementation that runs on a background IO thread.
|
|
class TcpSocketProxy::Core {
|
|
public:
|
|
Core();
|
|
~Core();
|
|
|
|
void Initialize(int local_port, base::WaitableEvent* initialized_event);
|
|
void Start(const IPEndPoint& remote_endpoint);
|
|
uint16_t local_port() const { return local_port_; }
|
|
|
|
private:
|
|
void DoAcceptLoop();
|
|
void OnAcceptResult(int result);
|
|
void HandleAcceptResult(int result);
|
|
void OnConnectionClosed(ConnectionProxy* connection);
|
|
|
|
IPEndPoint remote_endpoint_;
|
|
|
|
std::unique_ptr<TCPServerSocket> socket_;
|
|
|
|
uint16_t local_port_ = 0;
|
|
std::vector<std::unique_ptr<ConnectionProxy>> connections_;
|
|
|
|
std::unique_ptr<StreamSocket> accepted_socket_;
|
|
|
|
DISALLOW_COPY_AND_ASSIGN(Core);
|
|
};
|
|
|
|
TcpSocketProxy::Core::Core() {}
|
|
|
|
void TcpSocketProxy::Core::Initialize(int local_port,
|
|
base::WaitableEvent* initialized_event) {
|
|
DCHECK(!socket_);
|
|
|
|
local_port_ = 0;
|
|
|
|
socket_ = std::make_unique<TCPServerSocket>(nullptr, net::NetLogSource());
|
|
int result =
|
|
socket_->Listen(IPEndPoint(IPAddress::IPv4Localhost(), local_port), 5);
|
|
if (result != OK) {
|
|
LOG(ERROR) << "TcpServerSocket::Listen() returned "
|
|
<< ErrorToString(result);
|
|
} else {
|
|
// Get local port number.
|
|
IPEndPoint address;
|
|
result = socket_->GetLocalAddress(&address);
|
|
if (result != OK) {
|
|
LOG(ERROR) << "TcpServerSocket::GetLocalAddress() returned "
|
|
<< ErrorToString(result);
|
|
} else {
|
|
local_port_ = address.port();
|
|
}
|
|
}
|
|
|
|
if (initialized_event)
|
|
initialized_event->Signal();
|
|
}
|
|
|
|
void TcpSocketProxy::Core::Start(const IPEndPoint& remote_endpoint) {
|
|
DCHECK(socket_);
|
|
|
|
remote_endpoint_ = remote_endpoint;
|
|
DoAcceptLoop();
|
|
}
|
|
|
|
TcpSocketProxy::Core::~Core() {}
|
|
|
|
void TcpSocketProxy::Core::DoAcceptLoop() {
|
|
int result = OK;
|
|
while (result == OK) {
|
|
result = socket_->Accept(
|
|
&accepted_socket_,
|
|
base::Bind(&Core::OnAcceptResult, base::Unretained(this)));
|
|
if (result != ERR_IO_PENDING)
|
|
HandleAcceptResult(result);
|
|
}
|
|
}
|
|
|
|
void TcpSocketProxy::Core::OnAcceptResult(int result) {
|
|
HandleAcceptResult(result);
|
|
if (result == OK)
|
|
DoAcceptLoop();
|
|
}
|
|
|
|
void TcpSocketProxy::Core::HandleAcceptResult(int result) {
|
|
DCHECK_NE(result, ERR_IO_PENDING);
|
|
|
|
if (result < 0) {
|
|
LOG(ERROR) << "Error when accepting a connection: "
|
|
<< ErrorToString(result);
|
|
return;
|
|
}
|
|
|
|
std::unique_ptr<ConnectionProxy> connection_proxy =
|
|
std::make_unique<ConnectionProxy>(std::move(accepted_socket_));
|
|
ConnectionProxy* connection_proxy_ptr = connection_proxy.get();
|
|
connections_.push_back(std::move(connection_proxy));
|
|
|
|
// Start() may invoke the callback so it needs to be called after the
|
|
// connection is pushed to connections_.
|
|
connection_proxy_ptr->Start(
|
|
remote_endpoint_,
|
|
base::BindOnce(&Core::OnConnectionClosed, base::Unretained(this),
|
|
connection_proxy_ptr));
|
|
}
|
|
|
|
void TcpSocketProxy::Core::OnConnectionClosed(ConnectionProxy* connection) {
|
|
for (auto it = connections_.begin(); it != connections_.end(); ++it) {
|
|
if (it->get() == connection) {
|
|
connections_.erase(it);
|
|
return;
|
|
}
|
|
}
|
|
NOTREACHED();
|
|
}
|
|
|
|
TcpSocketProxy::TcpSocketProxy(
|
|
scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
|
|
: io_task_runner_(io_task_runner), core_(std::make_unique<Core>()) {}
|
|
|
|
bool TcpSocketProxy::Initialize(int local_port) {
|
|
DCHECK(!local_port_);
|
|
|
|
if (io_task_runner_->BelongsToCurrentThread()) {
|
|
core_->Initialize(local_port, nullptr);
|
|
} else {
|
|
base::WaitableEvent initialized_event(
|
|
base::WaitableEvent::ResetPolicy::MANUAL,
|
|
base::WaitableEvent::InitialState::NOT_SIGNALED);
|
|
io_task_runner_->PostTask(
|
|
FROM_HERE,
|
|
base::BindOnce(&Core::Initialize, base::Unretained(core_.get()),
|
|
local_port, &initialized_event));
|
|
initialized_event.Wait();
|
|
}
|
|
|
|
local_port_ = core_->local_port();
|
|
|
|
return local_port_ != 0;
|
|
}
|
|
|
|
TcpSocketProxy::~TcpSocketProxy() {
|
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
|
|
io_task_runner_->DeleteSoon(FROM_HERE, std::move(core_));
|
|
}
|
|
|
|
void TcpSocketProxy::Start(const IPEndPoint& remote_endpoint) {
|
|
io_task_runner_->PostTask(
|
|
FROM_HERE, base::BindOnce(&Core::Start, base::Unretained(core_.get()),
|
|
remote_endpoint));
|
|
}
|
|
|
|
} // namespace net
|