mirror of
https://github.com/klzgrad/naiveproxy.git
synced 2024-11-24 06:16:30 +03:00
Yield for socket pump I/O properly
Avoid doing one direction for too long.
This commit is contained in:
parent
2fe35ccbe9
commit
7acf2e6940
@ -21,6 +21,7 @@
|
||||
#include "net/socket/client_socket_handle.h"
|
||||
#include "net/socket/client_socket_pool_manager.h"
|
||||
#include "net/socket/stream_socket.h"
|
||||
#include "net/spdy/chromium/spdy_session.h"
|
||||
#include "net/tools/naive/socks5_server_socket.h"
|
||||
|
||||
namespace net {
|
||||
@ -41,9 +42,8 @@ NaiveClientConnection::NaiveClientConnection(
|
||||
client_socket_(
|
||||
std::make_unique<Socks5ServerSocket>(std::move(accepted_socket))),
|
||||
server_socket_handle_(std::make_unique<ClientSocketHandle>()),
|
||||
client_error_(OK),
|
||||
server_error_(OK),
|
||||
full_duplex_(false),
|
||||
time_func_(&base::TimeTicks::Now),
|
||||
weak_ptr_factory_(this) {
|
||||
io_callback_ = base::Bind(&NaiveClientConnection::OnIOComplete,
|
||||
weak_ptr_factory_.GetWeakPtr());
|
||||
@ -185,34 +185,49 @@ int NaiveClientConnection::Run(const CompletionCallback& callback) {
|
||||
|
||||
run_callback_ = callback;
|
||||
|
||||
Pull(client_socket_.get(), server_socket_handle_->socket());
|
||||
Pull(server_socket_handle_->socket(), client_socket_.get());
|
||||
sockets_[kClient] = client_socket_.get();
|
||||
sockets_[kServer] = server_socket_handle_->socket();
|
||||
|
||||
errors_[kClient] = OK;
|
||||
errors_[kServer] = OK;
|
||||
|
||||
bytes_passed_without_yielding_[kClient] = 0;
|
||||
bytes_passed_without_yielding_[kServer] = 0;
|
||||
|
||||
yield_after_time_[kClient] =
|
||||
time_func_() +
|
||||
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
|
||||
yield_after_time_[kServer] = yield_after_time_[kClient];
|
||||
|
||||
Pull(kClient, kServer);
|
||||
Pull(kServer, kClient);
|
||||
|
||||
return ERR_IO_PENDING;
|
||||
}
|
||||
|
||||
void NaiveClientConnection::Pull(StreamSocket* from, StreamSocket* to) {
|
||||
if (client_error_ < 0 || server_error_ < 0)
|
||||
void NaiveClientConnection::Pull(Direction from, Direction to) {
|
||||
if (errors_[kClient] < 0 || errors_[kServer] < 0)
|
||||
return;
|
||||
|
||||
auto buffer = base::MakeRefCounted<IOBuffer>(kBufferSize);
|
||||
int rv =
|
||||
from->Read(buffer.get(), kBufferSize,
|
||||
base::Bind(&NaiveClientConnection::OnReadComplete,
|
||||
weak_ptr_factory_.GetWeakPtr(), from, to, buffer));
|
||||
int rv = sockets_[from]->Read(
|
||||
buffer.get(), kBufferSize,
|
||||
base::Bind(&NaiveClientConnection::OnReadComplete,
|
||||
weak_ptr_factory_.GetWeakPtr(), from, to, buffer));
|
||||
|
||||
if (rv != ERR_IO_PENDING)
|
||||
OnReadComplete(from, to, buffer, rv);
|
||||
}
|
||||
|
||||
void NaiveClientConnection::Push(StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
void NaiveClientConnection::Push(Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<IOBuffer> buffer,
|
||||
int size) {
|
||||
if (client_error_ < 0 || server_error_ < 0)
|
||||
if (errors_[kClient] < 0 || errors_[kServer] < 0)
|
||||
return;
|
||||
|
||||
auto drainable = base::MakeRefCounted<DrainableIOBuffer>(buffer.get(), size);
|
||||
int rv = to->Write(
|
||||
int rv = sockets_[to]->Write(
|
||||
drainable.get(), size,
|
||||
base::Bind(&NaiveClientConnection::OnWriteComplete,
|
||||
weak_ptr_factory_.GetWeakPtr(), from, to, drainable));
|
||||
@ -221,31 +236,20 @@ void NaiveClientConnection::Push(StreamSocket* from,
|
||||
OnWriteComplete(from, to, drainable, rv);
|
||||
}
|
||||
|
||||
void NaiveClientConnection::OnIOError(StreamSocket* socket, int error) {
|
||||
void NaiveClientConnection::OnIOError(Direction from, int error) {
|
||||
// Avoids running run_callback_ again.
|
||||
if (client_error_ < 0 || server_error_ < 0)
|
||||
if (errors_[kClient] < 0 || errors_[kServer] < 0)
|
||||
return;
|
||||
|
||||
if (socket == client_socket_.get()) {
|
||||
if (client_error_ == OK) {
|
||||
DCHECK(run_callback_);
|
||||
base::ResetAndReturn(&run_callback_).Run(error);
|
||||
}
|
||||
client_error_ = error;
|
||||
return;
|
||||
}
|
||||
if (socket == server_socket_handle_->socket()) {
|
||||
if (server_error_ == OK) {
|
||||
DCHECK(run_callback_);
|
||||
base::ResetAndReturn(&run_callback_).Run(error);
|
||||
}
|
||||
server_error_ = error;
|
||||
return;
|
||||
if (errors_[from] == OK) {
|
||||
DCHECK(run_callback_);
|
||||
base::ResetAndReturn(&run_callback_).Run(error);
|
||||
}
|
||||
errors_[from] = error;
|
||||
}
|
||||
|
||||
void NaiveClientConnection::OnReadComplete(StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
void NaiveClientConnection::OnReadComplete(Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<IOBuffer> buffer,
|
||||
int result) {
|
||||
if (result <= 0) {
|
||||
@ -257,8 +261,8 @@ void NaiveClientConnection::OnReadComplete(StreamSocket* from,
|
||||
}
|
||||
|
||||
void NaiveClientConnection::OnWriteComplete(
|
||||
StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<DrainableIOBuffer> drainable,
|
||||
int result) {
|
||||
if (result < 0) {
|
||||
@ -266,6 +270,8 @@ void NaiveClientConnection::OnWriteComplete(
|
||||
return;
|
||||
}
|
||||
|
||||
bytes_passed_without_yielding_[from] += result;
|
||||
|
||||
drainable->DidConsume(result);
|
||||
int size = drainable->BytesRemaining();
|
||||
if (size > 0) {
|
||||
@ -273,7 +279,18 @@ void NaiveClientConnection::OnWriteComplete(
|
||||
return;
|
||||
}
|
||||
|
||||
Pull(from, to);
|
||||
if (bytes_passed_without_yielding_[from] > kYieldAfterBytesRead ||
|
||||
time_func_() > yield_after_time_[from]) {
|
||||
bytes_passed_without_yielding_[from] = 0;
|
||||
yield_after_time_[from] =
|
||||
time_func_() +
|
||||
base::TimeDelta::FromMilliseconds(kYieldAfterDurationMilliseconds);
|
||||
base::ThreadTaskRunnerHandle::Get()->PostTask(
|
||||
FROM_HERE, base::Bind(&NaiveClientConnection::Pull,
|
||||
weak_ptr_factory_.GetWeakPtr(), from, to));
|
||||
} else {
|
||||
Pull(from, to);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace net
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "base/macros.h"
|
||||
#include "base/memory/ref_counted.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/time/time.h"
|
||||
#include "net/base/completion_callback.h"
|
||||
#include "net/base/host_port_pair.h"
|
||||
#include "net/log/net_log_with_source.h"
|
||||
@ -28,6 +29,8 @@ class StreamSocket;
|
||||
|
||||
class NaiveClientConnection {
|
||||
public:
|
||||
using TimeFunc = base::TimeTicks (*)();
|
||||
|
||||
NaiveClientConnection(int id,
|
||||
std::unique_ptr<StreamSocket> accepted_socket,
|
||||
HttpNetworkSession* session);
|
||||
@ -47,6 +50,13 @@ class NaiveClientConnection {
|
||||
STATE_NONE,
|
||||
};
|
||||
|
||||
// From this direction.
|
||||
enum Direction {
|
||||
kClient = 0,
|
||||
kServer = 1,
|
||||
kNumDirections = 2,
|
||||
};
|
||||
|
||||
void DoCallback(int result);
|
||||
void OnIOComplete(int result);
|
||||
int DoLoop(int last_io_result);
|
||||
@ -54,18 +64,18 @@ class NaiveClientConnection {
|
||||
int DoConnectClientComplete(int result);
|
||||
int DoConnectServer();
|
||||
int DoConnectServerComplete(int result);
|
||||
void Pull(StreamSocket* from, StreamSocket* to);
|
||||
void Push(StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
void Pull(Direction from, Direction to);
|
||||
void Push(Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<IOBuffer> buffer,
|
||||
int size);
|
||||
void OnIOError(StreamSocket* socket, int error);
|
||||
void OnReadComplete(StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
void OnIOError(Direction from, int error);
|
||||
void OnReadComplete(Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<IOBuffer> buffer,
|
||||
int result);
|
||||
void OnWriteComplete(StreamSocket* from,
|
||||
StreamSocket* to,
|
||||
void OnWriteComplete(Direction from,
|
||||
Direction to,
|
||||
scoped_refptr<DrainableIOBuffer> drainable,
|
||||
int result);
|
||||
|
||||
@ -83,14 +93,17 @@ class NaiveClientConnection {
|
||||
HostPortPair request_endpoint_;
|
||||
|
||||
std::unique_ptr<Socks5ServerSocket> client_socket_;
|
||||
std::unique_ptr<StreamSocket> server_socket_;
|
||||
std::unique_ptr<ClientSocketHandle> server_socket_handle_;
|
||||
|
||||
int client_error_;
|
||||
int server_error_;
|
||||
StreamSocket* sockets_[kNumDirections];
|
||||
int errors_[kNumDirections];
|
||||
int bytes_passed_without_yielding_[kNumDirections];
|
||||
base::TimeTicks yield_after_time_[kNumDirections];
|
||||
|
||||
bool full_duplex_;
|
||||
|
||||
TimeFunc time_func_;
|
||||
|
||||
base::WeakPtrFactory<NaiveClientConnection> weak_ptr_factory_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(NaiveClientConnection);
|
||||
|
Loading…
Reference in New Issue
Block a user