// Copyright 2014 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/ssl/channel_id_service.h" #include #include #include #include #include "base/atomic_sequence_num.h" #include "base/bind.h" #include "base/bind_helpers.h" #include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/location.h" #include "base/logging.h" #include "base/macros.h" #include "base/memory/ptr_util.h" #include "base/metrics/histogram_macros.h" #include "base/rand_util.h" #include "base/single_thread_task_runner.h" #include "base/task/post_task.h" #include "base/task_runner.h" #include "base/threading/thread_task_runner_handle.h" #include "crypto/ec_private_key.h" #include "net/base/net_errors.h" #include "net/base/registry_controlled_domains/registry_controlled_domain.h" #include "net/cert/x509_certificate.h" #include "net/cert/x509_util.h" #include "url/gurl.h" namespace net { namespace { base::AtomicSequenceNumber g_next_id; // On success, returns a ChannelID object and sets |*error| to OK. // Otherwise, returns NULL, and |*error| will be set to a net error code. // |serial_number| is passed in because base::RandInt cannot be called from an // unjoined thread, due to relying on a non-leaked LazyInstance std::unique_ptr GenerateChannelID( const std::string& server_identifier, int* error) { std::unique_ptr result; base::Time creation_time = base::Time::Now(); std::unique_ptr key(crypto::ECPrivateKey::Create()); if (!key) { DLOG(ERROR) << "Unable to create channel ID key pair"; *error = ERR_KEY_GENERATION_FAILED; return result; } result.reset(new ChannelIDStore::ChannelID(server_identifier, creation_time, std::move(key))); *error = OK; return result; } } // namespace // ChannelIDServiceWorker takes care of the blocking process of performing key // generation. Will take care of deleting itself once Start() is called. class ChannelIDServiceWorker { public: typedef base::OnceCallback< void(const std::string&, int, std::unique_ptr)> WorkerDoneCallback; ChannelIDServiceWorker(const std::string& server_identifier, WorkerDoneCallback callback) : server_identifier_(server_identifier), origin_task_runner_(base::ThreadTaskRunnerHandle::Get()), callback_(std::move(callback)) {} // Starts the worker asynchronously. void Start(const scoped_refptr& task_runner) { DCHECK(origin_task_runner_->RunsTasksInCurrentSequence()); auto callback = base::BindOnce(&ChannelIDServiceWorker::Run, base::Owned(this)); if (task_runner) { task_runner->PostTask(FROM_HERE, std::move(callback)); } else { base::PostTaskWithTraits( FROM_HERE, {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}, std::move(callback)); } } private: void Run() { // Runs on a worker thread. int error = ERR_FAILED; std::unique_ptr channel_id = GenerateChannelID(server_identifier_, &error); origin_task_runner_->PostTask( FROM_HERE, base::BindOnce(std::move(callback_), server_identifier_, error, base::Passed(&channel_id))); } const std::string server_identifier_; scoped_refptr origin_task_runner_; WorkerDoneCallback callback_; DISALLOW_COPY_AND_ASSIGN(ChannelIDServiceWorker); }; // A ChannelIDServiceJob is a one-to-one counterpart of an // ChannelIDServiceWorker. It lives only on the ChannelIDService's // origin task runner's thread. class ChannelIDServiceJob { public: ChannelIDServiceJob(bool create_if_missing) : create_if_missing_(create_if_missing) { } ~ChannelIDServiceJob() { DCHECK(requests_.empty()); } void AddRequest(ChannelIDService::Request* request, bool create_if_missing = false) { create_if_missing_ |= create_if_missing; requests_.push_back(request); } void HandleResult(int error, std::unique_ptr key) { PostAll(error, std::move(key)); } bool CreateIfMissing() const { return create_if_missing_; } void CancelRequest(ChannelIDService::Request* req) { auto it = std::find(requests_.begin(), requests_.end(), req); if (it != requests_.end()) requests_.erase(it); } private: void PostAll(int error, std::unique_ptr key) { std::vector requests; requests_.swap(requests); for (std::vector::iterator i = requests.begin(); i != requests.end(); i++) { std::unique_ptr key_copy; if (key) key_copy = key->Copy(); (*i)->Post(error, std::move(key_copy)); } } std::vector requests_; bool create_if_missing_; }; ChannelIDService::Request::Request() : service_(NULL) { } ChannelIDService::Request::~Request() { Cancel(); } void ChannelIDService::Request::Cancel() { if (service_) { callback_.Reset(); job_->CancelRequest(this); service_ = NULL; } } void ChannelIDService::Request::RequestStarted( ChannelIDService* service, CompletionOnceCallback callback, std::unique_ptr* key, ChannelIDServiceJob* job) { DCHECK(service_ == NULL); service_ = service; callback_ = std::move(callback); key_ = key; job_ = job; } void ChannelIDService::Request::Post( int error, std::unique_ptr key) { service_ = NULL; DCHECK(!callback_.is_null()); if (key) *key_ = std::move(key); // Running the callback might delete |this| (e.g. the callback cleans up // resources created for the request), so we can't touch any of our // members afterwards. Reset callback_ first. base::ResetAndReturn(&callback_).Run(error); } ChannelIDService::ChannelIDService(ChannelIDStore* channel_id_store) : channel_id_store_(channel_id_store), id_(g_next_id.GetNext()), requests_(0), key_store_hits_(0), inflight_joins_(0), workers_created_(0), weak_ptr_factory_(this) {} ChannelIDService::~ChannelIDService() { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); } // static std::string ChannelIDService::GetDomainForHost(const std::string& host) { std::string domain = registry_controlled_domains::GetDomainAndRegistry( host, registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES); if (domain.empty()) return host; return domain; } int ChannelIDService::GetOrCreateChannelID( const std::string& host, std::unique_ptr* key, CompletionOnceCallback callback, Request* out_req) { DVLOG(1) << __func__ << " " << host; DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); if (callback.is_null() || !key || host.empty()) { return ERR_INVALID_ARGUMENT; } std::string domain = GetDomainForHost(host); if (domain.empty()) { return ERR_INVALID_ARGUMENT; } requests_++; // See if a request for the same domain is currently in flight. bool create_if_missing = true; if (JoinToInFlightRequest(domain, key, create_if_missing, &callback, out_req)) { return ERR_IO_PENDING; } int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req); if (err == ERR_FILE_NOT_FOUND) { // Sync lookup did not find a valid channel ID. Start generating a new one. workers_created_++; ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( domain, base::BindOnce(&ChannelIDService::GeneratedChannelID, weak_ptr_factory_.GetWeakPtr())); worker->Start(task_runner_); // We are waiting for key generation. Create a job & request to track it. ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); inflight_[domain] = base::WrapUnique(job); job->AddRequest(out_req); out_req->RequestStarted(this, std::move(callback), key, job); return ERR_IO_PENDING; } return err; } int ChannelIDService::GetChannelID(const std::string& host, std::unique_ptr* key, CompletionOnceCallback callback, Request* out_req) { DVLOG(1) << __func__ << " " << host; DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); if (callback.is_null() || !key || host.empty()) { return ERR_INVALID_ARGUMENT; } std::string domain = GetDomainForHost(host); if (domain.empty()) { return ERR_INVALID_ARGUMENT; } requests_++; // See if a request for the same domain currently in flight. bool create_if_missing = false; if (JoinToInFlightRequest(domain, key, create_if_missing, &callback, out_req)) { return ERR_IO_PENDING; } int err = LookupChannelID(domain, key, create_if_missing, &callback, out_req); return err; } void ChannelIDService::GotChannelID(int err, const std::string& server_identifier, std::unique_ptr key) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); auto j = inflight_.find(server_identifier); if (j == inflight_.end()) { NOTREACHED(); return; } if (err == OK) { // Async DB lookup found a valid channel ID. key_store_hits_++; // ChannelIDService::Request::Post will do the histograms and stuff. HandleResult(OK, server_identifier, std::move(key)); return; } // Async lookup failed or the channel ID was missing. Return the error // directly, unless the channel ID was missing and a request asked to create // one. if (err != ERR_FILE_NOT_FOUND || !j->second->CreateIfMissing()) { HandleResult(err, server_identifier, std::move(key)); return; } // At least one request asked to create a channel ID => start generating a new // one. workers_created_++; ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( server_identifier, base::BindOnce(&ChannelIDService::GeneratedChannelID, weak_ptr_factory_.GetWeakPtr())); worker->Start(task_runner_); } ChannelIDStore* ChannelIDService::GetChannelIDStore() { return channel_id_store_.get(); } void ChannelIDService::GeneratedChannelID( const std::string& server_identifier, int error, std::unique_ptr channel_id) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); std::unique_ptr key; if (error == OK) { key = channel_id->key()->Copy(); channel_id_store_->SetChannelID(std::move(channel_id)); } HandleResult(error, server_identifier, std::move(key)); } void ChannelIDService::HandleResult(int error, const std::string& server_identifier, std::unique_ptr key) { DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); auto j = inflight_.find(server_identifier); if (j == inflight_.end()) { NOTREACHED(); return; } std::unique_ptr job = std::move(j->second); inflight_.erase(j); job->HandleResult(error, std::move(key)); } bool ChannelIDService::JoinToInFlightRequest( const std::string& domain, std::unique_ptr* key, bool create_if_missing, CompletionOnceCallback* callback, Request* out_req) { auto j = inflight_.find(domain); if (j == inflight_.end()) return false; // A request for the same domain is in flight already. We'll attach our // callback, but we'll also mark it as requiring a channel ID if one's mising. ChannelIDServiceJob* job = j->second.get(); inflight_joins_++; job->AddRequest(out_req, create_if_missing); out_req->RequestStarted(this, std::move(*callback), key, job); return true; } int ChannelIDService::LookupChannelID( const std::string& domain, std::unique_ptr* key, bool create_if_missing, CompletionOnceCallback* callback, Request* out_req) { // Check if a channel ID key already exists for this domain. int err = channel_id_store_->GetChannelID( domain, key, base::BindOnce(&ChannelIDService::GotChannelID, weak_ptr_factory_.GetWeakPtr())); if (err == OK) { // Sync lookup found a valid channel ID. DVLOG(1) << "Channel ID store had valid key for " << domain; key_store_hits_++; return OK; } if (err == ERR_IO_PENDING) { // We are waiting for async DB lookup. Create a job & request to track it. ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); inflight_[domain] = base::WrapUnique(job); job->AddRequest(out_req); out_req->RequestStarted(this, std::move(*callback), key, job); return ERR_IO_PENDING; } return err; } int ChannelIDService::channel_id_count() { return channel_id_store_->GetChannelIDCount(); } } // namespace net