mirror of
https://github.com/klzgrad/naiveproxy.git
synced 2024-11-24 06:16:30 +03:00
255 lines
8.3 KiB
C++
255 lines
8.3 KiB
C++
// Copyright (c) 2012 The Chromium Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style license that can be
|
|
// found in the LICENSE file.
|
|
|
|
#include "net/dns/dns_test_util.h"
|
|
|
|
#include <string>
|
|
|
|
#include "base/big_endian.h"
|
|
#include "base/bind.h"
|
|
#include "base/location.h"
|
|
#include "base/memory/weak_ptr.h"
|
|
#include "base/single_thread_task_runner.h"
|
|
#include "base/sys_byteorder.h"
|
|
#include "base/threading/thread_task_runner_handle.h"
|
|
#include "net/base/io_buffer.h"
|
|
#include "net/base/net_errors.h"
|
|
#include "net/dns/address_sorter.h"
|
|
#include "net/dns/dns_query.h"
|
|
#include "net/dns/dns_response.h"
|
|
#include "net/dns/dns_transaction.h"
|
|
#include "net/dns/dns_util.h"
|
|
#include "testing/gtest/include/gtest/gtest.h"
|
|
|
|
namespace net {
|
|
namespace {
|
|
|
|
class MockAddressSorter : public AddressSorter {
|
|
public:
|
|
~MockAddressSorter() override = default;
|
|
void Sort(const AddressList& list, CallbackType callback) const override {
|
|
// Do nothing.
|
|
std::move(callback).Run(true, list);
|
|
}
|
|
};
|
|
|
|
// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
|
|
class MockTransaction : public DnsTransaction,
|
|
public base::SupportsWeakPtr<MockTransaction> {
|
|
public:
|
|
MockTransaction(const MockDnsClientRuleList& rules,
|
|
const std::string& hostname,
|
|
uint16_t qtype,
|
|
DnsTransactionFactory::CallbackType callback)
|
|
: result_(MockDnsClientRule::FAIL),
|
|
hostname_(hostname),
|
|
qtype_(qtype),
|
|
callback_(std::move(callback)),
|
|
started_(false),
|
|
delayed_(false) {
|
|
// Find the relevant rule which matches |qtype| and prefix of |hostname|.
|
|
for (size_t i = 0; i < rules.size(); ++i) {
|
|
const std::string& prefix = rules[i].prefix;
|
|
if ((rules[i].qtype == qtype) &&
|
|
(hostname.size() >= prefix.size()) &&
|
|
(hostname.compare(0, prefix.size(), prefix) == 0)) {
|
|
result_ = rules[i].result;
|
|
delayed_ = rules[i].delay;
|
|
|
|
// Fill in an IP address for the result if one was not specified.
|
|
if (result_.ip.empty() && result_.type == MockDnsClientRule::OK) {
|
|
result_.ip = qtype_ == dns_protocol::kTypeA
|
|
? IPAddress::IPv4Localhost()
|
|
: IPAddress::IPv6Localhost();
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
const std::string& GetHostname() const override { return hostname_; }
|
|
|
|
uint16_t GetType() const override { return qtype_; }
|
|
|
|
void Start() override {
|
|
EXPECT_FALSE(started_);
|
|
started_ = true;
|
|
if (delayed_)
|
|
return;
|
|
// Using WeakPtr to cleanly cancel when transaction is destroyed.
|
|
base::ThreadTaskRunnerHandle::Get()->PostTask(
|
|
FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
|
|
}
|
|
|
|
void FinishDelayedTransaction() {
|
|
EXPECT_TRUE(delayed_);
|
|
delayed_ = false;
|
|
Finish();
|
|
}
|
|
|
|
bool delayed() const { return delayed_; }
|
|
|
|
private:
|
|
void Finish() {
|
|
switch (result_.type) {
|
|
case MockDnsClientRule::NODOMAIN:
|
|
case MockDnsClientRule::EMPTY:
|
|
case MockDnsClientRule::OK: {
|
|
std::string qname;
|
|
DNSDomainFromDot(hostname_, &qname);
|
|
DnsQuery query(0, qname, qtype_);
|
|
|
|
DnsResponse response;
|
|
char* buffer = response.io_buffer()->data();
|
|
int nbytes = query.io_buffer()->size();
|
|
memcpy(buffer, query.io_buffer()->data(), nbytes);
|
|
dns_protocol::Header* header =
|
|
reinterpret_cast<dns_protocol::Header*>(buffer);
|
|
header->flags |= dns_protocol::kFlagResponse;
|
|
if (MockDnsClientRule::NODOMAIN == result_.type) {
|
|
header->flags |= base::HostToNet16(dns_protocol::kRcodeNXDOMAIN);
|
|
}
|
|
|
|
const uint16_t kPointerToQueryName =
|
|
static_cast<uint16_t>(0xc000 | sizeof(*header));
|
|
|
|
const uint32_t kTTL = 86400; // One day.
|
|
|
|
// Size of RDATA which is a IPv4 or IPv6 address.
|
|
if (MockDnsClientRule::OK == result_.type)
|
|
EXPECT_TRUE(result_.ip.IsValid());
|
|
size_t rdata_size = result_.ip.size();
|
|
|
|
if (MockDnsClientRule::OK == result_.type) {
|
|
// Write the answer using the expected IP address.
|
|
// 12 is the sum of sizes of the compressed name reference, TYPE,
|
|
// CLASS, TTL and RDLENGTH.
|
|
size_t answer_size = 12 + rdata_size;
|
|
header->ancount = base::HostToNet16(1);
|
|
base::BigEndianWriter writer(buffer + nbytes, answer_size);
|
|
writer.WriteU16(kPointerToQueryName);
|
|
writer.WriteU16(qtype_);
|
|
writer.WriteU16(dns_protocol::kClassIN);
|
|
writer.WriteU32(kTTL);
|
|
writer.WriteU16(static_cast<uint16_t>(rdata_size));
|
|
writer.WriteBytes(result_.ip.bytes().data(), rdata_size);
|
|
nbytes += answer_size;
|
|
} else {
|
|
// authority will be 12 bytes.
|
|
size_t authority_size = 12;
|
|
header->ancount = base::HostToNet16(0);
|
|
header->nscount = base::HostToNet16(1);
|
|
base::BigEndianWriter writer(buffer + nbytes, authority_size);
|
|
writer.WriteU16(kPointerToQueryName);
|
|
writer.WriteU16(dns_protocol::kTypeSOA);
|
|
writer.WriteU16(dns_protocol::kClassIN);
|
|
writer.WriteU32(kTTL);
|
|
writer.WriteU16(static_cast<uint16_t>(rdata_size));
|
|
nbytes += authority_size;
|
|
}
|
|
EXPECT_TRUE(response.InitParse(nbytes, query));
|
|
if (MockDnsClientRule::NODOMAIN == result_.type) {
|
|
std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED, &response);
|
|
} else {
|
|
std::move(callback_).Run(this, OK, &response);
|
|
}
|
|
} break;
|
|
case MockDnsClientRule::FAIL:
|
|
std::move(callback_).Run(this, ERR_NAME_NOT_RESOLVED, NULL);
|
|
break;
|
|
case MockDnsClientRule::TIMEOUT:
|
|
std::move(callback_).Run(this, ERR_DNS_TIMED_OUT, NULL);
|
|
break;
|
|
default:
|
|
NOTREACHED();
|
|
break;
|
|
}
|
|
}
|
|
|
|
void SetRequestContext(URLRequestContext*) override {}
|
|
void SetRequestPriority(RequestPriority priority) override {}
|
|
|
|
MockDnsClientRule::Result result_;
|
|
const std::string hostname_;
|
|
const uint16_t qtype_;
|
|
DnsTransactionFactory::CallbackType callback_;
|
|
bool started_;
|
|
bool delayed_;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// A DnsTransactionFactory which creates MockTransaction.
|
|
class MockTransactionFactory : public DnsTransactionFactory {
|
|
public:
|
|
explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
|
|
: rules_(rules) {}
|
|
|
|
~MockTransactionFactory() override = default;
|
|
|
|
std::unique_ptr<DnsTransaction> CreateTransaction(
|
|
const std::string& hostname,
|
|
uint16_t qtype,
|
|
DnsTransactionFactory::CallbackType callback,
|
|
const NetLogWithSource&) override {
|
|
std::unique_ptr<MockTransaction> transaction =
|
|
std::make_unique<MockTransaction>(rules_, hostname, qtype,
|
|
std::move(callback));
|
|
if (transaction->delayed())
|
|
delayed_transactions_.push_back(transaction->AsWeakPtr());
|
|
return transaction;
|
|
}
|
|
|
|
void AddEDNSOption(const OptRecordRdata::Opt& opt) override {
|
|
NOTREACHED() << "Not implemented";
|
|
}
|
|
|
|
void CompleteDelayedTransactions() {
|
|
DelayedTransactionList old_delayed_transactions;
|
|
old_delayed_transactions.swap(delayed_transactions_);
|
|
for (DelayedTransactionList::iterator it = old_delayed_transactions.begin();
|
|
it != old_delayed_transactions.end(); ++it) {
|
|
if (it->get())
|
|
(*it)->FinishDelayedTransaction();
|
|
}
|
|
}
|
|
|
|
private:
|
|
typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList;
|
|
|
|
MockDnsClientRuleList rules_;
|
|
DelayedTransactionList delayed_transactions_;
|
|
};
|
|
|
|
MockDnsClient::MockDnsClient(const DnsConfig& config,
|
|
const MockDnsClientRuleList& rules)
|
|
: config_(config),
|
|
factory_(new MockTransactionFactory(rules)),
|
|
address_sorter_(new MockAddressSorter()) {
|
|
}
|
|
|
|
MockDnsClient::~MockDnsClient() = default;
|
|
|
|
void MockDnsClient::SetConfig(const DnsConfig& config) {
|
|
config_ = config;
|
|
}
|
|
|
|
const DnsConfig* MockDnsClient::GetConfig() const {
|
|
return config_.IsValid() ? &config_ : NULL;
|
|
}
|
|
|
|
DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
|
|
return config_.IsValid() ? factory_.get() : NULL;
|
|
}
|
|
|
|
AddressSorter* MockDnsClient::GetAddressSorter() {
|
|
return address_sorter_.get();
|
|
}
|
|
|
|
void MockDnsClient::CompleteDelayedTransactions() {
|
|
factory_->CompleteDelayedTransactions();
|
|
}
|
|
|
|
} // namespace net
|