net: implement the beginnings of HSTS pinning

(Based on a patch by Chris Evans.)

Doesn't yet actually get the information from the HSTS header, but all the
infrastructure is in place.

BUG=none
TEST=none

Review URL: http://codereview.chromium.org/6835033

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@81584 0039d316-1c4b-4281-b951-d872f2087c98
diff --git a/net/base/transport_security_state.cc b/net/base/transport_security_state.cc
index 44dce0e..8f578c2f 100644
--- a/net/base/transport_security_state.cc
+++ b/net/base/transport_security_state.cc
@@ -9,6 +9,7 @@
 #include "base/json/json_writer.h"
 #include "base/logging.h"
 #include "base/memory/scoped_ptr.h"
+#include "base/sha1.h"
 #include "base/sha2.h"
 #include "base/string_number_conversions.h"
 #include "base/string_tokenizer.h"
@@ -32,6 +33,9 @@
   if (canonicalized_host.empty())
     return;
 
+  // TODO(cevans) -- we likely want to permit a host to override a built-in,
+  // for at least the case where the override is stricter (i.e. includes
+  // subdomains, or includes certificate pinning).
   bool temp;
   if (IsPreloadedSTS(canonicalized_host, true, &temp))
     return;
@@ -313,6 +317,18 @@
         continue;
     }
 
+    ListValue* pins = new ListValue;
+    for (std::vector<SHA1Fingerprint>::const_iterator
+         j = i->second.public_key_hashes.begin();
+         j != i->second.public_key_hashes.end(); ++j) {
+      std::string hash_str(reinterpret_cast<const char*>(j->data),
+                           sizeof(j->data));
+      std::string b64;
+      base::Base64Encode(hash_str, &b64);
+      pins->Append(new StringValue("sha1/" + b64));
+    }
+    state->Set("public_key_hashes", pins);
+
     toplevel.Set(HashedDomainToExternalString(i->first), state);
   }
 
@@ -350,6 +366,26 @@
       continue;
     }
 
+    ListValue* pins_list = NULL;
+    std::vector<SHA1Fingerprint> public_key_hashes;
+    if (state->GetList("public_key_hashes", &pins_list)) {
+      size_t num_pins = pins_list->GetSize();
+      for (size_t i = 0; i < num_pins; ++i) {
+        std::string type_and_base64;
+        std::string hash_str;
+        SHA1Fingerprint hash;
+        if (pins_list->GetString(i, &type_and_base64) &&
+            type_and_base64.find("sha1/") == 0 &&
+            base::Base64Decode(
+                type_and_base64.substr(5, type_and_base64.size() - 5),
+                &hash_str) &&
+            hash_str.size() == base::SHA1_LENGTH) {
+          memcpy(hash.data, hash_str.data(), sizeof(hash.data));
+          public_key_hashes.push_back(hash);
+        }
+      }
+    }
+
     DomainState::Mode mode;
     if (mode_string == "strict") {
       mode = DomainState::MODE_STRICT;
@@ -381,14 +417,17 @@
     }
 
     std::string hashed = ExternalStringToHashedDomain(*i);
-    if (hashed.empty())
+    if (hashed.empty()) {
+      dirtied = true;
       continue;
+    }
 
     DomainState new_state;
     new_state.mode = mode;
     new_state.created = created_time;
     new_state.expiry = expiry_time;
     new_state.include_subdomains = include_subdomains;
+    new_state.public_key_hashes = public_key_hashes;
     enabled_hosts_[hashed] = new_state;
   }
 
@@ -524,4 +563,41 @@
   return false;
 }
 
+static std::string HashesToBase64String(
+    const std::vector<net::SHA1Fingerprint>& hashes) {
+  std::vector<std::string> hashes_strs;
+  for (std::vector<net::SHA1Fingerprint>::const_iterator
+       i = hashes.begin(); i != hashes.end(); i++) {
+    std::string s;
+    const std::string hash_str(reinterpret_cast<const char*>(i->data),
+                               sizeof(i->data));
+    base::Base64Encode(hash_str, &s);
+    hashes_strs.push_back(s);
+  }
+
+  return JoinString(hashes_strs, ',');
+}
+
+bool TransportSecurityState::DomainState::IsChainOfPublicKeysPermitted(
+    const std::vector<net::SHA1Fingerprint>& hashes) {
+  if (public_key_hashes.empty())
+    return true;
+
+  for (std::vector<net::SHA1Fingerprint>::const_iterator
+       i = hashes.begin(); i != hashes.end(); ++i) {
+    for (std::vector<net::SHA1Fingerprint>::const_iterator
+         j = public_key_hashes.begin(); j != public_key_hashes.end(); ++j) {
+      if (i->Equals(*j))
+        return true;
+    }
+  }
+
+
+  LOG(ERROR) << "Rejecting public key chain for domain " << domain
+             << ". Validated chain: " << HashesToBase64String(hashes)
+             << ", expected: " << HashesToBase64String(public_key_hashes);
+
+  return false;
+}
+
 }  // namespace
diff --git a/net/base/transport_security_state.h b/net/base/transport_security_state.h
index e7705f5..33bae61 100644
--- a/net/base/transport_security_state.h
+++ b/net/base/transport_security_state.h
@@ -8,11 +8,13 @@
 
 #include <map>
 #include <string>
+#include <vector>
 
 #include "base/basictypes.h"
 #include "base/gtest_prod_util.h"
 #include "base/memory/ref_counted.h"
 #include "base/time.h"
+#include "net/base/x509_cert_types.h"
 
 namespace net {
 
@@ -50,10 +52,18 @@
           include_subdomains(false),
           preloaded(false) { }
 
+    // IsChainOfPublicKeysPermitted takes a set of public key hashes and
+    // returns true if:
+    //   1) |public_key_hashes| is empty, i.e. no public keys have been pinned.
+    //   2) |hashes| and |public_key_hashes| are not disjoint.
+    bool IsChainOfPublicKeysPermitted(
+        const std::vector<SHA1Fingerprint>& hashes);
+
     Mode mode;
     base::Time created;  // when this host entry was first created
     base::Time expiry;  // the absolute time (UTC) when this record expires
     bool include_subdomains;  // subdomains included?
+    std::vector<SHA1Fingerprint> public_key_hashes;  // optional; permitted keys
 
     // The follow members are not valid when stored in |enabled_hosts_|.
     bool preloaded;  // is this a preloaded entry?
diff --git a/net/base/transport_security_state_unittest.cc b/net/base/transport_security_state_unittest.cc
index 9823072..d2db7f9e 100644
--- a/net/base/transport_security_state_unittest.cc
+++ b/net/base/transport_security_state_unittest.cc
@@ -501,4 +501,37 @@
   EXPECT_FALSE(state->IsEnabledForHost(&domain_state, kLongName, true));
 }
 
+TEST_F(TransportSecurityStateTest, PublicKeyHashes) {
+  scoped_refptr<TransportSecurityState> state(
+      new TransportSecurityState);
+
+  TransportSecurityState::DomainState domain_state;
+  EXPECT_FALSE(state->IsEnabledForHost(&domain_state, "example.com", false));
+  std::vector<SHA1Fingerprint> hashes;
+  EXPECT_TRUE(domain_state.IsChainOfPublicKeysPermitted(hashes));
+
+  SHA1Fingerprint hash;
+  memset(hash.data, '1', sizeof(hash.data));
+  domain_state.public_key_hashes.push_back(hash);
+
+  EXPECT_FALSE(domain_state.IsChainOfPublicKeysPermitted(hashes));
+  hashes.push_back(hash);
+  EXPECT_TRUE(domain_state.IsChainOfPublicKeysPermitted(hashes));
+  hashes[0].data[0] = '2';
+  EXPECT_FALSE(domain_state.IsChainOfPublicKeysPermitted(hashes));
+
+  const base::Time current_time(base::Time::Now());
+  const base::Time expiry = current_time + base::TimeDelta::FromSeconds(1000);
+  domain_state.expiry = expiry;
+  state->EnableHost("example.com", domain_state);
+  std::string ser;
+  EXPECT_TRUE(state->Serialise(&ser));
+  bool dirty;
+  EXPECT_TRUE(state->Deserialise(ser, &dirty));
+  EXPECT_TRUE(state->IsEnabledForHost(&domain_state, "example.com", false));
+  EXPECT_EQ(1u, domain_state.public_key_hashes.size());
+  EXPECT_TRUE(0 == memcmp(domain_state.public_key_hashes[0].data, hash.data,
+                          sizeof(hash.data)));
+}
+
 }  // namespace net
diff --git a/net/url_request/url_request_http_job.cc b/net/url_request/url_request_http_job.cc
index 3afa977..7166e5f0 100644
--- a/net/url_request/url_request_http_job.cc
+++ b/net/url_request/url_request_http_job.cc
@@ -653,6 +653,25 @@
   // Clear the IO_PENDING status
   SetStatus(URLRequestStatus());
 
+  // Take care of any mandates for public key pinning.
+  // TODO(agl): we might have an issue here where a request for foo.example.com
+  // merges into a SPDY connection to www.example.com, and gets a different
+  // certificate.
+  const SSLInfo& ssl_info = transaction_->GetResponseInfo()->ssl_info;
+  if (result == OK &&
+      ssl_info.is_valid() &&
+      context_->transport_security_state()) {
+    TransportSecurityState::DomainState domain_state;
+    if (context_->transport_security_state()->IsEnabledForHost(
+            &domain_state,
+            request_->url().host(),
+            IsSNIAvailable(context_)) &&
+        ssl_info.is_issued_by_known_root &&
+        !domain_state.IsChainOfPublicKeysPermitted(ssl_info.public_key_hashes)){
+      result = ERR_CERT_INVALID;
+    }
+  }
+
   if (result == OK) {
     SaveCookiesAndNotifyHeadersComplete();
   } else if (ShouldTreatAsCertificateError(result)) {