naive: Yield for socket pump I/O properly

Avoid doing one direction for too long.
This commit is contained in:
klzgrad 2018-01-25 23:33:13 +08:00
parent 3f9e1f4c67
commit 21c9076faa
2 changed files with 77 additions and 47 deletions

View File

@ -21,6 +21,7 @@
#include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_manager.h" #include "net/socket/client_socket_pool_manager.h"
#include "net/socket/stream_socket.h" #include "net/socket/stream_socket.h"
#include "net/spdy/chromium/spdy_session.h"
#include "net/tools/naive/socks5_server_socket.h" #include "net/tools/naive/socks5_server_socket.h"
namespace net { namespace net {
@ -41,9 +42,8 @@ NaiveClientConnection::NaiveClientConnection(
client_socket_( client_socket_(
std::make_unique<Socks5ServerSocket>(std::move(accepted_socket))), std::make_unique<Socks5ServerSocket>(std::move(accepted_socket))),
server_socket_handle_(std::make_unique<ClientSocketHandle>()), server_socket_handle_(std::make_unique<ClientSocketHandle>()),
client_error_(OK),
server_error_(OK),
full_duplex_(false), full_duplex_(false),
time_func_(&base::TimeTicks::Now),
weak_ptr_factory_(this) { weak_ptr_factory_(this) {
io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete, io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete,
weak_ptr_factory_.GetWeakPtr()); weak_ptr_factory_.GetWeakPtr());
@ -185,34 +185,49 @@ int NaiveClientConnection::Run(const CompletionCallback& callback) {
run_callback_ = callback; run_callback_ = callback;
Pull(client_socket_.get(), server_socket_handle_->socket()); sockets_[kClient] = client_socket_.get();
Pull(server_socket_handle_->socket(), client_socket_.get()); sockets_[kServer] = server_socket_handle_->socket();
errors_[kClient] = OK;
errors_[kServer] = OK;
bytes_passed_without_yielding_[kClient] = 0;
bytes_passed_without_yielding_[kServer] = 0;
yield_after_time_[kClient] =
time_func_() +
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
yield_after_time_[kServer] = yield_after_time_[kClient];
Pull(kClient, kServer);
Pull(kServer, kClient);
return ERR_IO_PENDING; return ERR_IO_PENDING;
} }
void NaiveClientConnection::Pull(StreamSocket* from, StreamSocket* to) { void NaiveClientConnection::Pull(Direction from, Direction to) {
if (client_error_ < 0 || server_error_ < 0) if (errors_[kClient] < 0 || errors_[kServer] < 0)
return; return;
auto buffer = base::MakeRefCounted<IOBuffer>(kBufferSize); auto buffer = base::MakeRefCounted<IOBuffer>(kBufferSize);
int rv = int rv = sockets_[from]->Read(
from->Read(buffer.get(), kBufferSize, buffer.get(), kBufferSize,
base::Bind(&NaiveClientConnection::OnReadComplete, base::Bind(&NaiveClientConnection::OnReadComplete,
weak_ptr_factory_.GetWeakPtr(), from, to, buffer)); weak_ptr_factory_.GetWeakPtr(), from, to, buffer));
if (rv != ERR_IO_PENDING) if (rv != ERR_IO_PENDING)
OnReadComplete(from, to, buffer, rv); OnReadComplete(from, to, buffer, rv);
} }
void NaiveClientConnection::Push(StreamSocket* from, void NaiveClientConnection::Push(Direction from,
StreamSocket* to, Direction to,
scoped_refptr<IOBuffer> buffer, scoped_refptr<IOBuffer> buffer,
int size) { int size) {
if (client_error_ < 0 || server_error_ < 0) if (errors_[kClient] < 0 || errors_[kServer] < 0)
return; return;
auto drainable = base::MakeRefCounted<DrainableIOBuffer>(buffer.get(), size); auto drainable = base::MakeRefCounted<DrainableIOBuffer>(buffer.get(), size);
int rv = to->Write( int rv = sockets_[to]->Write(
drainable.get(), size, drainable.get(), size,
base::Bind(&NaiveClientConnection::OnWriteComplete, base::Bind(&NaiveClientConnection::OnWriteComplete,
weak_ptr_factory_.GetWeakPtr(), from, to, drainable)); weak_ptr_factory_.GetWeakPtr(), from, to, drainable));
@ -221,31 +236,20 @@ void NaiveClientConnection::Push(StreamSocket* from,
OnWriteComplete(from, to, drainable, rv); OnWriteComplete(from, to, drainable, rv);
} }
void NaiveClientConnection::OnIOError(StreamSocket* socket, int error) { void NaiveClientConnection::OnIOError(Direction from, int error) {
// Avoids running run_callback_ again. // Avoids running run_callback_ again.
if (client_error_ < 0 || server_error_ < 0) if (errors_[kClient] < 0 || errors_[kServer] < 0)
return; return;
if (socket == client_socket_.get()) { if (errors_[from] == OK) {
if (client_error_ == OK) { DCHECK(run_callback_);
DCHECK(run_callback_); base::ResetAndReturn(&run_callback_).Run(error);
base::ResetAndReturn(&run_callback_).Run(error);
}
client_error_ = error;
return;
}
if (socket == server_socket_handle_->socket()) {
if (server_error_ == OK) {
DCHECK(run_callback_);
base::ResetAndReturn(&run_callback_).Run(error);
}
server_error_ = error;
return;
} }
errors_[from] = error;
} }
void NaiveClientConnection::OnReadComplete(StreamSocket* from, void NaiveClientConnection::OnReadComplete(Direction from,
StreamSocket* to, Direction to,
scoped_refptr<IOBuffer> buffer, scoped_refptr<IOBuffer> buffer,
int result) { int result) {
if (result <= 0) { if (result <= 0) {
@ -257,8 +261,8 @@ void NaiveClientConnection::OnReadComplete(StreamSocket* from,
} }
void NaiveClientConnection::OnWriteComplete( void NaiveClientConnection::OnWriteComplete(
StreamSocket* from, Direction from,
StreamSocket* to, Direction to,
scoped_refptr<DrainableIOBuffer> drainable, scoped_refptr<DrainableIOBuffer> drainable,
int result) { int result) {
if (result < 0) { if (result < 0) {
@ -266,6 +270,8 @@ void NaiveClientConnection::OnWriteComplete(
return; return;
} }
bytes_passed_without_yielding_[from] += result;
drainable->DidConsume(result); drainable->DidConsume(result);
int size = drainable->BytesRemaining(); int size = drainable->BytesRemaining();
if (size > 0) { if (size > 0) {
@ -273,7 +279,18 @@ void NaiveClientConnection::OnWriteComplete(
return; return;
} }
Pull(from, to); if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead ||
time_func_() > yield_after_time_[from]) {
bytes_passed_without_yielding_[from] = 0;
yield_after_time_[from] =
time_func_() +
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&NaiveClientConnection::Pull,
weak_ptr_factory_.GetWeakPtr(), from, to));
} else {
Pull(from, to);
}
} }
} // namespace net } // namespace net

View File

@ -11,6 +11,7 @@
#include "base/macros.h" #include "base/macros.h"
#include "base/memory/ref_counted.h" #include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h" #include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "net/base/completion_callback.h" #include "net/base/completion_callback.h"
#include "net/base/host_port_pair.h" #include "net/base/host_port_pair.h"
#include "net/log/net_log_with_source.h" #include "net/log/net_log_with_source.h"
@ -28,6 +29,8 @@ class StreamSocket;
class NaiveClientConnection { class NaiveClientConnection {
public: public:
using TimeFunc = base::TimeTicks (*)();
NaiveClientConnection(int id, NaiveClientConnection(int id,
std::unique_ptr<StreamSocket> accepted_socket, std::unique_ptr<StreamSocket> accepted_socket,
HttpNetworkSession* session); HttpNetworkSession* session);
@ -47,6 +50,13 @@ class NaiveClientConnection {
STATE_NONE, STATE_NONE,
}; };
// From this direction.
enum Direction {
kClient = 0,
kServer = 1,
kNumDirections = 2,
};
void DoCallback(int result); void DoCallback(int result);
void OnIOComplete(int result); void OnIOComplete(int result);
int DoLoop(int last_io_result); int DoLoop(int last_io_result);
@ -54,18 +64,18 @@ class NaiveClientConnection {
int DoConnectClientComplete(int result); int DoConnectClientComplete(int result);
int DoConnectServer(); int DoConnectServer();
int DoConnectServerComplete(int result); int DoConnectServerComplete(int result);
void Pull(StreamSocket* from, StreamSocket* to); void Pull(Direction from, Direction to);
void Push(StreamSocket* from, void Push(Direction from,
StreamSocket* to, Direction to,
scoped_refptr<IOBuffer> buffer, scoped_refptr<IOBuffer> buffer,
int size); int size);
void OnIOError(StreamSocket* socket, int error); void OnIOError(Direction from, int error);
void OnReadComplete(StreamSocket* from, void OnReadComplete(Direction from,
StreamSocket* to, Direction to,
scoped_refptr<IOBuffer> buffer, scoped_refptr<IOBuffer> buffer,
int result); int result);
void OnWriteComplete(StreamSocket* from, void OnWriteComplete(Direction from,
StreamSocket* to, Direction to,
scoped_refptr<DrainableIOBuffer> drainable, scoped_refptr<DrainableIOBuffer> drainable,
int result); int result);
@ -83,14 +93,17 @@ class NaiveClientConnection {
HostPortPair request_endpoint_; HostPortPair request_endpoint_;
std::unique_ptr<Socks5ServerSocket> client_socket_; std::unique_ptr<Socks5ServerSocket> client_socket_;
std::unique_ptr<StreamSocket> server_socket_;
std::unique_ptr<ClientSocketHandle> server_socket_handle_; std::unique_ptr<ClientSocketHandle> server_socket_handle_;
int client_error_; StreamSocket* sockets_[kNumDirections];
int server_error_; int errors_[kNumDirections];
int bytes_passed_without_yielding_[kNumDirections];
base::TimeTicks yield_after_time_[kNumDirections];
bool full_duplex_; bool full_duplex_;
TimeFunc time_func_;
base::WeakPtrFactory<NaiveClientConnection> weak_ptr_factory_; base::WeakPtrFactory<NaiveClientConnection> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection); DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection);