Certificate Transparency DNS log client
This can query CT logs over DNS, as defined by:
https://github.com/google/certificate-transparency-rfcs/blob/master/dns/draft-ct-over-dns.md
This is required for obtaining audit proofs, which will allow Chrome to verify
that SCTs it receives are trustworthy and that logs are behaving correctly.
BUG=612439
Committed: https://crrev.com/59b6ea2217dbc10400b6a9d433ad13c91bb6b7c2
Review-Url: https://codereview.chromium.org/2066553002
Cr-Original-Commit-Position: refs/heads/master@{#403798}
Cr-Commit-Position: refs/heads/master@{#404199}
diff --git a/components/certificate_transparency/BUILD.gn b/components/certificate_transparency/BUILD.gn
index 15cb99f..3bba1a1 100644
--- a/components/certificate_transparency/BUILD.gn
+++ b/components/certificate_transparency/BUILD.gn
@@ -6,6 +6,8 @@
sources = [
"ct_policy_manager.cc",
"ct_policy_manager.h",
+ "log_dns_client.cc",
+ "log_dns_client.h",
"log_proof_fetcher.cc",
"log_proof_fetcher.h",
"pref_names.cc",
@@ -18,6 +20,7 @@
deps = [
"//base",
+ "//components/base32",
"//components/prefs",
"//components/safe_json",
"//components/url_formatter",
@@ -31,6 +34,7 @@
testonly = true
sources = [
"ct_policy_manager_unittest.cc",
+ "log_dns_client_unittest.cc",
"log_proof_fetcher_unittest.cc",
"single_tree_tracker_unittest.cc",
]
@@ -41,6 +45,7 @@
"//components/prefs:test_support",
"//components/safe_json:test_support",
"//net:test_support",
+ "//testing/gmock",
"//testing/gtest",
]
}
diff --git a/components/certificate_transparency/DEPS b/components/certificate_transparency/DEPS
index b3d5f275..575b1501 100644
--- a/components/certificate_transparency/DEPS
+++ b/components/certificate_transparency/DEPS
@@ -1,7 +1,9 @@
include_rules = [
+ "+components/base32",
"+components/prefs",
"+components/safe_json",
"+components/url_formatter",
"+components/url_matcher",
+ "+crypto",
"+net",
]
diff --git a/components/certificate_transparency/log_dns_client.cc b/components/certificate_transparency/log_dns_client.cc
new file mode 100644
index 0000000..77da9f8
--- /dev/null
+++ b/components/certificate_transparency/log_dns_client.cc
@@ -0,0 +1,294 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/certificate_transparency/log_dns_client.h"
+
+#include <sstream>
+
+#include "base/bind.h"
+#include "base/location.h"
+#include "base/logging.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/strings/string_util.h"
+#include "base/threading/thread_task_runner_handle.h"
+#include "base/time/time.h"
+#include "components/base32/base32.h"
+#include "crypto/sha2.h"
+#include "net/base/net_errors.h"
+#include "net/cert/merkle_audit_proof.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/dns_response.h"
+#include "net/dns/dns_transaction.h"
+#include "net/dns/record_parsed.h"
+#include "net/dns/record_rdata.h"
+
+namespace certificate_transparency {
+
+namespace {
+
+bool ParseTxtResponse(const net::DnsResponse& response, std::string* txt) {
+ DCHECK(txt);
+
+ net::DnsRecordParser parser = response.Parser();
+ // We don't care about the creation time, since we're going to throw
+ // |parsed_record| away as soon as we've extracted the payload, so provide
+ // the "null" time.
+ auto parsed_record = net::RecordParsed::CreateFrom(&parser, base::Time());
+ if (parsed_record == nullptr)
+ return false;
+
+ auto txt_record = parsed_record->rdata<net::TxtRecordRdata>();
+ if (txt_record == nullptr)
+ return false;
+
+ *txt = base::JoinString(txt_record->texts(), "");
+ return true;
+}
+
+bool ParseLeafIndex(const net::DnsResponse& response, uint64_t* index) {
+ DCHECK(index);
+
+ std::string index_str;
+ if (!ParseTxtResponse(response, &index_str))
+ return false;
+
+ return base::StringToUint64(index_str, index);
+}
+
+bool ParseAuditPath(const net::DnsResponse& response,
+ net::ct::MerkleAuditProof* proof) {
+ DCHECK(proof);
+
+ std::string audit_path;
+ if (!ParseTxtResponse(response, &audit_path))
+ return false;
+ // If empty or not a multiple of the node size, it is considered invalid.
+ // It's important to consider empty audit paths as invalid, as otherwise an
+ // infinite loop could occur if the server consistently returned empty
+ // responses.
+ if (audit_path.empty() || audit_path.size() % crypto::kSHA256Length != 0)
+ return false;
+
+ for (size_t i = 0; i < audit_path.size(); i += crypto::kSHA256Length) {
+ proof->nodes.push_back(audit_path.substr(i, crypto::kSHA256Length));
+ }
+
+ return true;
+}
+
+} // namespace
+
+LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client,
+ const net::BoundNetLog& net_log)
+ : dns_client_(std::move(dns_client)),
+ net_log_(net_log),
+ weak_ptr_factory_(this) {
+ CHECK(dns_client_);
+}
+
+LogDnsClient::~LogDnsClient() {}
+
+void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log,
+ base::StringPiece leaf_hash,
+ const LeafIndexCallback& callback) {
+ if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0));
+ return;
+ }
+
+ std::string encoded_leaf_hash =
+ base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING);
+ DCHECK_EQ(encoded_leaf_hash.size(), 52u);
+
+ net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory();
+ if (factory == nullptr) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0));
+ return;
+ }
+
+ std::ostringstream qname;
+ qname << encoded_leaf_hash << ".hash." << domain_for_log << ".";
+
+ net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind(
+ &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr());
+
+ std::unique_ptr<net::DnsTransaction> dns_transaction =
+ factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT,
+ transaction_callback, net_log_);
+
+ dns_transaction->Start();
+ leaf_index_queries_.push_back({std::move(dns_transaction), callback});
+}
+
+// The performance of this could be improved by sending all of the expected
+// queries up front. Each response can contain a maximum of 7 audit path nodes,
+// so for an audit proof of size 20, it could send 3 queries (for nodes 0-6,
+// 7-13 and 14-19) immediately. Currently, it sends only the first and then,
+// based on the number of nodes received, sends the next query. The complexity
+// of the code would increase though, as it would need to detect gaps in the
+// audit proof caused by the server not responding with the anticipated number
+// of nodes. Ownership of the proof would need to change, as it would be shared
+// between simultaneous DNS transactions.
+void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log,
+ uint64_t leaf_index,
+ uint64_t tree_size,
+ const AuditProofCallback& callback) {
+ if (domain_for_log.empty() || leaf_index >= tree_size) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr));
+ return;
+ }
+
+ std::unique_ptr<net::ct::MerkleAuditProof> proof(
+ new net::ct::MerkleAuditProof);
+ proof->leaf_index = leaf_index;
+ // TODO(robpercival): Once a "tree_size" field is added to MerkleAuditProof,
+ // pass |tree_size| to QueryAuditProofNodes using that.
+
+ // Query for the first batch of audit proof nodes (i.e. starting from 0).
+ QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size, 0,
+ callback);
+}
+
+void LogDnsClient::QueryLeafIndexComplete(net::DnsTransaction* transaction,
+ int net_error,
+ const net::DnsResponse* response) {
+ auto query_iterator =
+ std::find_if(leaf_index_queries_.begin(), leaf_index_queries_.end(),
+ [transaction](const Query<LeafIndexCallback>& query) {
+ return query.transaction.get() == transaction;
+ });
+ if (query_iterator == leaf_index_queries_.end()) {
+ NOTREACHED();
+ return;
+ }
+ const Query<LeafIndexCallback> query = std::move(*query_iterator);
+ leaf_index_queries_.erase(query_iterator);
+
+ // If we've received no response but no net::error either (shouldn't happen),
+ // report the response as invalid.
+ if (response == nullptr && net_error == net::OK) {
+ net_error = net::ERR_INVALID_RESPONSE;
+ }
+
+ if (net_error != net::OK) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE, base::Bind(query.callback, net_error, 0));
+ return;
+ }
+
+ uint64_t leaf_index;
+ if (!ParseLeafIndex(*response, &leaf_index)) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, 0));
+ return;
+ }
+
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE, base::Bind(query.callback, net::OK, leaf_index));
+}
+
+void LogDnsClient::QueryAuditProofNodes(
+ std::unique_ptr<net::ct::MerkleAuditProof> proof,
+ base::StringPiece domain_for_log,
+ uint64_t tree_size,
+ uint64_t node_index,
+ const AuditProofCallback& callback) {
+ // Preconditions that should be guaranteed internally by this class.
+ DCHECK(proof);
+ DCHECK(!domain_for_log.empty());
+ DCHECK_LT(proof->leaf_index, tree_size);
+ DCHECK_LT(node_index,
+ net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size));
+
+ net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory();
+ if (factory == nullptr) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr));
+ return;
+ }
+
+ std::ostringstream qname;
+ qname << node_index << "." << proof->leaf_index << "." << tree_size
+ << ".tree." << domain_for_log << ".";
+
+ net::DnsTransactionFactory::CallbackType transaction_callback =
+ base::Bind(&LogDnsClient::QueryAuditProofNodesComplete,
+ weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)),
+ domain_for_log, tree_size);
+
+ std::unique_ptr<net::DnsTransaction> dns_transaction =
+ factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT,
+ transaction_callback, net_log_);
+ dns_transaction->Start();
+ audit_proof_queries_.push_back({std::move(dns_transaction), callback});
+}
+
+void LogDnsClient::QueryAuditProofNodesComplete(
+ std::unique_ptr<net::ct::MerkleAuditProof> proof,
+ base::StringPiece domain_for_log,
+ uint64_t tree_size,
+ net::DnsTransaction* transaction,
+ int net_error,
+ const net::DnsResponse* response) {
+ // Preconditions that should be guaranteed internally by this class.
+ DCHECK(proof);
+ DCHECK(!domain_for_log.empty());
+
+ auto query_iterator =
+ std::find_if(audit_proof_queries_.begin(), audit_proof_queries_.end(),
+ [transaction](const Query<AuditProofCallback>& query) {
+ return query.transaction.get() == transaction;
+ });
+
+ if (query_iterator == audit_proof_queries_.end()) {
+ NOTREACHED();
+ return;
+ }
+ const Query<AuditProofCallback> query = std::move(*query_iterator);
+ audit_proof_queries_.erase(query_iterator);
+
+ // If we've received no response but no net::error either (shouldn't happen),
+ // report the response as invalid.
+ if (response == nullptr && net_error == net::OK) {
+ net_error = net::ERR_INVALID_RESPONSE;
+ }
+
+ if (net_error != net::OK) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE, base::Bind(query.callback, net_error, nullptr));
+ return;
+ }
+
+ const uint64_t audit_path_length =
+ net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size);
+ proof->nodes.reserve(audit_path_length);
+
+ if (!ParseAuditPath(*response, proof.get())) {
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, nullptr));
+ return;
+ }
+
+ const uint64_t audit_path_nodes_received = proof->nodes.size();
+ if (audit_path_nodes_received < audit_path_length) {
+ QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size,
+ audit_path_nodes_received, query.callback);
+ return;
+ }
+
+ base::ThreadTaskRunnerHandle::Get()->PostTask(
+ FROM_HERE,
+ base::Bind(query.callback, net::OK, base::Passed(std::move(proof))));
+}
+
+} // namespace certificate_transparency
diff --git a/components/certificate_transparency/log_dns_client.h b/components/certificate_transparency/log_dns_client.h
new file mode 100644
index 0000000..e00cefd
--- /dev/null
+++ b/components/certificate_transparency/log_dns_client.h
@@ -0,0 +1,120 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef COMPONENTS_CERTIFICATE_TRANSPARENCY_LOG_DNS_CLIENT_H_
+#define COMPONENTS_CERTIFICATE_TRANSPARENCY_LOG_DNS_CLIENT_H_
+
+#include <stdint.h>
+
+#include <list>
+#include <string>
+
+#include "base/callback.h"
+#include "base/macros.h"
+#include "base/strings/string_piece.h"
+#include "net/log/net_log.h"
+
+namespace net {
+class DnsClient;
+class DnsResponse;
+class DnsTransaction;
+namespace ct {
+struct MerkleAuditProof;
+} // namespace ct
+} // namespace net
+
+namespace certificate_transparency {
+
+// Queries Certificate Transparency (CT) log servers via DNS.
+// All queries are performed asynchronously.
+// For more information, see
+// https://github.com/google/certificate-transparency-rfcs/blob/master/dns/draft-ct-over-dns.md.
+class LogDnsClient {
+ public:
+ // Invoked when a leaf index query completes.
+ // If an error occured, |net_error| will be a net::Error code, otherwise it
+ // will be net::OK and |leaf_index| will be the leaf index that was received.
+ using LeafIndexCallback =
+ base::Callback<void(int net_error, uint64_t leaf_index)>;
+ // Invoked when an audit proof query completes.
+ // If an error occurred, |net_error| will be a net::Error code, otherwise it
+ // will be net::OK and |proof| will be the audit proof that was received.
+ // The log ID of |proof| will not be set, as that is not known by this class,
+ // but the leaf index will be set.
+ using AuditProofCallback =
+ base::Callback<void(int net_error,
+ std::unique_ptr<net::ct::MerkleAuditProof> proof)>;
+
+ // Creates a log client that will take ownership of |dns_client| and use it
+ // to perform DNS queries. Queries will be logged to |net_log|.
+ LogDnsClient(std::unique_ptr<net::DnsClient> dns_client,
+ const net::BoundNetLog& net_log);
+ virtual ~LogDnsClient();
+
+ // Queries a CT log to discover the index of the leaf with |leaf_hash|.
+ // The log is identified by |domain_for_log|, which is the DNS name used as a
+ // suffix for all queries.
+ // The |leaf_hash| is the SHA-256 hash of a Merkle tree leaf in that log.
+ // The |callback| is invoked when the query is complete, or an error occurs.
+ void QueryLeafIndex(base::StringPiece domain_for_log,
+ base::StringPiece leaf_hash,
+ const LeafIndexCallback& callback);
+
+ // Queries a CT log to retrieve an audit proof for the leaf at |leaf_index|.
+ // The size of the CT log tree must be provided in |tree_size|.
+ // The log is identified by |domain_for_log|, which is the DNS name used as a
+ // suffix for all queries.
+ // The |callback| is invoked when the query is complete, or an error occurs.
+ void QueryAuditProof(base::StringPiece domain_for_log,
+ uint64_t leaf_index,
+ uint64_t tree_size,
+ const AuditProofCallback& callback);
+
+ private:
+ void QueryLeafIndexComplete(net::DnsTransaction* transaction,
+ int neterror,
+ const net::DnsResponse* response);
+
+ // Queries a CT log to retrieve part of an audit |proof|. The |node_index|
+ // indicates which node of the audit proof/ should be requested. The CT log
+ // may return up to 7 nodes, starting from |node_index| (this is the maximum
+ // that will fit in a DNS UDP packet). The nodes will be appended to
+ // |proof->nodes|.
+ void QueryAuditProofNodes(std::unique_ptr<net::ct::MerkleAuditProof> proof,
+ base::StringPiece domain_for_log,
+ uint64_t tree_size,
+ uint64_t node_index,
+ const AuditProofCallback& callback);
+
+ void QueryAuditProofNodesComplete(
+ std::unique_ptr<net::ct::MerkleAuditProof> proof,
+ base::StringPiece domain_for_log,
+ uint64_t tree_size,
+ net::DnsTransaction* transaction,
+ int net_error,
+ const net::DnsResponse* response);
+
+ // A DNS query that is in flight.
+ template <typename CallbackType>
+ struct Query {
+ std::unique_ptr<net::DnsTransaction> transaction;
+ CallbackType callback;
+ };
+
+ // Used to perform DNS queries.
+ std::unique_ptr<net::DnsClient> dns_client_;
+ // Passed to the DNS client for logging.
+ net::BoundNetLog net_log_;
+ // Leaf index queries that haven't completed yet.
+ std::list<Query<LeafIndexCallback>> leaf_index_queries_;
+ // Audit proof queries that haven't completed yet.
+ std::list<Query<AuditProofCallback>> audit_proof_queries_;
+ // Creates weak_ptrs to this, for callback purposes.
+ base::WeakPtrFactory<LogDnsClient> weak_ptr_factory_;
+
+ DISALLOW_COPY_AND_ASSIGN(LogDnsClient);
+};
+
+} // namespace certificate_transparency
+#endif // COMPONENTS_CERTIFICATE_TRANSPARENCY_LOG_DNS_CLIENT_H_
diff --git a/components/certificate_transparency/log_dns_client_unittest.cc b/components/certificate_transparency/log_dns_client_unittest.cc
new file mode 100644
index 0000000..d8f11da
--- /dev/null
+++ b/components/certificate_transparency/log_dns_client_unittest.cc
@@ -0,0 +1,746 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/certificate_transparency/log_dns_client.h"
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+#include "base/big_endian.h"
+#include "base/macros.h"
+#include "base/message_loop/message_loop.h"
+#include "base/run_loop.h"
+#include "base/sys_byteorder.h"
+#include "base/test/test_timeouts.h"
+#include "crypto/sha2.h"
+#include "net/base/net_errors.h"
+#include "net/cert/merkle_audit_proof.h"
+#include "net/cert/merkle_tree_leaf.h"
+#include "net/cert/signed_certificate_timestamp.h"
+#include "net/dns/dns_client.h"
+#include "net/dns/dns_config_service.h"
+#include "net/dns/dns_protocol.h"
+#include "net/log/net_log.h"
+#include "net/socket/socket_test_util.h"
+#include "net/test/gtest_util.h"
+#include "testing/gmock/include/gmock/gmock.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace certificate_transparency {
+namespace {
+
+using ::testing::IsNull;
+using ::testing::NotNull;
+using net::test::IsError;
+using net::test::IsOk;
+
+constexpr char kLeafHash[] =
+ "\x1f\x25\xe1\xca\xba\x4f\xf9\xb8\x27\x24\x83\x0f\xca\x60\xe4\xc2\xbe\xa8"
+ "\xc3\xa9\x44\x1c\x27\xb0\xb4\x3e\x6a\x96\x94\xc7\xb8\x04";
+
+// Always return min, to simplify testing.
+// This should result in the DNS query ID always being 0.
+int FakeRandInt(int min, int max) {
+ return min;
+}
+
+std::vector<char> CreateDnsTxtRequest(base::StringPiece qname) {
+ std::string encoded_qname;
+ EXPECT_TRUE(net::DNSDomainFromDot(qname, &encoded_qname));
+
+ const size_t query_section_size = encoded_qname.size() + 4;
+
+ std::vector<char> request(sizeof(net::dns_protocol::Header) +
+ query_section_size);
+ base::BigEndianWriter writer(request.data(), request.size());
+
+ // Header
+ net::dns_protocol::Header header = {};
+ header.flags = base::HostToNet16(net::dns_protocol::kFlagRD);
+ header.qdcount = base::HostToNet16(1);
+ EXPECT_TRUE(writer.WriteBytes(&header, sizeof(header)));
+ // Query section
+ EXPECT_TRUE(writer.WriteBytes(encoded_qname.data(), encoded_qname.size()));
+ EXPECT_TRUE(writer.WriteU16(net::dns_protocol::kTypeTXT));
+ EXPECT_TRUE(writer.WriteU16(net::dns_protocol::kClassIN));
+ EXPECT_EQ(0, writer.remaining());
+
+ return request;
+}
+
+std::vector<char> CreateDnsTxtResponse(const std::vector<char>& request,
+ base::StringPiece answer) {
+ const size_t answers_section_size = 12 + answer.size();
+ constexpr uint32_t ttl = 86400; // seconds
+
+ std::vector<char> response(request.size() + answers_section_size);
+ std::copy(request.begin(), request.end(), response.begin());
+ // Modify the header
+ net::dns_protocol::Header* header =
+ reinterpret_cast<net::dns_protocol::Header*>(response.data());
+ header->ancount = base::HostToNet16(1);
+ header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse);
+
+ // Write the answer section
+ base::BigEndianWriter writer(response.data() + request.size(),
+ response.size() - request.size());
+ EXPECT_TRUE(writer.WriteU8(0xc0)); // qname is a pointer
+ EXPECT_TRUE(writer.WriteU8(
+ sizeof(*header))); // address of qname (start of query section)
+ EXPECT_TRUE(writer.WriteU16(net::dns_protocol::kTypeTXT));
+ EXPECT_TRUE(writer.WriteU16(net::dns_protocol::kClassIN));
+ EXPECT_TRUE(writer.WriteU32(ttl));
+ EXPECT_TRUE(writer.WriteU16(answer.size()));
+ EXPECT_TRUE(writer.WriteBytes(answer.data(), answer.size()));
+ EXPECT_EQ(0, writer.remaining());
+
+ return response;
+}
+
+std::vector<char> CreateDnsErrorResponse(const std::vector<char>& request,
+ uint8_t rcode) {
+ std::vector<char> response(request);
+ // Modify the header
+ net::dns_protocol::Header* header =
+ reinterpret_cast<net::dns_protocol::Header*>(response.data());
+ header->ancount = base::HostToNet16(1);
+ header->flags |= base::HostToNet16(net::dns_protocol::kFlagResponse | rcode);
+
+ return response;
+}
+
+std::vector<std::string> GetSampleAuditProof(size_t length) {
+ std::vector<std::string> audit_proof(length);
+ // Makes each node of the audit proof different, so that tests are able to
+ // confirm that the audit proof is reconstructed in the correct order.
+ for (size_t i = 0; i < length; ++i) {
+ std::string node(crypto::kSHA256Length, '\0');
+ // Each node is 32 bytes, with each byte having a different value.
+ for (size_t j = 0; j < crypto::kSHA256Length; ++j) {
+ node[j] = static_cast<char>((-127 + i + j) % 128);
+ }
+ audit_proof[i].assign(std::move(node));
+ }
+
+ return audit_proof;
+}
+
+class MockLeafIndexCallback {
+ public:
+ MockLeafIndexCallback() : called_(false) {}
+
+ bool called() const { return called_; }
+ int net_error() const { return net_error_; }
+ uint64_t leaf_index() const { return leaf_index_; }
+
+ void Run(int net_error, uint64_t leaf_index) {
+ EXPECT_TRUE(!called_);
+ called_ = true;
+ net_error_ = net_error;
+ leaf_index_ = leaf_index;
+ run_loop_.Quit();
+ }
+
+ LogDnsClient::LeafIndexCallback AsCallback() {
+ return base::Bind(&MockLeafIndexCallback::Run, base::Unretained(this));
+ }
+
+ void WaitUntilRun() { run_loop_.Run(); }
+
+ private:
+ bool called_;
+ int net_error_;
+ uint64_t leaf_index_;
+ base::RunLoop run_loop_;
+};
+
+class MockAuditProofCallback {
+ public:
+ MockAuditProofCallback() : called_(false) {}
+
+ bool called() const { return called_; }
+ int net_error() const { return net_error_; }
+ const net::ct::MerkleAuditProof* proof() const { return proof_.get(); }
+
+ void Run(int net_error, std::unique_ptr<net::ct::MerkleAuditProof> proof) {
+ EXPECT_TRUE(!called_);
+ called_ = true;
+ net_error_ = net_error;
+ proof_ = std::move(proof);
+ run_loop_.Quit();
+ }
+
+ LogDnsClient::AuditProofCallback AsCallback() {
+ return base::Bind(&MockAuditProofCallback::Run, base::Unretained(this));
+ }
+
+ void WaitUntilRun() { run_loop_.Run(); }
+
+ private:
+ bool called_;
+ int net_error_;
+ std::unique_ptr<net::ct::MerkleAuditProof> proof_;
+ base::RunLoop run_loop_;
+};
+
+// A container for all of the data we need to keep alive for a mock socket.
+// This is useful because Mock{Read,Write}, SequencedSocketData and
+// MockClientSocketFactory all do not take ownership of or copy their arguments,
+// so we have to manage the lifetime of those arguments ourselves. Wrapping all
+// of that up in a single class simplifies this.
+class MockSocketData {
+ public:
+ // A socket that expects one write and one read operation.
+ MockSocketData(const std::vector<char>& write, const std::vector<char>& read)
+ : expected_write_payload_(write),
+ expected_read_payload_(read),
+ expected_write_(net::SYNCHRONOUS,
+ expected_write_payload_.data(),
+ expected_write_payload_.size(),
+ 0),
+ expected_reads_{net::MockRead(net::ASYNC,
+ expected_read_payload_.data(),
+ expected_read_payload_.size(),
+ 1),
+ eof_},
+ socket_data_(expected_reads_, 2, &expected_write_, 1) {}
+
+ // A socket that expects one write and a read error.
+ MockSocketData(const std::vector<char>& write, int net_error)
+ : expected_write_payload_(write),
+ expected_write_(net::SYNCHRONOUS,
+ expected_write_payload_.data(),
+ expected_write_payload_.size(),
+ 0),
+ expected_reads_{net::MockRead(net::ASYNC, net_error, 1), eof_},
+ socket_data_(expected_reads_, 2, &expected_write_, 1) {}
+
+ // A socket that expects one write and no response.
+ explicit MockSocketData(const std::vector<char>& write)
+ : expected_write_payload_(write),
+ expected_write_(net::SYNCHRONOUS,
+ expected_write_payload_.data(),
+ expected_write_payload_.size(),
+ 0),
+ expected_reads_{net::MockRead(net::SYNCHRONOUS, net::ERR_IO_PENDING, 1),
+ eof_},
+ socket_data_(expected_reads_, 2, &expected_write_, 1) {}
+
+ void SetWriteMode(net::IoMode mode) { expected_write_.mode = mode; }
+ void SetReadMode(net::IoMode mode) { expected_reads_[0].mode = mode; }
+
+ void AddToFactory(net::MockClientSocketFactory* socket_factory) {
+ socket_factory->AddSocketDataProvider(&socket_data_);
+ }
+
+ private:
+ // Prevents read overruns and makes a socket timeout the default behaviour.
+ static const net::MockRead eof_;
+
+ const std::vector<char> expected_write_payload_;
+ const std::vector<char> expected_read_payload_;
+ // Encapsulates the data that is expected to be written to a socket.
+ net::MockWrite expected_write_;
+ // Encapsulates the data/error that should be returned when reading from a
+ // socket. The expected response is followed by |eof_|, to catch further,
+ // unexpected read attempts.
+ net::MockRead expected_reads_[2];
+ net::SequencedSocketData socket_data_;
+
+ DISALLOW_COPY_AND_ASSIGN(MockSocketData);
+};
+
+const net::MockRead MockSocketData::eof_(net::SYNCHRONOUS,
+ net::ERR_IO_PENDING,
+ 2);
+
+class LogDnsClientTest : public ::testing::TestWithParam<net::IoMode> {
+ protected:
+ LogDnsClientTest() {
+ // Use an invalid nameserver address. This prevents the tests accidentally
+ // sending real DNS queries. The mock sockets don't care that the address
+ // is invalid.
+ dns_config_.nameservers.push_back(net::IPEndPoint());
+ // Don't attempt retransmissions - just fail.
+ dns_config_.attempts = 1;
+ // This ensures timeouts are long enough for memory tests.
+ dns_config_.timeout = TestTimeouts::action_timeout();
+ // Simplify testing - don't require random numbers for the source port.
+ // This means our FakeRandInt function should only be called to get query
+ // IDs.
+ dns_config_.randomize_ports = false;
+ }
+
+ void ExpectRequestAndErrorResponse(base::StringPiece qname, uint8_t rcode) {
+ std::vector<char> request = CreateDnsTxtRequest(qname);
+ std::vector<char> response = CreateDnsErrorResponse(request, rcode);
+
+ mock_socket_data_.emplace_back(new MockSocketData(request, response));
+ mock_socket_data_.back()->SetReadMode(GetParam());
+ mock_socket_data_.back()->AddToFactory(&socket_factory_);
+ }
+
+ void ExpectRequestAndSocketError(base::StringPiece qname, int net_error) {
+ std::vector<char> request = CreateDnsTxtRequest(qname);
+
+ mock_socket_data_.emplace_back(new MockSocketData(request, net_error));
+ mock_socket_data_.back()->SetReadMode(GetParam());
+ mock_socket_data_.back()->AddToFactory(&socket_factory_);
+ }
+
+ void ExpectRequestAndTimeout(base::StringPiece qname) {
+ std::vector<char> request = CreateDnsTxtRequest(qname);
+
+ mock_socket_data_.emplace_back(new MockSocketData(request));
+ mock_socket_data_.back()->SetReadMode(GetParam());
+ mock_socket_data_.back()->AddToFactory(&socket_factory_);
+
+ // Speed up timeout tests.
+ dns_config_.timeout = TestTimeouts::tiny_timeout();
+ }
+
+ void ExpectLeafIndexRequestAndResponse(base::StringPiece qname,
+ base::StringPiece leaf_index) {
+ // Prepend size to leaf_index to create the query answer (rdata)
+ ASSERT_LE(leaf_index.size(), 0xFFul); // size must fit into a single byte
+ std::string answer = leaf_index.as_string();
+ answer.insert(answer.begin(), static_cast<char>(leaf_index.size()));
+
+ ExpectRequestAndResponse(qname, answer);
+ }
+
+ void ExpectAuditProofRequestAndResponse(
+ base::StringPiece qname,
+ std::vector<std::string>::const_iterator audit_path_start,
+ std::vector<std::string>::const_iterator audit_path_end) {
+ // Join nodes in the audit path into a single string.
+ std::string proof =
+ std::accumulate(audit_path_start, audit_path_end, std::string());
+
+ // Prepend size to proof to create the query answer (rdata)
+ ASSERT_LE(proof.size(), 0xFFul); // size must fit into a single byte
+ proof.insert(proof.begin(), static_cast<char>(proof.size()));
+
+ ExpectRequestAndResponse(qname, proof);
+ }
+
+ void QueryLeafIndex(base::StringPiece log_domain,
+ base::StringPiece leaf_hash,
+ MockLeafIndexCallback* callback) {
+ std::unique_ptr<net::DnsClient> dns_client = CreateDnsClient();
+ LogDnsClient log_client(std::move(dns_client), net::BoundNetLog());
+
+ log_client.QueryLeafIndex(log_domain, leaf_hash, callback->AsCallback());
+ callback->WaitUntilRun();
+ }
+
+ void QueryAuditProof(base::StringPiece log_domain,
+ uint64_t leaf_index,
+ uint64_t tree_size,
+ MockAuditProofCallback* callback) {
+ std::unique_ptr<net::DnsClient> dns_client = CreateDnsClient();
+ LogDnsClient log_client(std::move(dns_client), net::BoundNetLog());
+
+ log_client.QueryAuditProof(log_domain, leaf_index, tree_size,
+ callback->AsCallback());
+ callback->WaitUntilRun();
+ }
+
+ private:
+ std::unique_ptr<net::DnsClient> CreateDnsClient() {
+ std::unique_ptr<net::DnsClient> client =
+ net::DnsClient::CreateClientForTesting(nullptr, &socket_factory_,
+ base::Bind(&FakeRandInt));
+ client->SetConfig(dns_config_);
+ return client;
+ }
+
+ void ExpectRequestAndResponse(base::StringPiece qname,
+ base::StringPiece answer) {
+ std::vector<char> request = CreateDnsTxtRequest(qname);
+ std::vector<char> response = CreateDnsTxtResponse(request, answer);
+
+ mock_socket_data_.emplace_back(new MockSocketData(request, response));
+ mock_socket_data_.back()->SetReadMode(GetParam());
+ mock_socket_data_.back()->AddToFactory(&socket_factory_);
+ }
+
+ net::DnsConfig dns_config_;
+ base::MessageLoopForIO message_loop_;
+ std::vector<std::unique_ptr<MockSocketData>> mock_socket_data_;
+ net::MockClientSocketFactory socket_factory_;
+};
+
+TEST_P(LogDnsClientTest, QueryLeafIndex) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ "123456");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsOk());
+ EXPECT_THAT(callback.leaf_index(), 123456);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsThatLogDomainDoesNotExist) {
+ ExpectRequestAndErrorResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ net::dns_protocol::kRcodeNXDOMAIN);
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_NAME_NOT_RESOLVED));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsServerFailure) {
+ ExpectRequestAndErrorResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ net::dns_protocol::kRcodeSERVFAIL);
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_SERVER_FAILED));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsServerRefusal) {
+ ExpectRequestAndErrorResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ net::dns_protocol::kRcodeREFUSED);
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_SERVER_FAILED));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest,
+ QueryLeafIndexReportsMalformedResponseIfLeafIndexIsNotNumeric) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ "foo");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest,
+ QueryLeafIndexReportsMalformedResponseIfLeafIndexIsFloatingPoint) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ "123456.0");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest,
+ QueryLeafIndexReportsMalformedResponseIfLeafIndexIsEmpty) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.", "");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest,
+ QueryLeafIndexReportsMalformedResponseIfLeafIndexHasNonNumericPrefix) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ "foo123456");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest,
+ QueryLeafIndexReportsMalformedResponseIfLeafIndexHasNonNumericSuffix) {
+ ExpectLeafIndexRequestAndResponse(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ "123456foo");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsInvalidArgIfLogDomainIsEmpty) {
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsInvalidArgIfLogDomainIsNull) {
+ MockLeafIndexCallback callback;
+ QueryLeafIndex(nullptr, kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsInvalidArgIfLeafHashIsInvalid) {
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", "foo", &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsInvalidArgIfLeafHashIsEmpty) {
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", "", &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsInvalidArgIfLeafHashIsNull) {
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", nullptr, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsSocketError) {
+ ExpectRequestAndSocketError(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.",
+ net::ERR_CONNECTION_REFUSED);
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_CONNECTION_REFUSED));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryLeafIndexReportsTimeout) {
+ ExpectRequestAndTimeout(
+ "D4S6DSV2J743QJZEQMH4UYHEYK7KRQ5JIQOCPMFUHZVJNFGHXACA.hash.ct.test.");
+
+ MockLeafIndexCallback callback;
+ QueryLeafIndex("ct.test", kLeafHash, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_TIMED_OUT));
+ EXPECT_THAT(callback.leaf_index(), 0);
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProof) {
+ const std::vector<std::string> audit_proof = GetSampleAuditProof(20);
+
+ // It should require 3 queries to collect the entire audit proof, as there is
+ // only space for 7 nodes per UDP packet.
+ ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.",
+ audit_proof.begin(),
+ audit_proof.begin() + 7);
+ ExpectAuditProofRequestAndResponse("7.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 7,
+ audit_proof.begin() + 14);
+ ExpectAuditProofRequestAndResponse("14.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 14,
+ audit_proof.end());
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsOk());
+ ASSERT_THAT(callback.proof(), NotNull());
+ EXPECT_THAT(callback.proof()->leaf_index, 123456);
+ // EXPECT_THAT(callback.proof()->tree_size, 999999);
+ EXPECT_THAT(callback.proof()->nodes, audit_proof);
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofHandlesResponsesWithShortAuditPaths) {
+ const std::vector<std::string> audit_proof = GetSampleAuditProof(20);
+
+ // Make some of the responses contain fewer proof nodes than they can hold.
+ ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.",
+ audit_proof.begin(),
+ audit_proof.begin() + 1);
+ ExpectAuditProofRequestAndResponse("1.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 1,
+ audit_proof.begin() + 3);
+ ExpectAuditProofRequestAndResponse("3.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 3,
+ audit_proof.begin() + 6);
+ ExpectAuditProofRequestAndResponse("6.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 6,
+ audit_proof.begin() + 10);
+ ExpectAuditProofRequestAndResponse("10.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 10,
+ audit_proof.begin() + 13);
+ ExpectAuditProofRequestAndResponse("13.123456.999999.tree.ct.test.",
+ audit_proof.begin() + 13,
+ audit_proof.end());
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsOk());
+ ASSERT_THAT(callback.proof(), NotNull());
+ EXPECT_THAT(callback.proof()->leaf_index, 123456);
+ // EXPECT_THAT(callback.proof()->tree_size, 999999);
+ EXPECT_THAT(callback.proof()->nodes, audit_proof);
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsThatLogDomainDoesNotExist) {
+ ExpectRequestAndErrorResponse("0.123456.999999.tree.ct.test.",
+ net::dns_protocol::kRcodeNXDOMAIN);
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_NAME_NOT_RESOLVED));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsServerFailure) {
+ ExpectRequestAndErrorResponse("0.123456.999999.tree.ct.test.",
+ net::dns_protocol::kRcodeSERVFAIL);
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_SERVER_FAILED));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsServerRefusal) {
+ ExpectRequestAndErrorResponse("0.123456.999999.tree.ct.test.",
+ net::dns_protocol::kRcodeREFUSED);
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_SERVER_FAILED));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest,
+ QueryAuditProofReportsResponseMalformedIfNodeTooShort) {
+ // node is shorter than a SHA-256 hash (31 vs 32 bytes)
+ const std::vector<std::string> audit_proof(1, std::string(31, 'a'));
+
+ ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.",
+ audit_proof.begin(), audit_proof.end());
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfNodeTooLong) {
+ // node is longer than a SHA-256 hash (33 vs 32 bytes)
+ const std::vector<std::string> audit_proof(1, std::string(33, 'a'));
+
+ ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.",
+ audit_proof.begin(), audit_proof.end());
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsResponseMalformedIfEmpty) {
+ const std::vector<std::string> audit_proof;
+
+ ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.",
+ audit_proof.begin(), audit_proof.end());
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_MALFORMED_RESPONSE));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLogDomainIsEmpty) {
+ MockAuditProofCallback callback;
+ QueryAuditProof("", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsInvalidArgIfLogDomainIsNull) {
+ MockAuditProofCallback callback;
+ QueryAuditProof(nullptr, 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest,
+ QueryAuditProofReportsInvalidArgIfLeafIndexEqualToTreeSize) {
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 123456, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest,
+ QueryAuditProofReportsInvalidArgIfLeafIndexGreaterThanTreeSize) {
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 999999, 123456, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_INVALID_ARGUMENT));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsSocketError) {
+ ExpectRequestAndSocketError("0.123456.999999.tree.ct.test.",
+ net::ERR_CONNECTION_REFUSED);
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_CONNECTION_REFUSED));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+TEST_P(LogDnsClientTest, QueryAuditProofReportsTimeout) {
+ ExpectRequestAndTimeout("0.123456.999999.tree.ct.test.");
+
+ MockAuditProofCallback callback;
+ QueryAuditProof("ct.test", 123456, 999999, &callback);
+ ASSERT_TRUE(callback.called());
+ EXPECT_THAT(callback.net_error(), IsError(net::ERR_DNS_TIMED_OUT));
+ EXPECT_THAT(callback.proof(), IsNull());
+}
+
+INSTANTIATE_TEST_CASE_P(ReadMode,
+ LogDnsClientTest,
+ ::testing::Values(net::IoMode::ASYNC,
+ net::IoMode::SYNCHRONOUS));
+
+} // namespace
+} // namespace certificate_transparency