From 7acf2e69408e6e6373d2d65281d36fc6482d3a31 Mon Sep 17 00:00:00 2001 From: klzgrad Date: Thu, 25 Jan 2018 10:33:13 -0500 Subject: [PATCH] Yield for socket pump I/O properly Avoid doing one direction for too long. --- net/tools/naive/naive_client_connection.cc | 89 +++++++++++++--------- net/tools/naive/naive_client_connection.h | 35 ++++++--- 2 files changed, 77 insertions(+), 47 deletions(-) diff --git a/net/tools/naive/naive_client_connection.cc b/net/tools/naive/naive_client_connection.cc index 53d0a57fde..b7e836d18f 100644 --- a/net/tools/naive/naive_client_connection.cc +++ b/net/tools/naive/naive_client_connection.cc @@ -21,6 +21,7 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_manager.h" #include "net/socket/stream_socket.h" +#include "net/spdy/chromium/spdy_session.h" #include "net/tools/naive/socks5_server_socket.h" namespace net { @@ -41,9 +42,8 @@ NaiveClientConnection::NaiveClientConnection( client_socket_( std::make_unique(std::move(accepted_socket))), server_socket_handle_(std::make_unique()), - client_error_(OK), - server_error_(OK), full_duplex_(false), + time_func_(&base::TimeTicks::Now), weak_ptr_factory_(this) { io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete, weak_ptr_factory_.GetWeakPtr()); @@ -185,34 +185,49 @@ int NaiveClientConnection::Run(const CompletionCallback& callback) { run_callback_ = callback; - Pull(client_socket_.get(), server_socket_handle_->socket()); - Pull(server_socket_handle_->socket(), client_socket_.get()); + sockets_[kClient] = client_socket_.get(); + sockets_[kServer] = server_socket_handle_->socket(); + + errors_[kClient] = OK; + errors_[kServer] = OK; + + bytes_passed_without_yielding_[kClient] = 0; + bytes_passed_without_yielding_[kServer] = 0; + + yield_after_time_[kClient] = + time_func_() + + base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds); + yield_after_time_[kServer] = yield_after_time_[kClient]; + + Pull(kClient, kServer); + Pull(kServer, kClient); + return ERR_IO_PENDING; } -void NaiveClientConnection::Pull(StreamSocket* from, StreamSocket* to) { - if (client_error_ < 0 || server_error_ < 0) +void NaiveClientConnection::Pull(Direction from, Direction to) { + if (errors_[kClient] < 0 || errors_[kServer] < 0) return; auto buffer = base::MakeRefCounted(kBufferSize); - int rv = - from->Read(buffer.get(), kBufferSize, - base::Bind(&NaiveClientConnection::OnReadComplete, - weak_ptr_factory_.GetWeakPtr(), from, to, buffer)); + int rv = sockets_[from]->Read( + buffer.get(), kBufferSize, + base::Bind(&NaiveClientConnection::OnReadComplete, + weak_ptr_factory_.GetWeakPtr(), from, to, buffer)); if (rv != ERR_IO_PENDING) OnReadComplete(from, to, buffer, rv); } -void NaiveClientConnection::Push(StreamSocket* from, - StreamSocket* to, +void NaiveClientConnection::Push(Direction from, + Direction to, scoped_refptr buffer, int size) { - if (client_error_ < 0 || server_error_ < 0) + if (errors_[kClient] < 0 || errors_[kServer] < 0) return; auto drainable = base::MakeRefCounted(buffer.get(), size); - int rv = to->Write( + int rv = sockets_[to]->Write( drainable.get(), size, base::Bind(&NaiveClientConnection::OnWriteComplete, weak_ptr_factory_.GetWeakPtr(), from, to, drainable)); @@ -221,31 +236,20 @@ void NaiveClientConnection::Push(StreamSocket* from, OnWriteComplete(from, to, drainable, rv); } -void NaiveClientConnection::OnIOError(StreamSocket* socket, int error) { +void NaiveClientConnection::OnIOError(Direction from, int error) { // Avoids running run_callback_ again. - if (client_error_ < 0 || server_error_ < 0) + if (errors_[kClient] < 0 || errors_[kServer] < 0) return; - if (socket == client_socket_.get()) { - if (client_error_ == OK) { - DCHECK(run_callback_); - base::ResetAndReturn(&run_callback_).Run(error); - } - client_error_ = error; - return; - } - if (socket == server_socket_handle_->socket()) { - if (server_error_ == OK) { - DCHECK(run_callback_); - base::ResetAndReturn(&run_callback_).Run(error); - } - server_error_ = error; - return; + if (errors_[from] == OK) { + DCHECK(run_callback_); + base::ResetAndReturn(&run_callback_).Run(error); } + errors_[from] = error; } -void NaiveClientConnection::OnReadComplete(StreamSocket* from, - StreamSocket* to, +void NaiveClientConnection::OnReadComplete(Direction from, + Direction to, scoped_refptr buffer, int result) { if (result <= 0) { @@ -257,8 +261,8 @@ void NaiveClientConnection::OnReadComplete(StreamSocket* from, } void NaiveClientConnection::OnWriteComplete( - StreamSocket* from, - StreamSocket* to, + Direction from, + Direction to, scoped_refptr drainable, int result) { if (result < 0) { @@ -266,6 +270,8 @@ void NaiveClientConnection::OnWriteComplete( return; } + bytes_passed_without_yielding_[from] += result; + drainable->DidConsume(result); int size = drainable->BytesRemaining(); if (size > 0) { @@ -273,7 +279,18 @@ void NaiveClientConnection::OnWriteComplete( return; } - Pull(from, to); + if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead || + time_func_() > yield_after_time_[from]) { + bytes_passed_without_yielding_[from] = 0; + yield_after_time_[from] = + time_func_() + + base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds); + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::Bind(&NaiveClientConnection::Pull, + weak_ptr_factory_.GetWeakPtr(), from, to)); + } else { + Pull(from, to); + } } } // namespace net diff --git a/net/tools/naive/naive_client_connection.h b/net/tools/naive/naive_client_connection.h index bcdee0bf5c..095e126602 100644 --- a/net/tools/naive/naive_client_connection.h +++ b/net/tools/naive/naive_client_connection.h @@ -11,6 +11,7 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" +#include "base/time/time.h" #include "net/base/completion_callback.h" #include "net/base/host_port_pair.h" #include "net/log/net_log_with_source.h" @@ -28,6 +29,8 @@ class StreamSocket; class NaiveClientConnection { public: + using TimeFunc = base::TimeTicks (*)(); + NaiveClientConnection(int id, std::unique_ptr accepted_socket, HttpNetworkSession* session); @@ -47,6 +50,13 @@ class NaiveClientConnection { STATE_NONE, }; + // From this direction. + enum Direction { + kClient = 0, + kServer = 1, + kNumDirections = 2, + }; + void DoCallback(int result); void OnIOComplete(int result); int DoLoop(int last_io_result); @@ -54,18 +64,18 @@ class NaiveClientConnection { int DoConnectClientComplete(int result); int DoConnectServer(); int DoConnectServerComplete(int result); - void Pull(StreamSocket* from, StreamSocket* to); - void Push(StreamSocket* from, - StreamSocket* to, + void Pull(Direction from, Direction to); + void Push(Direction from, + Direction to, scoped_refptr buffer, int size); - void OnIOError(StreamSocket* socket, int error); - void OnReadComplete(StreamSocket* from, - StreamSocket* to, + void OnIOError(Direction from, int error); + void OnReadComplete(Direction from, + Direction to, scoped_refptr buffer, int result); - void OnWriteComplete(StreamSocket* from, - StreamSocket* to, + void OnWriteComplete(Direction from, + Direction to, scoped_refptr drainable, int result); @@ -83,14 +93,17 @@ class NaiveClientConnection { HostPortPair request_endpoint_; std::unique_ptr client_socket_; - std::unique_ptr server_socket_; std::unique_ptr server_socket_handle_; - int client_error_; - int server_error_; + StreamSocket* sockets_[kNumDirections]; + int errors_[kNumDirections]; + int bytes_passed_without_yielding_[kNumDirections]; + base::TimeTicks yield_after_time_[kNumDirections]; bool full_duplex_; + TimeFunc time_func_; + base::WeakPtrFactory weak_ptr_factory_; DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection);