diff --git a/net/tools/naive/naive_client_connection.cc b/net/tools/naive/naive_client_connection.cc index b7e836d18f..bcb0d9dd48 100644 --- a/net/tools/naive/naive_client_connection.cc +++ b/net/tools/naive/naive_client_connection.cc @@ -42,6 +42,12 @@ NaiveClientConnection::NaiveClientConnection( client_socket_( std::make_unique(std::move(accepted_socket))), server_socket_handle_(std::make_unique()), + sockets_{client_socket_.get(), nullptr}, + errors_{OK, OK}, + write_pending_{false, false}, + early_pull_pending_(false), + can_push_to_server_(false), + early_pull_result_(ERR_IO_PENDING), full_duplex_(false), time_func_(&base::TimeTicks::Now), weak_ptr_factory_(this) { @@ -138,6 +144,15 @@ int NaiveClientConnection::DoConnectClientComplete(int result) { if (result < 0) return result; + early_pull_pending_ = true; + Pull(kClient, kServer); + if (early_pull_result_ != ERR_IO_PENDING) { + // Pull has completed synchronously. + if (early_pull_result_ <= 0) { + return early_pull_result_ ? early_pull_result_ : ERR_CONNECTION_CLOSED; + } + } + next_state_ = STATE_CONNECT_SERVER; return OK; } @@ -172,25 +187,27 @@ int NaiveClientConnection::DoConnectServerComplete(int result) { if (result < 0) return result; + DCHECK(server_socket_handle_->socket()); + sockets_[kServer] = server_socket_handle_->socket(); + full_duplex_ = true; next_state_ = STATE_NONE; return OK; } int NaiveClientConnection::Run(const CompletionCallback& callback) { - DCHECK(client_socket_); - DCHECK(server_socket_handle_->socket()); + DCHECK(sockets_[kClient]); + DCHECK(sockets_[kServer]); DCHECK_EQ(next_state_, STATE_NONE); DCHECK(!connect_callback_); + if (errors_[kClient] != OK) + return errors_[kClient]; + if (errors_[kServer] != OK) + return errors_[kServer]; + run_callback_ = callback; - sockets_[kClient] = client_socket_.get(); - sockets_[kServer] = server_socket_handle_->socket(); - - errors_[kClient] = OK; - errors_[kServer] = OK; - bytes_passed_without_yielding_[kClient] = 0; bytes_passed_without_yielding_[kServer] = 0; @@ -199,7 +216,11 @@ int NaiveClientConnection::Run(const CompletionCallback& callback) { base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds); yield_after_time_[kServer] = yield_after_time_[kClient]; - Pull(kClient, kServer); + can_push_to_server_ = true; + if (!early_pull_pending_) { + DCHECK_GT(early_pull_result_, 0); + Push(kClient, kServer, early_pull_result_); + } Pull(kServer, kClient); return ERR_IO_PENDING; @@ -209,75 +230,124 @@ void NaiveClientConnection::Pull(Direction from, Direction to) { if (errors_[kClient] < 0 || errors_[kServer] < 0) return; - auto buffer = base::MakeRefCounted(kBufferSize); + read_buffers_[from] = new IOBuffer(kBufferSize); + DCHECK(sockets_[from]); int rv = sockets_[from]->Read( - buffer.get(), kBufferSize, - base::Bind(&NaiveClientConnection::OnReadComplete, - weak_ptr_factory_.GetWeakPtr(), from, to, buffer)); + read_buffers_[from].get(), kBufferSize, + base::Bind(&NaiveClientConnection::OnPullComplete, + weak_ptr_factory_.GetWeakPtr(), from, to)); + + if (from == kClient && early_pull_pending_) + early_pull_result_ = rv; if (rv != ERR_IO_PENDING) - OnReadComplete(from, to, buffer, rv); + OnPullComplete(from, to, rv); } -void NaiveClientConnection::Push(Direction from, - Direction to, - scoped_refptr buffer, - int size) { - if (errors_[kClient] < 0 || errors_[kServer] < 0) - return; - - auto drainable = base::MakeRefCounted(buffer.get(), size); - int rv = sockets_[to]->Write( - drainable.get(), size, - base::Bind(&NaiveClientConnection::OnWriteComplete, - weak_ptr_factory_.GetWeakPtr(), from, to, drainable)); +void NaiveClientConnection::Push(Direction from, Direction to, int size) { + write_buffers_[to] = new DrainableIOBuffer(read_buffers_[from].get(), size); + write_pending_[to] = true; + DCHECK(sockets_[to]); + int rv = + sockets_[to]->Write(write_buffers_[to].get(), size, + base::Bind(&NaiveClientConnection::OnPushComplete, + weak_ptr_factory_.GetWeakPtr(), from, to)); if (rv != ERR_IO_PENDING) - OnWriteComplete(from, to, drainable, rv); + OnPushComplete(from, to, rv); } -void NaiveClientConnection::OnIOError(Direction from, int error) { - // Avoids running run_callback_ again. - if (errors_[kClient] < 0 || errors_[kServer] < 0) - return; +void NaiveClientConnection::Disconnect(Direction side) { + if (sockets_[side]) { + sockets_[side]->Disconnect(); + sockets_[side] = nullptr; + write_pending_[side] = false; + } +} - if (errors_[from] == OK) { - DCHECK(run_callback_); +bool NaiveClientConnection::IsConnected(Direction side) { + return sockets_[side]; +} + +void NaiveClientConnection::OnBothDisconnected() { + if (run_callback_) { + int error = OK; + if (errors_[kClient] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kClient]; + if (errors_[kServer] != ERR_CONNECTION_CLOSED && errors_[kClient] < 0) + error = errors_[kServer]; base::ResetAndReturn(&run_callback_).Run(error); } +} + +void NaiveClientConnection::OnPullError(Direction from, + Direction to, + int error) { + DCHECK_LT(error, 0); + errors_[from] = error; + Disconnect(from); + + if (!write_pending_[to]) + Disconnect(to); + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); } -void NaiveClientConnection::OnReadComplete(Direction from, +void NaiveClientConnection::OnPushError(Direction from, + Direction to, + int error) { + DCHECK_LE(error, 0); + DCHECK(!write_pending_[to]); + + if (error < 0) { + errors_[to] = error; + Disconnect(kServer); + Disconnect(kClient); + } else if (!IsConnected(from)) { + Disconnect(to); + } + + if (!IsConnected(from) && !IsConnected(to)) + OnBothDisconnected(); +} + +void NaiveClientConnection::OnPullComplete(Direction from, Direction to, - scoped_refptr buffer, int result) { + if (from == kClient && early_pull_pending_) { + early_pull_pending_ = false; + early_pull_result_ = result; + } + if (result <= 0) { - OnIOError(from, result ? result : ERR_CONNECTION_CLOSED); + OnPullError(from, to, result ? result : ERR_CONNECTION_CLOSED); return; } - Push(from, to, buffer, result); + if (from == kClient && !can_push_to_server_) + return; + + Push(from, to, result); } -void NaiveClientConnection::OnWriteComplete( - Direction from, - Direction to, - scoped_refptr drainable, - int result) { - if (result < 0) { - OnIOError(to, result); - return; +void NaiveClientConnection::OnPushComplete(Direction from, + Direction to, + int result) { + if (result >= 0) { + bytes_passed_without_yielding_[from] += result; + write_buffers_[to]->DidConsume(result); + int size = write_buffers_[to]->BytesRemaining(); + if (size > 0) { + Push(from, to, size); + return; + } } - bytes_passed_without_yielding_[from] += result; - - drainable->DidConsume(result); - int size = drainable->BytesRemaining(); - if (size > 0) { - Push(from, to, drainable.get(), size); - return; - } + write_pending_[to] = false; + // Checks for termination even if result is OK. + OnPushError(from, to, result >= 0 ? OK : result); if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead || time_func_() > yield_after_time_[from]) { diff --git a/net/tools/naive/naive_client_connection.h b/net/tools/naive/naive_client_connection.h index 095e126602..c0a56ef1ff 100644 --- a/net/tools/naive/naive_client_connection.h +++ b/net/tools/naive/naive_client_connection.h @@ -65,19 +65,14 @@ class NaiveClientConnection { int DoConnectServer(); int DoConnectServerComplete(int result); void Pull(Direction from, Direction to); - void Push(Direction from, - Direction to, - scoped_refptr buffer, - int size); - void OnIOError(Direction from, int error); - void OnReadComplete(Direction from, - Direction to, - scoped_refptr buffer, - int result); - void OnWriteComplete(Direction from, - Direction to, - scoped_refptr drainable, - int result); + void Push(Direction from, Direction to, int size); + void Disconnect(Direction side); + bool IsConnected(Direction side); + void OnBothDisconnected(); + void OnPullError(Direction from, Direction to, int error); + void OnPushError(Direction from, Direction to, int error); + void OnPullComplete(Direction from, Direction to, int result); + void OnPushComplete(Direction from, Direction to, int result); int id_; @@ -96,10 +91,17 @@ class NaiveClientConnection { std::unique_ptr server_socket_handle_; StreamSocket* sockets_[kNumDirections]; + scoped_refptr read_buffers_[kNumDirections]; + scoped_refptr write_buffers_[kNumDirections]; int errors_[kNumDirections]; + bool write_pending_[kNumDirections]; int bytes_passed_without_yielding_[kNumDirections]; base::TimeTicks yield_after_time_[kNumDirections]; + bool early_pull_pending_; + bool can_push_to_server_; + int early_pull_result_; + bool full_duplex_; TimeFunc time_func_;