// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/ssl/default_channel_id_store.h"

#include <utility>

#include "base/bind.h"
#include "base/metrics/histogram_macros.h"
#include "crypto/ec_private_key.h"
#include "net/base/net_errors.h"

namespace {

bool AllDomainsPredicate(const std::string& domain) {
  return true;
}

}  // namespace

namespace net {

// --------------------------------------------------------------------------
// Task
class DefaultChannelIDStore::Task {
 public:
  virtual ~Task();

  // Runs the task and invokes the client callback on the thread that
  // originally constructed the task.
  virtual void Run(DefaultChannelIDStore* store) = 0;

 protected:
  void InvokeCallback(base::OnceClosure callback) const;
};

DefaultChannelIDStore::Task::~Task() = default;

void DefaultChannelIDStore::Task::InvokeCallback(
    base::OnceClosure callback) const {
  if (!callback.is_null())
    std::move(callback).Run();
}

// --------------------------------------------------------------------------
// GetChannelIDTask
class DefaultChannelIDStore::GetChannelIDTask
    : public DefaultChannelIDStore::Task {
 public:
  GetChannelIDTask(const std::string& server_identifier,
                   GetChannelIDCallback callback);
  ~GetChannelIDTask() override;
  void Run(DefaultChannelIDStore* store) override;

 private:
  std::string server_identifier_;
  GetChannelIDCallback callback_;
};

DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
    const std::string& server_identifier,
    GetChannelIDCallback callback)
    : server_identifier_(server_identifier), callback_(std::move(callback)) {}

DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() = default;

void DefaultChannelIDStore::GetChannelIDTask::Run(
    DefaultChannelIDStore* store) {
  std::unique_ptr<crypto::ECPrivateKey> key_result;
  int err = store->GetChannelID(server_identifier_, &key_result,
                                GetChannelIDCallback());
  DCHECK(err != ERR_IO_PENDING);

  InvokeCallback(base::BindOnce(std::move(callback_), err, server_identifier_,
                                std::move(key_result)));
}

// --------------------------------------------------------------------------
// SetChannelIDTask
class DefaultChannelIDStore::SetChannelIDTask
    : public DefaultChannelIDStore::Task {
 public:
  SetChannelIDTask(std::unique_ptr<ChannelID> channel_id);
  ~SetChannelIDTask() override;
  void Run(DefaultChannelIDStore* store) override;

 private:
  std::unique_ptr<ChannelID> channel_id_;
};

DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
    std::unique_ptr<ChannelID> channel_id)
    : channel_id_(std::move(channel_id)) {}

DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() = default;

void DefaultChannelIDStore::SetChannelIDTask::Run(
    DefaultChannelIDStore* store) {
  store->SyncSetChannelID(std::move(channel_id_));
}

// --------------------------------------------------------------------------
// DeleteChannelIDTask
class DefaultChannelIDStore::DeleteChannelIDTask
    : public DefaultChannelIDStore::Task {
 public:
  DeleteChannelIDTask(const std::string& server_identifier,
                      base::OnceClosure callback);
  ~DeleteChannelIDTask() override;
  void Run(DefaultChannelIDStore* store) override;

 private:
  std::string server_identifier_;
  base::OnceClosure callback_;
};

DefaultChannelIDStore::DeleteChannelIDTask::DeleteChannelIDTask(
    const std::string& server_identifier,
    base::OnceClosure callback)
    : server_identifier_(server_identifier), callback_(std::move(callback)) {}

DefaultChannelIDStore::DeleteChannelIDTask::~DeleteChannelIDTask() = default;

void DefaultChannelIDStore::DeleteChannelIDTask::Run(
    DefaultChannelIDStore* store) {
  store->SyncDeleteChannelID(server_identifier_);

  InvokeCallback(std::move(callback_));
}

// --------------------------------------------------------------------------
// DeleteForDomainssCreatedBetweenTask
class DefaultChannelIDStore::DeleteForDomainsCreatedBetweenTask
    : public DefaultChannelIDStore::Task {
 public:
  DeleteForDomainsCreatedBetweenTask(
      const base::Callback<bool(const std::string&)>& domain_predicate,
      base::Time delete_begin,
      base::Time delete_end,
      base::OnceClosure callback);
  ~DeleteForDomainsCreatedBetweenTask() override;
  void Run(DefaultChannelIDStore* store) override;

 private:
  const base::Callback<bool(const std::string&)> domain_predicate_;
  base::Time delete_begin_;
  base::Time delete_end_;
  base::OnceClosure callback_;
};

DefaultChannelIDStore::DeleteForDomainsCreatedBetweenTask::
    DeleteForDomainsCreatedBetweenTask(
        const base::Callback<bool(const std::string&)>& domain_predicate,
        base::Time delete_begin,
        base::Time delete_end,
        base::OnceClosure callback)
    : domain_predicate_(domain_predicate),
      delete_begin_(delete_begin),
      delete_end_(delete_end),
      callback_(std::move(callback)) {}

DefaultChannelIDStore::DeleteForDomainsCreatedBetweenTask::
    ~DeleteForDomainsCreatedBetweenTask() = default;

void DefaultChannelIDStore::DeleteForDomainsCreatedBetweenTask::Run(
    DefaultChannelIDStore* store) {
  store->SyncDeleteForDomainsCreatedBetween(domain_predicate_, delete_begin_,
                                            delete_end_);

  InvokeCallback(std::move(callback_));
}

// --------------------------------------------------------------------------
// GetAllChannelIDsTask
class DefaultChannelIDStore::GetAllChannelIDsTask
    : public DefaultChannelIDStore::Task {
 public:
  explicit GetAllChannelIDsTask(GetChannelIDListCallback callback);
  ~GetAllChannelIDsTask() override;
  void Run(DefaultChannelIDStore* store) override;

 private:
  std::string server_identifier_;
  GetChannelIDListCallback callback_;
};

DefaultChannelIDStore::GetAllChannelIDsTask::GetAllChannelIDsTask(
    GetChannelIDListCallback callback)
    : callback_(std::move(callback)) {}

DefaultChannelIDStore::GetAllChannelIDsTask::~GetAllChannelIDsTask() = default;

void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
    DefaultChannelIDStore* store) {
  ChannelIDList key_list;
  store->SyncGetAllChannelIDs(&key_list);

  InvokeCallback(base::BindOnce(std::move(callback_), key_list));
}

// --------------------------------------------------------------------------
// DefaultChannelIDStore

DefaultChannelIDStore::DefaultChannelIDStore(
    PersistentStore* store)
    : initialized_(false),
      loaded_(false),
      store_(store),
      weak_ptr_factory_(this) {}

int DefaultChannelIDStore::GetChannelID(
    const std::string& server_identifier,
    std::unique_ptr<crypto::ECPrivateKey>* key_result,
    GetChannelIDCallback callback) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  InitIfNecessary();

  if (!loaded_) {
    EnqueueTask(std::unique_ptr<Task>(
        new GetChannelIDTask(server_identifier, std::move(callback))));
    return ERR_IO_PENDING;
  }

  ChannelIDMap::iterator it = channel_ids_.find(server_identifier);

  if (it == channel_ids_.end())
    return ERR_FILE_NOT_FOUND;

  ChannelID* channel_id = it->second;
  *key_result = channel_id->key()->Copy();

  return OK;
}

void DefaultChannelIDStore::SetChannelID(
    std::unique_ptr<ChannelID> channel_id) {
  auto* task = new SetChannelIDTask(std::move(channel_id));
  RunOrEnqueueTask(std::unique_ptr<Task>(task));
}

void DefaultChannelIDStore::DeleteChannelID(
    const std::string& server_identifier,
    base::OnceClosure callback) {
  RunOrEnqueueTask(std::unique_ptr<Task>(
      new DeleteChannelIDTask(server_identifier, std::move(callback))));
}

void DefaultChannelIDStore::DeleteForDomainsCreatedBetween(
    const base::Callback<bool(const std::string&)>& domain_predicate,
    base::Time delete_begin,
    base::Time delete_end,
    base::OnceClosure callback) {
  RunOrEnqueueTask(std::unique_ptr<Task>(new DeleteForDomainsCreatedBetweenTask(
      domain_predicate, delete_begin, delete_end, std::move(callback))));
}

void DefaultChannelIDStore::DeleteAll(base::OnceClosure callback) {
  DeleteForDomainsCreatedBetween(base::Bind(&AllDomainsPredicate), base::Time(),
                                 base::Time(), std::move(callback));
}

void DefaultChannelIDStore::GetAllChannelIDs(
    GetChannelIDListCallback callback) {
  RunOrEnqueueTask(
      std::unique_ptr<Task>(new GetAllChannelIDsTask(std::move(callback))));
}

void DefaultChannelIDStore::Flush() {
  store_->Flush();
}

int DefaultChannelIDStore::GetChannelIDCount() {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  return channel_ids_.size();
}

void DefaultChannelIDStore::SetForceKeepSessionState() {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  InitIfNecessary();

  if (store_)
    store_->SetForceKeepSessionState();
}

DefaultChannelIDStore::~DefaultChannelIDStore() {
  DeleteAllInMemory();
}

void DefaultChannelIDStore::DeleteAllInMemory() {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);

  for (ChannelIDMap::iterator it = channel_ids_.begin();
       it != channel_ids_.end(); ++it) {
    delete it->second;
  }
  channel_ids_.clear();
}

void DefaultChannelIDStore::InitStore() {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(store_) << "Store must exist to initialize";
  DCHECK(!loaded_);

  store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
                          weak_ptr_factory_.GetWeakPtr()));
}

void DefaultChannelIDStore::OnLoaded(
    std::unique_ptr<std::vector<std::unique_ptr<ChannelID>>> channel_ids) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  for (std::vector<std::unique_ptr<ChannelID>>::iterator it =
           channel_ids->begin();
       it != channel_ids->end(); ++it) {
    DCHECK(channel_ids_.find((*it)->server_identifier()) ==
           channel_ids_.end());
    std::string ident = (*it)->server_identifier();
    channel_ids_[ident] = it->release();
  }
  channel_ids->clear();

  loaded_ = true;

  for (std::unique_ptr<Task>& i : waiting_tasks_)
    i->Run(this);
  waiting_tasks_.clear();
}

void DefaultChannelIDStore::SyncSetChannelID(
    std::unique_ptr<ChannelID> channel_id) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);

  InternalDeleteChannelID(channel_id->server_identifier());
  InternalInsertChannelID(std::move(channel_id));
}

void DefaultChannelIDStore::SyncDeleteChannelID(
    const std::string& server_identifier) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);
  InternalDeleteChannelID(server_identifier);
}

void DefaultChannelIDStore::SyncDeleteForDomainsCreatedBetween(
    const base::Callback<bool(const std::string&)>& domain_predicate,
    base::Time delete_begin,
    base::Time delete_end) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);
  for (ChannelIDMap::iterator it = channel_ids_.begin();
       it != channel_ids_.end();) {
    ChannelIDMap::iterator cur = it;
    ++it;
    ChannelID* channel_id = cur->second;

    if ((delete_begin.is_null() ||
         channel_id->creation_time() >= delete_begin) &&
        (delete_end.is_null() || channel_id->creation_time() < delete_end) &&
        domain_predicate.Run(channel_id->server_identifier())) {
      if (store_)
        store_->DeleteChannelID(*channel_id);
      delete channel_id;
      channel_ids_.erase(cur);
    }
  }
}

void DefaultChannelIDStore::SyncGetAllChannelIDs(
    ChannelIDList* channel_id_list) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);
  for (ChannelIDMap::iterator it = channel_ids_.begin();
       it != channel_ids_.end(); ++it)
    channel_id_list->push_back(*it->second);
}

void DefaultChannelIDStore::EnqueueTask(std::unique_ptr<Task> task) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(!loaded_);
  waiting_tasks_.push_back(std::move(task));
}

void DefaultChannelIDStore::RunOrEnqueueTask(std::unique_ptr<Task> task) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  InitIfNecessary();

  if (!loaded_) {
    EnqueueTask(std::move(task));
    return;
  }

  task->Run(this);
}

void DefaultChannelIDStore::InternalDeleteChannelID(
    const std::string& server_identifier) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);

  ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
  if (it == channel_ids_.end())
    return;  // There is nothing to delete.

  ChannelID* channel_id = it->second;
  if (store_)
    store_->DeleteChannelID(*channel_id);
  channel_ids_.erase(it);
  delete channel_id;
}

void DefaultChannelIDStore::InternalInsertChannelID(
    std::unique_ptr<ChannelID> channel_id) {
  DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
  DCHECK(loaded_);

  if (store_)
    store_->AddChannelID(*channel_id);
  const std::string& server_identifier = channel_id->server_identifier();
  channel_ids_[server_identifier] = channel_id.release();
}

bool DefaultChannelIDStore::IsEphemeral() {
  return !store_;
}

DefaultChannelIDStore::PersistentStore::PersistentStore() = default;

DefaultChannelIDStore::PersistentStore::~PersistentStore() = default;

}  // namespace net