blob: 273bda47fefacf050058d2b4cadd022147a58bd5 [file] [log] [blame]
// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stdio.h>
#include <stdlib.h>
#include <set>
#include "base/file_util.h"
#include "base/logging.h"
#include "base/path_service.h"
#include "base/perftimer.h"
#include "base/string_util.h"
#include "base/test_file_util.h"
#include "chrome/browser/safe_browsing/safe_browsing_database.h"
#include "chrome/common/chrome_paths.h"
#include "chrome/common/sqlite_compiled_statement.h"
#include "chrome/common/sqlite_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
// These tests are slow, especially the ones that create databases. So disable
// them by default.
//#define SAFE_BROWSING_DATABASE_TESTS_ENABLED
#ifdef SAFE_BROWSING_DATABASE_TESTS_ENABLED
namespace {
// Base class for a safebrowsing database. Derived classes can implement
// different types of tables to compare performance characteristics.
class Database {
public:
Database() : db_(NULL) {
}
~Database() {
if (db_) {
statement_cache_.Cleanup();
sqlite3_close(db_);
db_ = NULL;
}
}
bool Init(const std::string& name, bool create) {
// get an empty file for the test DB
std::wstring filename;
PathService::Get(base::DIR_TEMP, &filename);
filename.push_back(FilePath::kSeparators[0]);
filename.append(ASCIIToWide(name));
if (create) {
DeleteFile(filename.c_str());
} else {
DLOG(INFO) << "evicting " << name << " ...";
file_util::EvictFileFromSystemCache(filename.c_str());
DLOG(INFO) << "... evicted";
}
if (sqlite3_open(WideToUTF8(filename).c_str(), &db_) != SQLITE_OK)
return false;
statement_cache_.set_db(db_);
if (!create)
return true;
return CreateTable();
}
virtual bool CreateTable() = 0;
virtual bool Add(int host_key, int* prefixes, int count) = 0;
virtual bool Read(int host_key, int* prefixes, int size, int* count) = 0;
virtual int Count() = 0;
virtual std::string GetDBSuffix() = 0;
sqlite3* db() { return db_; }
protected:
// The database connection.
sqlite3* db_;
// Cache of compiled statements for our database.
SqliteStatementCache statement_cache_;
};
class SimpleDatabase : public Database {
public:
virtual bool CreateTable() {
if (DoesSqliteTableExist(db_, "hosts"))
return false;
return sqlite3_exec(db_, "CREATE TABLE hosts ("
"host INTEGER,"
"prefixes BLOB)",
NULL, NULL, NULL) == SQLITE_OK;
}
virtual bool Add(int host_key, int* prefixes, int count) {
SQLITE_UNIQUE_STATEMENT(statement, statement_cache_,
"INSERT OR REPLACE INTO hosts"
"(host,prefixes)"
"VALUES (?,?)");
if (!statement.is_valid())
return false;
statement->bind_int(0, host_key);
statement->bind_blob(1, prefixes, count*sizeof(int));
return statement->step() == SQLITE_DONE;
}
virtual bool Read(int host_key, int* prefixes, int size, int* count) {
SQLITE_UNIQUE_STATEMENT(statement, statement_cache_,
"SELECT host, prefixes FROM hosts WHERE host=?");
if (!statement.is_valid())
return false;
statement->bind_int(0, host_key);
int rv = statement->step();
if (rv == SQLITE_DONE) {
// no hostkey found, not an error
*count = -1;
return true;
}
if (rv != SQLITE_ROW)
return false;
*count = statement->column_bytes(1);
if (*count > size)
return false;
memcpy(prefixes, statement->column_blob(0), *count);
return true;
}
int Count() {
SQLITE_UNIQUE_STATEMENT(statement, statement_cache_,
"SELECT COUNT(*) FROM hosts");
if (!statement.is_valid()) {
EXPECT_TRUE(false);
return -1;
}
if (statement->step() != SQLITE_ROW) {
EXPECT_TRUE(false);
return -1;
}
return statement->column_int(0);
}
std::string GetDBSuffix() {
return "Simple";
}
};
class IndexedDatabase : public SimpleDatabase {
public:
virtual bool CreateTable() {
return sqlite3_exec(db_, "CREATE TABLE hosts ("
"host INTEGER PRIMARY KEY,"
"prefixes BLOB)",
NULL, NULL, NULL) == SQLITE_OK;
}
std::string GetDBSuffix() {
return "Indexed";
}
};
class IndexedWithIDDatabase : public SimpleDatabase {
public:
virtual bool CreateTable() {
return sqlite3_exec(db_, "CREATE TABLE hosts ("
"id INTEGER PRIMARY KEY AUTOINCREMENT,"
"host INTEGER UNIQUE,"
"prefixes BLOB)",
NULL, NULL, NULL) == SQLITE_OK;
}
virtual bool Add(int host_key, int* prefixes, int count) {
SQLITE_UNIQUE_STATEMENT(statement, statement_cache_,
"INSERT OR REPLACE INTO hosts"
"(id,host,prefixes)"
"VALUES (NULL,?,?)");
if (!statement.is_valid())
return false;
statement->bind_int(0, host_key);
statement->bind_blob(1, prefixes, count * sizeof(int));
return statement->step() == SQLITE_DONE;
}
std::string GetDBSuffix() {
return "IndexedWithID";
}
};
}
class SafeBrowsing: public testing::Test {
protected:
// Get the test parameters from the test case's name.
virtual void SetUp() {
logging::InitLogging(
NULL, logging::LOG_ONLY_TO_SYSTEM_DEBUG_LOG,
logging::LOCK_LOG_FILE,
logging::DELETE_OLD_LOG_FILE);
const testing::TestInfo* const test_info =
testing::UnitTest::GetInstance()->current_test_info();
std::string test_name = test_info->name();
TestType type;
if (test_name.find("Write") != std::string::npos) {
type = WRITE;
} else if (test_name.find("Read") != std::string::npos) {
type = READ;
} else {
type = COUNT;
}
if (test_name.find("IndexedWithID") != std::string::npos) {
db_ = new IndexedWithIDDatabase();
} else if (test_name.find("Indexed") != std::string::npos) {
db_ = new IndexedDatabase();
} else {
db_ = new SimpleDatabase();
}
char multiplier_letter = test_name[test_name.size() - 1];
int multiplier = 0;
if (multiplier_letter == 'K') {
multiplier = 1000;
} else if (multiplier_letter == 'M') {
multiplier = 1000000;
} else {
NOTREACHED();
}
size_t index = test_name.size() - 1;
while (index != 0 && test_name[index] != '_')
index--;
DCHECK(index);
const char* count_start = test_name.c_str() + ++index;
int count = atoi(count_start);
int size = count * multiplier;
db_name_ = StringPrintf("TestSafeBrowsing");
db_name_.append(count_start);
db_name_.append(db_->GetDBSuffix());
ASSERT_TRUE(db_->Init(db_name_, type == WRITE));
if (type == WRITE) {
WriteEntries(size);
} else if (type == READ) {
ReadEntries(100);
} else {
CountEntries();
}
}
virtual void TearDown() {
delete db_;
}
// This writes the given number of entries to the database.
void WriteEntries(int count) {
int prefixes[4];
SQLTransaction transaction(db_->db());
transaction.Begin();
int inc = kint32max / count;
for (int i = 0; i < count; i++) {
int hostkey;
rand_s((unsigned int*)&hostkey);
ASSERT_TRUE(db_->Add(hostkey, prefixes, 1));
}
transaction.Commit();
}
// Read the given number of entries from the database.
void ReadEntries(int count) {
int prefixes[4];
int64 total_ms = 0;
for (int i = 0; i < count; ++i) {
int key;
rand_s((unsigned int*)&key);
PerfTimer timer;
int read;
ASSERT_TRUE(db_->Read(key, prefixes, sizeof(prefixes), &read));
int64 time_ms = timer.Elapsed().InMilliseconds();
total_ms += time_ms;
DLOG(INFO) << "Read in " << time_ms << " ms.";
}
DLOG(INFO) << db_name_ << " read " << count << " entries in average of " <<
total_ms/count << " ms.";
}
// Counts how many entries are in the database, which effectively does a full
// table scan.
void CountEntries() {
PerfTimer timer;
int count = db_->Count();
DLOG(INFO) << db_name_ << " counted " << count << " entries in " <<
timer.Elapsed().InMilliseconds() << " ms";
}
enum TestType {
WRITE,
READ,
COUNT,
};
private:
Database* db_;
std::string db_name_;
};
TEST_F(SafeBrowsing, Write_100K) {
}
TEST_F(SafeBrowsing, Read_100K) {
}
TEST_F(SafeBrowsing, WriteIndexed_100K) {
}
TEST_F(SafeBrowsing, ReadIndexed_100K) {
}
TEST_F(SafeBrowsing, WriteIndexed_250K) {
}
TEST_F(SafeBrowsing, ReadIndexed_250K) {
}
TEST_F(SafeBrowsing, WriteIndexed_500K) {
}
TEST_F(SafeBrowsing, ReadIndexed_500K) {
}
TEST_F(SafeBrowsing, ReadIndexedWithID_250K) {
}
TEST_F(SafeBrowsing, WriteIndexedWithID_250K) {
}
TEST_F(SafeBrowsing, ReadIndexedWithID_500K) {
}
TEST_F(SafeBrowsing, WriteIndexedWithID_500K) {
}
TEST_F(SafeBrowsing, CountIndexed_250K) {
}
TEST_F(SafeBrowsing, CountIndexed_500K) {
}
TEST_F(SafeBrowsing, CountIndexedWithID_250K) {
}
TEST_F(SafeBrowsing, CountIndexedWithID_500K) {
}
class SafeBrowsingDatabaseTest {
public:
SafeBrowsingDatabaseTest(const std::wstring& name) {
logging::InitLogging(
NULL, logging::LOG_ONLY_TO_SYSTEM_DEBUG_LOG,
logging::LOCK_LOG_FILE,
logging::DELETE_OLD_LOG_FILE);
PathService::Get(base::DIR_TEMP, &filename_);
filename_.push_back(FilePath::kSeparators[0]);
filename_.append(name);
}
void Create(int size) {
DeleteFile(filename_.c_str());
SafeBrowsingDatabase database;
database.set_synchronous();
EXPECT_TRUE(database.Init(filename_));
int chunk_id = 0;
int total_host_keys = size;
int host_keys_per_chunk = 100;
std::deque<SBChunk>* chunks = new std::deque<SBChunk>;
for (int i = 0; i < total_host_keys / host_keys_per_chunk; ++i) {
chunks->push_back(SBChunk());
chunks->back().chunk_number = ++chunk_id;
for (int j = 0; j < host_keys_per_chunk; ++j) {
SBChunkHost host;
rand_s((unsigned int*)&host.host);
host.entry = SBEntry::Create(SBEntry::ADD_PREFIX, 2);
host.entry->SetPrefixAt(0, 0x2425525);
host.entry->SetPrefixAt(1, 0x1536366);
chunks->back().hosts.push_back(host);
}
}
database.InsertChunks("goog-malware", chunks);
}
void Read(bool use_bloom_filter) {
int keys_to_read = 500;
file_util::EvictFileFromSystemCache(filename_.c_str());
SafeBrowsingDatabase database;
database.set_synchronous();
EXPECT_TRUE(database.Init(filename_));
PerfTimer total_timer;
int64 db_ms = 0;
int keys_from_db = 0;
for (int i = 0; i < keys_to_read; ++i) {
int key;
rand_s((unsigned int*)&key);
std::string url = StringPrintf("http://www.%d.com/blah.html", key);
std::string matching_list;
std::vector<SBPrefix> prefix_hits;
GURL gurl(url);
if (!use_bloom_filter || database.NeedToCheckUrl(gurl)) {
PerfTimer timer;
database.ContainsUrl(gurl, &matching_list, &prefix_hits);
int64 time_ms = timer.Elapsed().InMilliseconds();
DLOG(INFO) << "Read from db in " << time_ms << " ms.";
db_ms += time_ms;
keys_from_db++;
}
}
int64 total_ms = total_timer.Elapsed().InMilliseconds();
DLOG(INFO) << WideToASCII(file_util::GetFilenameFromPath(filename_)) <<
" read " << keys_to_read << " entries in " << total_ms << " ms. " <<
keys_from_db << " keys were read from the db, with average read taking " <<
db_ms / keys_from_db << " ms";
}
void BuildBloomFilter() {
file_util::EvictFileFromSystemCache(filename_.c_str());
file_util::Delete(SafeBrowsingDatabase::BloomFilterFilename(filename_), false);
PerfTimer total_timer;
SafeBrowsingDatabase database;
database.set_synchronous();
EXPECT_TRUE(database.Init(filename_));
int64 total_ms = total_timer.Elapsed().InMilliseconds();
DLOG(INFO) << WideToASCII(file_util::GetFilenameFromPath(filename_)) <<
" built bloom filter in " << total_ms << " ms.";
}
private:
std::wstring filename_;
};
// Adds 100K host records.
TEST(SafeBrowsingDatabase, FillUp100K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing100K");
db.Create(100000);
}
// Adds 250K host records.
TEST(SafeBrowsingDatabase, FillUp250K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing250K");
db.Create(250000);
}
// Adds 500K host records.
TEST(SafeBrowsingDatabase, FillUp500K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing500K");
db.Create(500000);
}
// Reads 500 entries and prints the timing.
TEST(SafeBrowsingDatabase, ReadFrom250K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing250K");
db.Read(false);
}
TEST(SafeBrowsingDatabase, ReadFrom500K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing500K");
db.Read(false);
}
// Read 500 entries with a bloom filter and print the timing.
TEST(SafeBrowsingDatabase, BloomReadFrom250K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing250K");
db.Read(true);
}
TEST(SafeBrowsingDatabase, BloomReadFrom500K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing500K");
db.Read(true);
}
// Test how long bloom filter creation takes.
TEST(SafeBrowsingDatabase, BuildBloomFilter250K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing250K");
db.BuildBloomFilter();
}
TEST(SafeBrowsingDatabase, BuildBloomFilter500K) {
SafeBrowsingDatabaseTest db(L"SafeBrowsing500K");
db.BuildBloomFilter();
}
#endif