naive: Yield for socket pump I/O properly

Avoid doing one direction for too long.
This commit is contained in:
klzgrad 2018-01-25 23:33:13 +08:00
parent 3f9e1f4c67
commit 21c9076faa
2 changed files with 77 additions and 47 deletions

View File

@ -21,6 +21,7 @@
#include "net/socket/client_socket_handle.h"
#include "net/socket/client_socket_pool_manager.h"
#include "net/socket/stream_socket.h"
#include "net/spdy/chromium/spdy_session.h"
#include "net/tools/naive/socks5_server_socket.h"
namespace net {
@ -41,9 +42,8 @@ NaiveClientConnection::NaiveClientConnection(
client_socket_(
std::make_unique<Socks5ServerSocket>(std::move(accepted_socket))),
server_socket_handle_(std::make_unique<ClientSocketHandle>()),
client_error_(OK),
server_error_(OK),
full_duplex_(false),
time_func_(&base::TimeTicks::Now),
weak_ptr_factory_(this) {
io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete,
weak_ptr_factory_.GetWeakPtr());
@ -185,34 +185,49 @@ int NaiveClientConnection::Run(const CompletionCallback& callback) {
run_callback_ = callback;
Pull(client_socket_.get(), server_socket_handle_->socket());
Pull(server_socket_handle_->socket(), client_socket_.get());
sockets_[kClient] = client_socket_.get();
sockets_[kServer] = server_socket_handle_->socket();
errors_[kClient] = OK;
errors_[kServer] = OK;
bytes_passed_without_yielding_[kClient] = 0;
bytes_passed_without_yielding_[kServer] = 0;
yield_after_time_[kClient] =
time_func_() +
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
yield_after_time_[kServer] = yield_after_time_[kClient];
Pull(kClient, kServer);
Pull(kServer, kClient);
return ERR_IO_PENDING;
}
void NaiveClientConnection::Pull(StreamSocket* from, StreamSocket* to) {
if (client_error_ < 0 || server_error_ < 0)
void NaiveClientConnection::Pull(Direction from, Direction to) {
if (errors_[kClient] < 0 || errors_[kServer] < 0)
return;
auto buffer = base::MakeRefCounted<IOBuffer>(kBufferSize);
int rv =
from->Read(buffer.get(), kBufferSize,
base::Bind(&NaiveClientConnection::OnReadComplete,
weak_ptr_factory_.GetWeakPtr(), from, to, buffer));
int rv = sockets_[from]->Read(
buffer.get(), kBufferSize,
base::Bind(&NaiveClientConnection::OnReadComplete,
weak_ptr_factory_.GetWeakPtr(), from, to, buffer));
if (rv != ERR_IO_PENDING)
OnReadComplete(from, to, buffer, rv);
}
void NaiveClientConnection::Push(StreamSocket* from,
StreamSocket* to,
void NaiveClientConnection::Push(Direction from,
Direction to,
scoped_refptr<IOBuffer> buffer,
int size) {
if (client_error_ < 0 || server_error_ < 0)
if (errors_[kClient] < 0 || errors_[kServer] < 0)
return;
auto drainable = base::MakeRefCounted<DrainableIOBuffer>(buffer.get(), size);
int rv = to->Write(
int rv = sockets_[to]->Write(
drainable.get(), size,
base::Bind(&NaiveClientConnection::OnWriteComplete,
weak_ptr_factory_.GetWeakPtr(), from, to, drainable));
@ -221,31 +236,20 @@ void NaiveClientConnection::Push(StreamSocket* from,
OnWriteComplete(from, to, drainable, rv);
}
void NaiveClientConnection::OnIOError(StreamSocket* socket, int error) {
void NaiveClientConnection::OnIOError(Direction from, int error) {
// Avoids running run_callback_ again.
if (client_error_ < 0 || server_error_ < 0)
if (errors_[kClient] < 0 || errors_[kServer] < 0)
return;
if (socket == client_socket_.get()) {
if (client_error_ == OK) {
DCHECK(run_callback_);
base::ResetAndReturn(&run_callback_).Run(error);
}
client_error_ = error;
return;
}
if (socket == server_socket_handle_->socket()) {
if (server_error_ == OK) {
DCHECK(run_callback_);
base::ResetAndReturn(&run_callback_).Run(error);
}
server_error_ = error;
return;
if (errors_[from] == OK) {
DCHECK(run_callback_);
base::ResetAndReturn(&run_callback_).Run(error);
}
errors_[from] = error;
}
void NaiveClientConnection::OnReadComplete(StreamSocket* from,
StreamSocket* to,
void NaiveClientConnection::OnReadComplete(Direction from,
Direction to,
scoped_refptr<IOBuffer> buffer,
int result) {
if (result <= 0) {
@ -257,8 +261,8 @@ void NaiveClientConnection::OnReadComplete(StreamSocket* from,
}
void NaiveClientConnection::OnWriteComplete(
StreamSocket* from,
StreamSocket* to,
Direction from,
Direction to,
scoped_refptr<DrainableIOBuffer> drainable,
int result) {
if (result < 0) {
@ -266,6 +270,8 @@ void NaiveClientConnection::OnWriteComplete(
return;
}
bytes_passed_without_yielding_[from] += result;
drainable->DidConsume(result);
int size = drainable->BytesRemaining();
if (size > 0) {
@ -273,7 +279,18 @@ void NaiveClientConnection::OnWriteComplete(
return;
}
Pull(from, to);
if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead ||
time_func_() > yield_after_time_[from]) {
bytes_passed_without_yielding_[from] = 0;
yield_after_time_[from] =
time_func_() +
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::Bind(&NaiveClientConnection::Pull,
weak_ptr_factory_.GetWeakPtr(), from, to));
} else {
Pull(from, to);
}
}
} // namespace net

View File

@ -11,6 +11,7 @@
#include "base/macros.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/time/time.h"
#include "net/base/completion_callback.h"
#include "net/base/host_port_pair.h"
#include "net/log/net_log_with_source.h"
@ -28,6 +29,8 @@ class StreamSocket;
class NaiveClientConnection {
public:
using TimeFunc = base::TimeTicks (*)();
NaiveClientConnection(int id,
std::unique_ptr<StreamSocket> accepted_socket,
HttpNetworkSession* session);
@ -47,6 +50,13 @@ class NaiveClientConnection {
STATE_NONE,
};
// From this direction.
enum Direction {
kClient = 0,
kServer = 1,
kNumDirections = 2,
};
void DoCallback(int result);
void OnIOComplete(int result);
int DoLoop(int last_io_result);
@ -54,18 +64,18 @@ class NaiveClientConnection {
int DoConnectClientComplete(int result);
int DoConnectServer();
int DoConnectServerComplete(int result);
void Pull(StreamSocket* from, StreamSocket* to);
void Push(StreamSocket* from,
StreamSocket* to,
void Pull(Direction from, Direction to);
void Push(Direction from,
Direction to,
scoped_refptr<IOBuffer> buffer,
int size);
void OnIOError(StreamSocket* socket, int error);
void OnReadComplete(StreamSocket* from,
StreamSocket* to,
void OnIOError(Direction from, int error);
void OnReadComplete(Direction from,
Direction to,
scoped_refptr<IOBuffer> buffer,
int result);
void OnWriteComplete(StreamSocket* from,
StreamSocket* to,
void OnWriteComplete(Direction from,
Direction to,
scoped_refptr<DrainableIOBuffer> drainable,
int result);
@ -83,14 +93,17 @@ class NaiveClientConnection {
HostPortPair request_endpoint_;
std::unique_ptr<Socks5ServerSocket> client_socket_;
std::unique_ptr<StreamSocket> server_socket_;
std::unique_ptr<ClientSocketHandle> server_socket_handle_;
int client_error_;
int server_error_;
StreamSocket* sockets_[kNumDirections];
int errors_[kNumDirections];
int bytes_passed_without_yielding_[kNumDirections];
base::TimeTicks yield_after_time_[kNumDirections];
bool full_duplex_;
TimeFunc time_func_;
base::WeakPtrFactory<NaiveClientConnection> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection);