// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/base/mock_file_stream.h" #include #include "base/bind.h" #include "base/location.h" #include "base/single_thread_task_runner.h" #include "base/threading/thread_task_runner_handle.h" namespace net { namespace testing { MockFileStream::MockFileStream( const scoped_refptr& task_runner) : FileStream(task_runner), forced_error_(OK), async_error_(false), throttled_(false), weak_factory_(this) { } MockFileStream::MockFileStream( base::File file, const scoped_refptr& task_runner) : FileStream(std::move(file), task_runner), forced_error_(OK), async_error_(false), throttled_(false), weak_factory_(this) {} MockFileStream::~MockFileStream() = default; int MockFileStream::Seek(int64_t offset, const Int64CompletionCallback& callback) { Int64CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback64, weak_factory_.GetWeakPtr(), callback); if (forced_error_ == OK) return FileStream::Seek(offset, wrapped_callback); return ErrorCallback64(wrapped_callback); } int MockFileStream::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback, weak_factory_.GetWeakPtr(), callback); if (forced_error_ == OK) return FileStream::Read(buf, buf_len, wrapped_callback); return ErrorCallback(wrapped_callback); } int MockFileStream::Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback, weak_factory_.GetWeakPtr(), callback); if (forced_error_ == OK) return FileStream::Write(buf, buf_len, wrapped_callback); return ErrorCallback(wrapped_callback); } int MockFileStream::Flush(const CompletionCallback& callback) { CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback, weak_factory_.GetWeakPtr(), callback); if (forced_error_ == OK) return FileStream::Flush(wrapped_callback); return ErrorCallback(wrapped_callback); } void MockFileStream::ThrottleCallbacks() { CHECK(!throttled_); throttled_ = true; } void MockFileStream::ReleaseCallbacks() { CHECK(throttled_); throttled_ = false; if (!throttled_task_.is_null()) { base::Closure throttled_task = throttled_task_; throttled_task_.Reset(); base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE, throttled_task); } } void MockFileStream::DoCallback(const CompletionCallback& callback, int result) { if (!throttled_) { callback.Run(result); return; } CHECK(throttled_task_.is_null()); throttled_task_ = base::Bind(callback, result); } void MockFileStream::DoCallback64(const Int64CompletionCallback& callback, int64_t result) { if (!throttled_) { callback.Run(result); return; } CHECK(throttled_task_.is_null()); throttled_task_ = base::Bind(callback, result); } int MockFileStream::ErrorCallback(const CompletionCallback& callback) { CHECK_NE(OK, forced_error_); if (async_error_) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(callback, forced_error_)); clear_forced_error(); return ERR_IO_PENDING; } int ret = forced_error_; clear_forced_error(); return ret; } int64_t MockFileStream::ErrorCallback64( const Int64CompletionCallback& callback) { CHECK_NE(OK, forced_error_); if (async_error_) { base::ThreadTaskRunnerHandle::Get()->PostTask( FROM_HERE, base::Bind(callback, forced_error_)); clear_forced_error(); return ERR_IO_PENDING; } int64_t ret = forced_error_; clear_forced_error(); return ret; } } // namespace testing } // namespace net