Implementation of single-module-per-logon-session-per-profile implementation for Chrome Frame.

BUG=61383
TEST=Run IE with version X of CF loaded. Register version Y of CF. Run a second IE process, notice that version Y when loaded defers to version X.


Review URL: http://codereview.chromium.org/4144008

git-svn-id: svn://svn.chromium.org/chrome/trunk/src@65236 0039d316-1c4b-4281-b951-d872f2087c98
diff --git a/base/shared_memory.h b/base/shared_memory.h
index 79ea8fd..719eb69d 100644
--- a/base/shared_memory.h
+++ b/base/shared_memory.h
@@ -43,6 +43,13 @@
  public:
   SharedMemory();
 
+#if defined(OS_WIN)
+  // Similar to the default constructor, except that this allows for
+  // calling Lock() to acquire the named mutex before either Create or Open
+  // are called on Windows.
+  explicit SharedMemory(const std::wstring& name);
+#endif
+
   // Create a new SharedMemory object from an existing, open
   // shared memory file.
   SharedMemory(SharedMemoryHandle handle, bool read_only);
@@ -165,6 +172,12 @@
   // across Mac and Linux.
   void Lock();
 
+#if defined(OS_WIN)
+  // A Lock() implementation with a timeout. Returns true if the Lock() has
+  // been acquired, false if the timeout was reached.
+  bool Lock(uint32 timeout_ms);
+#endif
+
   // Releases the shared memory lock.
   void Unlock();
 
diff --git a/base/shared_memory_win.cc b/base/shared_memory_win.cc
index a0b2a5aa..5f293fc2 100644
--- a/base/shared_memory_win.cc
+++ b/base/shared_memory_win.cc
@@ -17,6 +17,15 @@
       lock_(NULL) {
 }
 
+SharedMemory::SharedMemory(const std::wstring& name)
+    : mapped_file_(NULL),
+      memory_(NULL),
+      read_only_(false),
+      created_size_(0),
+      lock_(NULL),
+      name_(name) {
+}
+
 SharedMemory::SharedMemory(SharedMemoryHandle handle, bool read_only)
     : mapped_file_(handle),
       memory_(NULL),
@@ -188,6 +197,10 @@
 }
 
 void SharedMemory::Lock() {
+  Lock(INFINITE);
+}
+
+bool SharedMemory::Lock(uint32 timeout_ms) {
   if (lock_ == NULL) {
     std::wstring name = name_;
     name.append(L"lock");
@@ -195,10 +208,13 @@
     DCHECK(lock_ != NULL);
     if (lock_ == NULL) {
       DLOG(ERROR) << "Could not create mutex" << GetLastError();
-      return;  // there is nothing good we can do here.
+      return false;  // there is nothing good we can do here.
     }
   }
-  WaitForSingleObject(lock_, INFINITE);
+  DWORD result = WaitForSingleObject(lock_, timeout_ms);
+
+  // Return false for WAIT_ABANDONED, WAIT_TIMEOUT or WAIT_FAILED.
+  return (result == WAIT_OBJECT_0);
 }
 
 void SharedMemory::Unlock() {
diff --git a/chrome_frame/chrome_frame.gyp b/chrome_frame/chrome_frame.gyp
index 1584f8e2..d31e990 100644
--- a/chrome_frame/chrome_frame.gyp
+++ b/chrome_frame/chrome_frame.gyp
@@ -114,6 +114,7 @@
         'test/exception_barrier_unittest.cc',
         'test/html_util_unittests.cc',
         'test/http_negotiate_unittest.cc',
+        'test/module_utils_test.cc',
         'test/policy_settings_unittest.cc',
         'test/simulate_input.h',
         'test/simulate_input.cc',
diff --git a/chrome_frame/chrome_tab.cc b/chrome_frame/chrome_tab.cc
index b81be69..c31e8ece 100644
--- a/chrome_frame/chrome_tab.cc
+++ b/chrome_frame/chrome_tab.cc
@@ -212,26 +212,25 @@
     logging::InitLogging(NULL, logging::LOG_ONLY_TO_SYSTEM_DEBUG_LOG,
                         logging::LOCK_LOG_FILE, logging::DELETE_OLD_LOG_FILE);
 
-    if (!DllRedirector::RegisterAsFirstCFModule()) {
-      // We are not the first ones in, get the module who registered first.
-      HMODULE original_module = DllRedirector::GetFirstCFModule();
-      DCHECK(original_module != NULL)
-          << "Could not get first CF module handle.";
-      HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase);
-      if (original_module != this_module) {
-        // Someone else was here first, try and get a pointer to their
-        // DllGetClassObject export:
-        g_dll_get_class_object_redir_ptr =
-            DllRedirector::GetDllGetClassObjectPtr(original_module);
-        DCHECK(g_dll_get_class_object_redir_ptr != NULL)
-            << "Found CF module with no DllGetClassObject export.";
-      }
+    DllRedirector* dll_redirector = Singleton<DllRedirector>::get();
+    DCHECK(dll_redirector);
+
+    if (!dll_redirector->RegisterAsFirstCFModule()) {
+      // Someone else was here first, try and get a pointer to their
+      // DllGetClassObject export:
+      g_dll_get_class_object_redir_ptr =
+          dll_redirector->GetDllGetClassObjectPtr();
+      DCHECK(g_dll_get_class_object_redir_ptr != NULL)
+          << "Found CF module with no DllGetClassObject export.";
     }
 
     // Enable ETW logging.
     logging::LogEventProvider::Initialize(kChromeFrameProvider);
   } else if (reason == DLL_PROCESS_DETACH) {
-    DllRedirector::UnregisterAsFirstCFModule();
+    DllRedirector* dll_redirector = Singleton<DllRedirector>::get();
+    DCHECK(dll_redirector);
+
+    dll_redirector->UnregisterAsFirstCFModule();
     g_patch_helper.UnpatchIfNeeded();
     delete g_exit_manager;
     g_exit_manager = NULL;
diff --git a/chrome_frame/module_utils.cc b/chrome_frame/module_utils.cc
index 62f07b81..cca852a 100644
--- a/chrome_frame/module_utils.cc
+++ b/chrome_frame/module_utils.cc
@@ -5,83 +5,197 @@
 #include "chrome_frame/module_utils.h"
 
 #include <atlbase.h>
+#include "base/file_path.h"
+#include "base/file_version_info.h"
 #include "base/logging.h"
+#include "base/path_service.h"
+#include "base/shared_memory.h"
+#include "base/utf_string_conversions.h"
+#include "base/version.h"
 
-const wchar_t kBeaconWindowClassName[] =
-    L"ChromeFrameBeaconWindowClass826C5D01-E355-4b23-8AC2-40650E0B7843";
+const char kSharedMemoryName[] = "ChromeFrameVersionBeacon";
+const uint32 kSharedMemorySize = 128;
+const uint32 kSharedMemoryLockTimeoutMs = 1000;
 
 // static
-ATOM DllRedirector::atom_ = 0;
+DllRedirector::DllRedirector() : first_module_handle_(NULL) {
+  // TODO(robertshield): Correctly construct the profile name here. Also allow
+  // for overrides to be taken from the environment.
+  shared_memory_.reset(new base::SharedMemory(ASCIIToWide(kSharedMemoryName)));
+}
+
+DllRedirector::DllRedirector(const char* shared_memory_name)
+    : shared_memory_name_(shared_memory_name), first_module_handle_(NULL) {
+  // TODO(robertshield): Correctly construct the profile name here. Also allow
+  // for overrides to be taken from the environment.
+  shared_memory_.reset(new base::SharedMemory(ASCIIToWide(shared_memory_name)));
+}
+
+DllRedirector::~DllRedirector() {
+  if (first_module_handle_) {
+    if (first_module_handle_ != reinterpret_cast<HMODULE>(&__ImageBase)) {
+      FreeLibrary(first_module_handle_);
+    }
+    first_module_handle_ = NULL;
+  }
+  UnregisterAsFirstCFModule();
+}
 
 bool DllRedirector::RegisterAsFirstCFModule() {
-  // This would imply that this module had already registered a window class
-  // which should never happen.
-  if (atom_) {
-    NOTREACHED();
+  DCHECK(first_module_handle_ == NULL);
+
+  // Build our own file version outside of the lock:
+  scoped_ptr<Version> our_version(GetCurrentModuleVersion());
+
+  // We sadly can't use the autolock here since we want to have a timeout.
+  // Be careful not to return while holding the lock. Also, attempt to do as
+  // little as possible while under this lock.
+  bool lock_acquired = shared_memory_->Lock(kSharedMemoryLockTimeoutMs);
+
+  if (!lock_acquired) {
+    // We couldn't get the lock in a reasonable amount of time, so fall
+    // back to loading our current version. We return true to indicate that the
+    // caller should not attempt to delegate to an already loaded version.
+    dll_version_.swap(our_version);
+    first_module_handle_ = reinterpret_cast<HMODULE>(&__ImageBase);
     return true;
   }
 
-  WNDCLASSEX wnd_class = {0};
-  wnd_class.cbSize = sizeof(WNDCLASSEX);
-  wnd_class.style = CS_GLOBALCLASS;
-  wnd_class.hCursor = LoadCursor(NULL, IDC_ARROW);
-  wnd_class.lpszClassName = kBeaconWindowClassName;
+  bool created_beacon = true;
+  bool result = shared_memory_->CreateNamed(shared_memory_name_.c_str(),
+                                            false,  // open_existing
+                                            kSharedMemorySize);
 
-  HMODULE this_module = reinterpret_cast<HMODULE>(&__ImageBase);
-  wnd_class.lpfnWndProc = reinterpret_cast<WNDPROC>(this_module);
+  if (!result) {
+    created_beacon = false;
 
-  atom_ = RegisterClassEx(&wnd_class);
-  return (atom_ != 0);
+    // We failed to create the shared memory segment, suggesting it may already
+    // exist: try to create it read-only.
+    result = shared_memory_->Open(shared_memory_name_.c_str(),
+                                  true /* read_only */);
+  }
+
+  if (result) {
+    // Map in the whole thing.
+    result = shared_memory_->Map(0);
+    DCHECK(shared_memory_->memory());
+
+    if (result) {
+      // Either write our own version number or read it in if it was already
+      // present in the shared memory section.
+      if (created_beacon) {
+        dll_version_.swap(our_version);
+
+        lstrcpynA(reinterpret_cast<char*>(shared_memory_->memory()),
+                  dll_version_->GetString().c_str(),
+                  std::min(kSharedMemorySize,
+                           dll_version_->GetString().length() + 1));
+
+        // Mark ourself as the first module in.
+        first_module_handle_ = reinterpret_cast<HMODULE>(&__ImageBase);
+      } else {
+        char buffer[kSharedMemorySize] = {0};
+        memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1);
+        dll_version_.reset(Version::GetVersionFromString(buffer));
+
+        if (!dll_version_.get() || dll_version_->Equals(*our_version.get())) {
+          // If we either couldn't parse a valid version out of the shared
+          // memory or we did parse a version and it is the same as our own,
+          // then pretend we're first in to avoid trying to load any other DLLs.
+          dll_version_.reset(our_version.release());
+          first_module_handle_ = reinterpret_cast<HMODULE>(&__ImageBase);
+          created_beacon = true;
+        }
+      }
+    } else {
+      NOTREACHED() << "Failed to map in version beacon.";
+    }
+  } else {
+    NOTREACHED() << "Could not create file mapping for version beacon, gle: "
+                 << ::GetLastError();
+  }
+
+  // Matching Unlock.
+  shared_memory_->Unlock();
+
+  return created_beacon;
 }
 
 void DllRedirector::UnregisterAsFirstCFModule() {
-  if (atom_) {
-    UnregisterClass(MAKEINTATOM(atom_), NULL);
-    atom_ = NULL;
+  if (base::SharedMemory::IsHandleValid(shared_memory_->handle())) {
+    bool lock_acquired = shared_memory_->Lock(kSharedMemoryLockTimeoutMs);
+    if (lock_acquired) {
+      // Free our handles. The last closed handle SHOULD result in it being
+      // deleted.
+      shared_memory_->Close();
+      shared_memory_->Unlock();
+    }
   }
 }
 
-HMODULE DllRedirector::GetFirstCFModule() {
-  WNDCLASSEX wnd_class = {0};
-  HMODULE oldest_module = NULL;
-  if (GetClassInfoEx(GetModuleHandle(NULL), kBeaconWindowClassName,
-                     &wnd_class)) {
-    oldest_module = reinterpret_cast<HMODULE>(wnd_class.lpfnWndProc);
-    // Handle older versions that store module pointer in a class info.
-    // TODO(amit): Remove this in future versions.
-    if (reinterpret_cast<HMODULE>(DefWindowProc) == oldest_module) {
-      WNDCLASSEX wnd_class = {0};
-      HWND hwnd = CreateWindow(kBeaconWindowClassName, L"temp_window",
-                               WS_POPUP, 0, 0, 0, 0, NULL, NULL, NULL, NULL);
-      DCHECK(IsWindow(hwnd));
-      if (hwnd) {
-        oldest_module = reinterpret_cast<HMODULE>(GetClassLongPtr(hwnd, 0));
-        DestroyWindow(hwnd);
-      }
-    }
-  }
-  return oldest_module;
-}
+LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectPtr() {
+  HMODULE first_module_handle = GetFirstModule();
 
-LPFNGETCLASSOBJECT DllRedirector::GetDllGetClassObjectPtr(HMODULE module) {
-  LPFNGETCLASSOBJECT proc_ptr = NULL;
-  HMODULE temp_handle = 0;
-  // Increment the module ref count while we have an pointer to its
-  // DllGetClassObject function.
-  if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
-                        reinterpret_cast<LPCTSTR>(module),
-                        &temp_handle)) {
-    proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>(
-        GetProcAddress(temp_handle, "DllGetClassObject"));
-    if (!proc_ptr) {
-      FreeLibrary(temp_handle);
-      LOG(ERROR) << "Module Scan: Couldn't get address of "
-                 << "DllGetClassObject: "
-                 << GetLastError();
-    }
-  } else {
-    LOG(ERROR) << "Module Scan: Could not increment module count: "
-               << GetLastError();
+  LPFNGETCLASSOBJECT proc_ptr = reinterpret_cast<LPFNGETCLASSOBJECT>(
+      GetProcAddress(first_module_handle, "DllGetClassObject"));
+  if (!proc_ptr) {
+    DLOG(ERROR) << "DllRedirector: Could get address of DllGetClassObject "
+                   "from first loaded module, GLE: "
+                << GetLastError();
+    // Oh boink, the first module we loaded was somehow bogus, make ourselves
+    // the first module again.
+    first_module_handle = reinterpret_cast<HMODULE>(&__ImageBase);
   }
   return proc_ptr;
 }
+
+Version* DllRedirector::GetCurrentModuleVersion() {
+  scoped_ptr<FileVersionInfo> file_version_info(
+      FileVersionInfo::CreateFileVersionInfoForCurrentModule());
+  DCHECK(file_version_info.get());
+
+  Version* current_version = NULL;
+  if (file_version_info.get()) {
+     current_version = Version::GetVersionFromString(
+        file_version_info->file_version());
+    DCHECK(current_version);
+  }
+
+  return current_version;
+}
+
+HMODULE DllRedirector::GetFirstModule() {
+  DCHECK(dll_version_.get())
+      << "Error: Did you call RegisterAsFirstCFModule() first?";
+
+  if (first_module_handle_ == NULL) {
+    first_module_handle_ = LoadVersionedModule(dll_version_.get());
+    if (!first_module_handle_) {
+      first_module_handle_ = reinterpret_cast<HMODULE>(&__ImageBase);
+    }
+  }
+
+  return first_module_handle_;
+}
+
+HMODULE DllRedirector::LoadVersionedModule(Version* version) {
+  DCHECK(version);
+
+  FilePath module_path;
+  PathService::Get(base::DIR_MODULE, &module_path);
+  DCHECK(!module_path.empty());
+
+  FilePath module_name = module_path.BaseName();
+  module_path = module_path.DirName()
+                           .Append(ASCIIToWide(version->GetString()))
+                           .Append(module_name);
+
+  HMODULE hmodule = LoadLibrary(module_path.value().c_str());
+  if (hmodule == NULL) {
+    DLOG(ERROR) << "Could not load reported module version "
+                << version->GetString();
+  }
+
+  return hmodule;
+}
+
diff --git a/chrome_frame/module_utils.h b/chrome_frame/module_utils.h
index 4cdf4b85..0da5472 100644
--- a/chrome_frame/module_utils.h
+++ b/chrome_frame/module_utils.h
@@ -8,33 +8,82 @@
 #include <ObjBase.h>
 #include <windows.h>
 
+#include "base/basictypes.h"
+#include "base/scoped_ptr.h"
+#include "base/shared_memory.h"
+#include "base/singleton.h"
+
+// Forward
+class Version;
+
+// A singleton class that provides a facility to register the version of the
+// current module as the only version that should be loaded system-wide. If
+// this module is not the first instance loaded in the system, then the version
+// that loaded first will be delegated to. This makes a few assumptions:
+//  1) That different versions of the module this code is in reside in
+//     neighbouring versioned directories, e.g.
+//       C:\foo\bar\1.2.3.4\my_module.dll
+//       C:\foo\bar\1.2.3.5\my_module.dll
+//  2) That the instance of this class will outlive the module that may be
+//     delegated to. That is to say, that this class only guarantees that the
+//     module is loaded as long as this instance is active.
+//  3) The module this is compiled into is built with version info.
 class DllRedirector {
  public:
-  // Attempts to register a window class under a well known name and appends to
-  // its extra data a handle to the current module. Will fail if the window
-  // class is already registered. This is intended to be called from DllMain
-  // under PROCESS_ATTACH.
-  static bool DllRedirector::RegisterAsFirstCFModule();
+  virtual ~DllRedirector();
+
+  // Attempts to register this Chrome Frame version as the first loaded version
+  // on the system. If this succeeds, return true. If it fails, it returns
+  // false meaning that there is another version already loaded somewhere and
+  // the caller should delegate to that version instead.
+  bool DllRedirector::RegisterAsFirstCFModule();
 
   // Unregisters the well known window class if we registered it earlier.
   // This is intended to be called from DllMain under PROCESS_DETACH.
-  static void DllRedirector::UnregisterAsFirstCFModule();
-
-  // Helper function that extracts the HMODULE parameter from our well known
-  // window class.
-  static HMODULE GetFirstCFModule();
+  void DllRedirector::UnregisterAsFirstCFModule();
 
   // Helper function to return the DllGetClassObject function pointer from
   // the given module. On success, the return value is non-null and module
   // will have had its reference count incremented.
-  static LPFNGETCLASSOBJECT GetDllGetClassObjectPtr(HMODULE module);
+  LPFNGETCLASSOBJECT GetDllGetClassObjectPtr();
 
- private:
-  // Use this to keep track of whether or not we have registered the window
-  // class in this module.
-  static ATOM atom_;
+ protected:
+  DllRedirector();
+  friend struct DefaultSingletonTraits<DllRedirector>;
+
+  // Constructor used for tests.
+  explicit DllRedirector(const char* shared_memory_name);
+
+  // Returns an HMODULE to the version of the module that should be loaded.
+  virtual HMODULE GetFirstModule();
+
+  // Returns the version of the current module or NULL if none can be found.
+  // The caller must free the Version.
+  virtual Version* GetCurrentModuleVersion();
+
+  // Attempt to load the specified version dll. Finds it by walking up one
+  // directory from our current module's location, then appending the newly
+  // found version number. The Version class in base will have ensured that we
+  // actually have a valid version and not e.g. ..\..\..\..\MyEvilFolder\.
+  virtual HMODULE LoadVersionedModule(Version* version);
+
+  // Shared memory segment that contains the version beacon.
+  scoped_ptr<base::SharedMemory> shared_memory_;
+
+  // The current version of the DLL to be loaded.
+  scoped_ptr<Version> dll_version_;
+
+  // The handle to the first version of this module that was loaded. This
+  // may refer to the current module, or another version of the same module
+  // that we go and load.
+  HMODULE first_module_handle_;
+
+  // Used for tests to override the name of the shared memory segment.
+  std::string shared_memory_name_;
 
   friend class ModuleUtilsTest;
+
+  DISALLOW_COPY_AND_ASSIGN(DllRedirector);
 };
 
 #endif  // CHROME_FRAME_MODULE_UTILS_H_
diff --git a/chrome_frame/test/module_utils_test.cc b/chrome_frame/test/module_utils_test.cc
new file mode 100644
index 0000000..4bb16ce
--- /dev/null
+++ b/chrome_frame/test/module_utils_test.cc
@@ -0,0 +1,292 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "chrome_frame/module_utils.h"
+
+#include "base/scoped_handle.h"
+#include "base/shared_memory.h"
+#include "base/utf_string_conversions.h"
+#include "base/version.h"
+#include "gtest/gtest.h"
+
+extern "C" IMAGE_DOS_HEADER __ImageBase;
+
+const char kMockVersionString[] = "42.42.42.42";
+const char kMockVersionString2[] = "133.33.33.7";
+
+const HMODULE kMockModuleHandle = reinterpret_cast<HMODULE>(42);
+const HMODULE kMockModuleHandle2 = reinterpret_cast<HMODULE>(43);
+
+const char kTestVersionBeaconName[] = "DllRedirectorTestVersionBeacon";
+const uint32 kSharedMemorySize = 128;
+
+// The maximum amount of time we are willing to let a test that Waits timeout
+// before failing.
+const uint32 kWaitTestTimeout = 20000;
+
+using base::win::ScopedHandle;
+
+class MockDllRedirector : public DllRedirector {
+ public:
+  explicit MockDllRedirector(const char* beacon_name)
+      : DllRedirector(beacon_name) {}
+
+  virtual HMODULE LoadVersionedModule() {
+    return kMockModuleHandle;
+  }
+
+  virtual Version* GetCurrentModuleVersion() {
+    return Version::GetVersionFromString(kMockVersionString);
+  }
+
+  virtual HMODULE GetFirstModule() {
+    return DllRedirector::GetFirstModule();
+  }
+
+  Version* GetFirstModuleVersion() {
+    // Lazy man's copy.
+    return Version::GetVersionFromString(dll_version_->GetString());
+  }
+
+  base::SharedMemory* shared_memory() {
+    return shared_memory_.get();
+  }
+};
+
+class MockDllRedirector2 : public MockDllRedirector {
+ public:
+  explicit MockDllRedirector2(const char* beacon_name)
+      : MockDllRedirector(beacon_name) {}
+
+  virtual HMODULE LoadVersionedModule() {
+    return kMockModuleHandle2;
+  }
+
+  virtual Version* GetCurrentModuleVersion() {
+    return Version::GetVersionFromString(kMockVersionString2);
+  }
+};
+
+class DllRedirectorTest : public testing::Test {
+ public:
+  virtual void SetUp() {
+    shared_memory_.reset(new base::SharedMemory);
+    mock_version_.reset(Version::GetVersionFromString(kMockVersionString));
+    mock_version2_.reset(Version::GetVersionFromString(kMockVersionString2));
+  }
+
+  virtual void TearDown() {
+    CloseBeacon();
+  }
+
+  void CreateVersionBeacon(const std::string& name,
+                           const std::string& version_string) {
+    // Abort the test if we can't create and map a new named memory object.
+    EXPECT_TRUE(shared_memory_->CreateNamed(name, false,
+                                            kSharedMemorySize));
+    EXPECT_TRUE(shared_memory_->Map(0));
+    EXPECT_TRUE(shared_memory_->memory());
+
+    if (shared_memory_->memory()) {
+      memcpy(shared_memory_->memory(),
+             version_string.c_str(),
+             std::min(kSharedMemorySize, version_string.length() + 1));
+    }
+  }
+
+  // Opens the named beacon and returns the version.
+  Version* OpenAndReadVersionFromBeacon(const std::string& name) {
+    // Abort the test if we can't open and map the named memory object.
+    EXPECT_TRUE(shared_memory_->Open(name, true /* read_only */));
+    EXPECT_TRUE(shared_memory_->Map(0));
+    EXPECT_TRUE(shared_memory_->memory());
+
+    char buffer[kSharedMemorySize] = {0};
+    memcpy(buffer, shared_memory_->memory(), kSharedMemorySize - 1);
+    return Version::GetVersionFromString(buffer);
+  }
+
+  void CloseBeacon() {
+    shared_memory_->Close();
+  }
+
+  // Shared memory segment that contains the version beacon.
+  scoped_ptr<base::SharedMemory> shared_memory_;
+  scoped_ptr<Version> mock_version_;
+  scoped_ptr<Version> mock_version2_;
+};
+
+TEST_F(DllRedirectorTest, RegisterAsFirstModule) {
+  scoped_ptr<MockDllRedirector> redirector(
+      new MockDllRedirector(kTestVersionBeaconName));
+  EXPECT_TRUE(redirector->RegisterAsFirstCFModule());
+
+  base::SharedMemory* redirector_memory = redirector->shared_memory();
+  char buffer[kSharedMemorySize] = {0};
+  memcpy(buffer, redirector_memory->memory(), kSharedMemorySize - 1);
+  scoped_ptr<Version> redirector_version(Version::GetVersionFromString(buffer));
+  ASSERT_TRUE(redirector_version.get());
+  EXPECT_TRUE(redirector_version->Equals(*mock_version_.get()));
+  redirector_memory = NULL;
+
+  scoped_ptr<Version> memory_version(
+      OpenAndReadVersionFromBeacon(kTestVersionBeaconName));
+  ASSERT_TRUE(memory_version.get());
+  EXPECT_TRUE(redirector_version->Equals(*memory_version.get()));
+  CloseBeacon();
+
+  redirector.reset();
+  EXPECT_FALSE(shared_memory_->Open(kTestVersionBeaconName, true));
+}
+
+TEST_F(DllRedirectorTest, SecondModuleLoading) {
+  scoped_ptr<MockDllRedirector> first_redirector(
+      new MockDllRedirector(kTestVersionBeaconName));
+  EXPECT_TRUE(first_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<MockDllRedirector2> second_redirector(
+      new MockDllRedirector2(kTestVersionBeaconName));
+  EXPECT_FALSE(second_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<Version> first_redirector_version(
+      first_redirector->GetFirstModuleVersion());
+  scoped_ptr<Version> second_redirector_version(
+      second_redirector->GetFirstModuleVersion());
+
+  EXPECT_TRUE(
+      second_redirector_version->Equals(*first_redirector_version.get()));
+  EXPECT_TRUE(
+      second_redirector_version->Equals(*mock_version_.get()));
+}
+
+// This test ensures that the beacon remains alive as long as there is a single
+// module that used it to determine its version still loaded.
+TEST_F(DllRedirectorTest, TestBeaconOwnershipHandoff) {
+  scoped_ptr<MockDllRedirector> first_redirector(
+      new MockDllRedirector(kTestVersionBeaconName));
+  EXPECT_TRUE(first_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<MockDllRedirector2> second_redirector(
+      new MockDllRedirector2(kTestVersionBeaconName));
+  EXPECT_FALSE(second_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<Version> first_redirector_version(
+      first_redirector->GetFirstModuleVersion());
+  scoped_ptr<Version> second_redirector_version(
+      second_redirector->GetFirstModuleVersion());
+
+  EXPECT_TRUE(
+      second_redirector_version->Equals(*first_redirector_version.get()));
+  EXPECT_TRUE(
+      second_redirector_version->Equals(*mock_version_.get()));
+
+  // Clear out the first redirector. The second, still holding a reference
+  // to the shared memory should ensure that the beacon stays alive.
+  first_redirector.reset();
+
+  scoped_ptr<MockDllRedirector2> third_redirector(
+      new MockDllRedirector2(kTestVersionBeaconName));
+  EXPECT_FALSE(third_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<Version> third_redirector_version(
+      third_redirector->GetFirstModuleVersion());
+
+  EXPECT_TRUE(
+      third_redirector_version->Equals(*second_redirector_version.get()));
+  EXPECT_TRUE(
+      third_redirector_version->Equals(*mock_version_.get()));
+
+  // Now close all remaining redirectors, which should destroy the beacon.
+  second_redirector.reset();
+  third_redirector.reset();
+
+  // Now create a fourth, expecting that this time it should be the first in.
+  scoped_ptr<MockDllRedirector2> fourth_redirector(
+      new MockDllRedirector2(kTestVersionBeaconName));
+  EXPECT_TRUE(fourth_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<Version> fourth_redirector_version(
+      fourth_redirector->GetFirstModuleVersion());
+
+  EXPECT_TRUE(
+      fourth_redirector_version->Equals(*mock_version2_.get()));
+}
+
+struct LockSquattingThreadParams {
+  ScopedHandle is_squatting;
+  ScopedHandle time_to_die;
+};
+
+DWORD WINAPI LockSquattingThread(void* in_params) {
+  LockSquattingThreadParams* params =
+      reinterpret_cast<LockSquattingThreadParams*>(in_params);
+  DCHECK(params);
+
+  // Grab the lock for the shared memory region and hold onto it.
+  base::SharedMemory squatter(ASCIIToWide(kTestVersionBeaconName));
+  base::SharedMemoryAutoLock squatter_lock(&squatter);
+
+  // Notify our caller that we're squatting.
+  BOOL ret = ::SetEvent(params->is_squatting);
+  DCHECK(ret);
+
+  // And then wait to be told to shut down.
+  DWORD result = ::WaitForSingleObject(params->time_to_die, kWaitTestTimeout);
+  EXPECT_EQ(WAIT_OBJECT_0, result);
+
+  return 0;
+}
+
+// Test that the Right Thing happens when someone else is holding onto the
+// beacon lock and not letting go. (The Right Thing being that the redirector
+// assumes that it is the right version and doesn't attempt to use the shared
+// memory region.)
+TEST_F(DllRedirectorTest, LockSquatting) {
+  scoped_ptr<MockDllRedirector> first_redirector(
+      new MockDllRedirector(kTestVersionBeaconName));
+  EXPECT_TRUE(first_redirector->RegisterAsFirstCFModule());
+
+  LockSquattingThreadParams params;
+  params.is_squatting.Set(::CreateEvent(NULL, FALSE, FALSE, NULL));
+  params.time_to_die.Set(::CreateEvent(NULL, FALSE, FALSE, NULL));
+  DWORD tid = 0;
+  ScopedHandle lock_squat_thread(
+      ::CreateThread(NULL, 0, LockSquattingThread, &params, 0, &tid));
+
+  // Make sure the squatter has started squatting.
+  DWORD wait_result = ::WaitForSingleObject(params.is_squatting,
+                                            kWaitTestTimeout);
+  EXPECT_EQ(WAIT_OBJECT_0, wait_result);
+
+  scoped_ptr<MockDllRedirector2> second_redirector(
+      new MockDllRedirector2(kTestVersionBeaconName));
+  EXPECT_TRUE(second_redirector->RegisterAsFirstCFModule());
+
+  scoped_ptr<Version> second_redirector_version(
+      second_redirector->GetFirstModuleVersion());
+  EXPECT_TRUE(
+      second_redirector_version->Equals(*mock_version2_.get()));
+
+  // Shut down the squatting thread.
+  DWORD ret = ::SetEvent(params.time_to_die);
+  DCHECK(ret);
+
+  wait_result = ::WaitForSingleObject(lock_squat_thread, kWaitTestTimeout);
+  EXPECT_EQ(WAIT_OBJECT_0, wait_result);
+}
+
+TEST_F(DllRedirectorTest, BadVersionNumber) {
+  std::string bad_version("I am not a version number");
+  CreateVersionBeacon(kTestVersionBeaconName, bad_version);
+
+  // The redirector should fail to read the version number and defer to
+  // its own version.
+  scoped_ptr<MockDllRedirector> first_redirector(
+      new MockDllRedirector(kTestVersionBeaconName));
+  EXPECT_TRUE(first_redirector->RegisterAsFirstCFModule());
+
+  HMODULE first_module = first_redirector->GetFirstModule();
+  EXPECT_EQ(reinterpret_cast<HMODULE>(&__ImageBase), first_module);
+}
+