目录
1、源码及框架分析
SGI-STL30版本源代码,map和set的源代码在map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件中。 map和set的实现结构框架核心部分截取出来如下:
// set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>
// map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>
// stl_set.h
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
// typedefs:
typedef Key key_type;
typedef Key value_type;
private:
typedef rb_tree<key_type, value_type,
identity<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing set
};
// stl_map.h
template <class Key, class T, class Compare = less<Key>,
class Alloc = alloc>
class map {
public:
// typedefs:
typedef Key key_type;
typedef T mapped_type;
typedef pair<const Key, T> value_type;
private:
typedef rb_tree<key_type, value_type,
select1st<value_type>, key_compare, Alloc> rep_type;
rep_type t; // red-black tree representing map
};
// stl_tree.h
struct __rb_tree_node_base {
typedef __rb_tree_color_type color_type;
typedef __rb_tree_node_base* base_ptr;
color_type color;
base_ptr parent;
base_ptr left;
base_ptr right;
};
// stl_tree.h
template <class Key, class Value, class KeyOfValue, class Compare,
class Alloc = alloc>
class rb_tree {
protected:
typedef void* void_pointer;
typedef __rb_tree_node_base* base_ptr;
typedef __rb_tree_node<Value> rb_tree_node;
typedef rb_tree_node* link_type;
typedef Key key_type;
typedef Value value_type;
public:
// insert
pair<iterator, bool> insert_unique(const value_type& x);
// erase and find
size_type erase(const key_type& x);
iterator find(const key_type& x);
protected:
size_type node_count; // keeps track of size of tree
link_type header;
};
template <class Value>
struct __rb_tree_node : public __rb_tree_node_base {
typedef __rb_tree_node<Value>* link_type;
Value value_field;
};
template <class Key, class Value, class KeyOfValue, class Compare,class Alloc = alloc>
插入用Value,删除查找用Key,KeyOfValue是一个仿函数,取Value中的Key值。
2、模拟实现map和set
2.1 复用的红黑树框架及Insert
1. 这里相比源码调整一下,key参数就用K,value参数就用V,红黑树中的数据类型,我们使用T。
2. 源码中的pair的 < 比较,比较了key和value,但是红黑树只需要比较key,所以MyMap和MySet各自实现了一个只比较key的仿函数KfromT。MySet是为了兼容MyMap,所以也要实现。
3. const保证了不能修改key。
RBTree<K, pair<const K, V>, MapKfromT> _t;
RBTree<K, const K, SetKfromT> _t;
template<class K, class T, class KfromT>
class RBTree{};
// 源码中 pair 支持的 < 重载
//template <class T1, class T2>
//bool operator<(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
// return lhs.first < rhs.first || (!(rhs.first < lhs.first) && lhs.second < rhs.second);
//}
// Mymap.h
namespace Lzc
{
template<class K, class V>
class MyMap
{
struct MapKfromT
{
const K& operator()(const pair<const K, V>& kv)
{
return kv.first;
}
};
public:
bool insert(const pair<const K, V>& kv)
{
return _t.Insert(kv);
}
private:
RBTree<K, pair<const K, V>, MapKfromT> _t;
};
}
// Myset.h
namespace Lzc
{
template<class K>
class MySet
{
struct SetKfromT
{
const K& operator()(const K& k)
{
return k;
}
};
public:
bool insert(const K& k)
{
return _t.Insert(k);
}
private:
RBTree<K, const K, SetKfromT> _t;
};
}
// RBTree.h
namespace Lzc
{
enum Color
{
RED,
BLACK
};
template<class T>
struct RBTreeNode
{
T _data;
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
Color _col;
RBTreeNode(const T& data)
:_data(data)
, _left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _col(RED)
{ }
};
template<class K, class T, class KfromT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
KfromT KfT;
bool Insert(const T& data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return true;
}
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (KfT(data) > KfT(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else if (KfT(data) < KfT(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else
{
return false;
}
}
cur = new Node(data);
if (KfT(data) > KfT(parent->_data))
parent->_right = cur;
else
parent->_left = cur;
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
Node* uncle;
if (parent == grandfather->_left)
{
// g
// p u
uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
// g
// u p
uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
if (parent == nullptr)
_root->_col = BLACK;
return true;
}
void RotateR(Node* parent)
{
Node* pParent = parent->_parent;
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
subL->_parent = pParent;
if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
{
_root = subL;
}
else
{
if (pParent->_left == parent)
pParent->_left = subL;
else
pParent->_right = subL;
}
}
void RotateL(Node* parent)
{
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
subR->_left = parent;
parent->_parent = subR;
subR->_parent = pParent;
if (pParent == nullptr)
_root = subR;
else
{
if (pParent->_left == parent)
pParent->_left = subR;
else
pParent->_right = subR;
}
}
Node* Find(const K& key)
{
Node* cur = _root;
while (cur)
{
if (key > KfT(cur->_data))
cur = cur->_right;
else if (key < KfT(cur->_data))
cur = cur->_left;
else
return cur;
}
return nullptr;
}
~RBTree()
{
Destroy(_root);
_root = nullptr;
}
void Destroy(Node* root)
{
if (root == nullptr)
return;
Destroy(root->_left);
Destroy(root->_right);
delete root;
}
private:
Node* _root = nullptr;
};
}
2.2 iterator的实现
2.2.1 iterator的核心源码
typedef bool __rb_tree_color_type;
const __rb_tree_color_type __rb_tree_red = false;
const __rb_tree_color_type __rb_tree_black = true;
struct __rb_tree_base_iterator {
typedef __rb_tree_node_base::base_ptr base_ptr;
base_ptr node;
void increment() {
if (node->right != 0) {
node = node->right;
while (node->left != 0)
node = node->left;
} else {
base_ptr y = node->parent;
while (node == y->right) {
node = y;
y = y->parent;
}
if (node->right != y)
node = y;
}
}
void decrement() {
if (node->color == __rb_tree_red && node->parent->parent == node) {
node = node->right;
} else if (node->left != 0) {
base_ptr y = node->left;
while (y->right != 0)
y = y->right;
node = y;
} else {
base_ptr y = node->parent;
while (node == y->left) {
node = y;
y = y->parent;
}
node = y;
}
}
};
template <class Value, class Ref, class Ptr>
struct __rb_tree_iterator : public __rb_tree_base_iterator {
typedef Value value_type;
typedef Ref reference;
typedef Ptr pointer;
typedef __rb_tree_iterator<Value, Value&, Value*> iterator;
__rb_tree_iterator() {}
__rb_tree_iterator(link_type x) { node = x; }
__rb_tree_iterator(const iterator& it) { node = it.node; }
reference operator*() const { return link_type(node)->value_field; }
#ifndef __SGI_STL_NO_ARROW_OPERATOR
pointer operator->() const { return &(operator*()); }
#endif /* __SGI_STL_NO_ARROW_OPERATOR */
self& operator++() {
increment();
return *this;
}
self& operator--() {
decrement();
return *this;
}
inline bool operator==(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
return x.node == y.node;
}
inline bool operator!=(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
return x.node != y.node;
}
};
2.2.2 iterator的实现思路
注意:源码是有一个红色的头节点(作为end()),指向_root,我们没有使用头节点,用nullptr做end()。
1. 整体思路与list的iterator一致,封装节点的指针,迭代器类模板多传Ref和Ptr两个参数,一份模板实现iterator和const_iterator。
2. 重点是operator++和operator--的实现。operator++走中序遍历,左中右,
当左为空,表示左访问完了,访问中(其实只能访问中,给的节点就是访问完的中节点),
如果右不为空,(在右子树中进行,左中右,访问右子树的最左节点)。
如果右为空,(那么这个子树已经访问完了,如果这个子树是外面的右子树,那么外面一层的子树也访问完了,直到子树是外面一层的左子树,这时左子树访问完了,访问中。即使走到了最右节点,向上走,parent为nullptr,就返回nullptr)。
然后更新迭代器中的节点指针,返回*this。
operator--就是走右中左,基本相同。
3. begin()和end()。begin()就给最左节点,end()给nullptr,但是,--end()呢?
所以给迭代器类模板的增加一个成员变量_root(红黑树的根节点),--end()就可以是最右节点。
4. const 对象的const迭代器不能修改数据,所以
typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
typedef RBTreeIterator<T, const T&, const T*, const Node*> ConstIterator;
template<class T, class Ref, class Ptr,class NodePtr>
struct RBTreeIterator
{
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ref, Ptr, NodePtr> Self;
typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
NodePtr _node;
NodePtr _root;
RBTreeIterator(NodePtr node, NodePtr root)
:_node(node)
, _root(root)
{
}
// Iterator中Iterator->Iterator,ConstIterator中Iterator->ConstIterator
RBTreeIterator(const Iterator& it)
:_node(it._node)
, _root(it._root)
{
}
Self& operator++()
{
if (_node->_right)
{
NodePtr cur = _node->_right;
while (cur->_left)
{
cur = cur->_left;
}
_node = cur;
}
else
{
NodePtr cur = _node;
NodePtr parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Self& operator--()
{
// --end,因为end == nullptr,所以最右节点需要_root
if (_node == nullptr)
{
NodePtr MostRight = _root;
while (MostRight && MostRight->_right)
{
MostRight = MostRight->_right;
}
_node = MostRight;
}
else if (_node->_left)
{
NodePtr cur = _node->_left;
while (cur->_right)
{
cur = cur->_right;
}
_node = cur;
}
else
{
NodePtr cur = _node;
NodePtr parent = cur->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Ref operator*()
{
assert(_node);
return _node->_data;
}
Ptr operator->()
{
assert(_node);
return &(_node->_data);
}
bool operator!=(const Self& s) const
{
return _node != s._node;
}
bool operator==(const Self& s) const
{
return _node == s._node;
}
};
2.3 map支持[ ]
MyMap要支持[ ]主要需要修改Insert返回值,
修改RBtree中的insert返回值为pair<Iterator,bool> Insert(const T& data),
插入失败,就返回相同的key的value的引用。
插入成功,就返回key的value(默认值)的引用。
2.4 MyMap和MySet的代码实现
2.4.1 MyMap.h
#pragma once
#include "RBTree.h"
namespace Lzc
{
template<class K, class V>
class MyMap
{
struct MapKfromT
{
const K& operator()(const pair<const K, V>& kv)
{
return kv.first;
}
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKfromT>::Iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKfromT>::ConstIterator const_iterator;
pair<iterator, bool> insert(const pair<const K, V>& kv)
{
return _t.Insert(kv);
}
V& operator[](const K& k)
{
iterator ret = _t.Insert({ k, V() }).first;
return ret->second;
}
iterator find(const K& key)
{
return _t.Find(key);
}
const_iterator find(const K& key) const
{
return _t.Find(key);
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
private:
RBTree<K, pair<const K, V>, MapKfromT> _t;
};
}
2.4.2 MySet.h
#pragma once
#include "RBTree.h"
namespace Lzc
{
template<class K>
class MySet
{
struct SetKfromT
{
const K& operator()(const K& key)
{
return key;
}
};
public:
typedef typename RBTree<K, const K, SetKfromT>::Iterator iterator;
typedef typename RBTree<K, const K, SetKfromT>::ConstIterator const_iterator;
pair<iterator, bool> insert(const K& key)
{
return _t.Insert(key);
}
iterator find(const K& key)
{
return _t.Find(key);
}
const_iterator find(const K& key) const
{
return _t.Find(key);
}
iterator begin()
{
return _t.Begin();
}
iterator end()
{
return _t.End();
}
const_iterator begin() const
{
return _t.Begin();
}
const_iterator end() const
{
return _t.End();
}
private:
RBTree<K, const K, SetKfromT> _t;
};
}
2.4.3 RBTree.h
#pragma once
#include <iostream>
#include <assert.h>
using namespace std;
namespace Lzc
{
enum Color
{
RED,
BLACK
};
template<class T>
struct RBTreeNode
{
T _data;
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
Color _col;
RBTreeNode(const T& data)
:_data(data)
, _left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _col(RED)
{ }
};
template<class T, class Ref, class Ptr,class NodePtr>
struct RBTreeIterator
{
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ref, Ptr, NodePtr> Self;
typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
NodePtr _node;
NodePtr _root;
RBTreeIterator(NodePtr node, NodePtr root)
:_node(node)
, _root(root)
{
}
// Iterator中Iterator->Iterator,ConstIterator中Iterator->ConstIterator
RBTreeIterator(const Iterator& it)
:_node(it._node)
, _root(it._root)
{
}
Self& operator++()
{
if (_node->_right)
{
NodePtr cur = _node->_right;
while (cur->_left)
{
cur = cur->_left;
}
_node = cur;
}
else
{
NodePtr cur = _node;
NodePtr parent = cur->_parent;
while (parent && cur == parent->_right)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Self& operator--()
{
// --end,因为end == nullptr,所以最右节点需要_root
if (_node == nullptr)
{
NodePtr MostRight = _root;
while (MostRight && MostRight->_right)
{
MostRight = MostRight->_right;
}
_node = MostRight;
}
else if (_node->_left)
{
NodePtr cur = _node->_left;
while (cur->_right)
{
cur = cur->_right;
}
_node = cur;
}
else
{
NodePtr cur = _node;
NodePtr parent = cur->_parent;
while (parent && cur == parent->_left)
{
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Ref operator*()
{
return _node->_data;
}
Ptr operator->()
{
return &(_node->_data);
}
bool operator!=(const Self& s) const
{
return _node != s._node;
}
bool operator==(const Self& s) const
{
return _node == s._node;
}
};
template<class K, class T, class KfromT>
class RBTree
{
typedef RBTreeNode<T> Node;
public:
typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
typedef RBTreeIterator<T, const T&, const T*, const Node*> ConstIterator;
RBTree()
:_root(nullptr)
{ }
RBTree(const RBTree& rbt)
{
_root = Copy(rbt._root, nullptr);
}
RBTree& operator=(const RBTree& rbt)
{
if (this != &rbt)
{
RBTree tmp(rbt);
swap(_root, tmp._root);
}
return *this;
}
~RBTree()
{
Destroy(_root);
_root = nullptr;
}
Iterator Begin()
{
Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return Iterator(cur, _root);
}
Iterator End()
{
return Iterator(nullptr, _root);
}
ConstIterator Begin() const
{
const Node* cur = _root;
while (cur && cur->_left)
{
cur = cur->_left;
}
return ConstIterator(cur, _root);
}
ConstIterator End() const
{
return ConstIterator(nullptr, _root);
}
KfromT KfT;
pair<Iterator, bool> Insert(const T& data)
{
if (_root == nullptr)
{
_root = new Node(data);
_root->_col = BLACK;
return { Iterator(_root,_root),true };
}
Node* parent = nullptr;
Node* cur = _root;
while (cur)
{
if (KfT(data) > KfT(cur->_data))
{
parent = cur;
cur = cur->_right;
}
else if (KfT(data) < KfT(cur->_data))
{
parent = cur;
cur = cur->_left;
}
else
{
return { Iterator(cur,_root),false };
}
}
cur = new Node(data);
Node* newnode = cur; // cur可能后面会更新
if (KfT(data) > KfT(parent->_data))
parent->_right = cur;
else
parent->_left = cur;
cur->_parent = parent;
while (parent && parent->_col == RED)
{
Node* grandfather = parent->_parent;
Node* uncle;
if (parent == grandfather->_left)
{
// g
// p u
uncle = grandfather->_right;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_left)
{
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else
{
// g
// u p
uncle = grandfather->_left;
if (uncle && uncle->_col == RED)
{
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else
{
if (cur == parent->_right)
{
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else
{
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
if (parent == nullptr)
_root->_col = BLACK;
return { Iterator(newnode,_root),true };
}
void RotateR(Node* parent)
{
Node* pParent = parent->_parent;
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
parent->_parent = subL;
subL->_parent = pParent;
if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
{
_root = subL;
}
else
{
if (pParent->_left == parent)
pParent->_left = subL;
else
pParent->_right = subL;
}
}
void RotateL(Node* parent)
{
Node* pParent = parent->_parent;
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
subR->_left = parent;
parent->_parent = subR;
subR->_parent = pParent;
if (pParent == nullptr)
_root = subR;
else
{
if (pParent->_left == parent)
pParent->_left = subR;
else
pParent->_right = subR;
}
}
Iterator Find(const K& key)
{
Node* cur = _root;
while (cur)
{
if (key > KfT(cur->_data))
cur = cur->_right;
else if (key < KfT(cur->_data))
cur = cur->_left;
else
return Iterator(cur, _root);
}
return Iterator(nullptr, _root);
}
ConstIterator Find(const K& key) const
{
const Node* cur = _root;
while (cur)
{
if (key > KfT(cur->_data))
cur = cur->_right;
else if (key < KfT(cur->_data))
cur = cur->_left;
else
return ConstIterator(cur, _root);
}
return ConstIterator(nullptr, _root);
}
private:
Node* Copy(Node* root,Node* parent) // 后序
{
if (root == nullptr)
return nullptr;
Node* newNode = new Node(root->_data);
newNode->_col = root->_col;
newNode->_left = Copy(root->_left,newNode);
newNode->_right = Copy(root->_right,newNode);
newNode->_parent = parent;
return newNode;
}
void Destroy(Node* root)
{
if (root == nullptr)
return;
Destroy(root->_left);
Destroy(root->_right);
delete root;
}
Node* _root;
};
}
2.4.4 Test.cpp
#include "RBTree.h"
#include "MySet.h"
#include "MyMap.h"
struct SetKfromT
{
const int& operator()(const int& key)
{
return key;
}
};
void TestIterators() {
// 1. 测试普通迭代器
Lzc::RBTree<int, int, SetKfromT> tree;
tree.Insert(10);
tree.Insert(20);
tree.Insert(5);
tree.Insert(15);
cout << "Testing regular iterator:" << endl;
for (auto it = tree.Begin(); it != tree.End(); ++it) {
cout << *it << " "; // 应输出:5 10 15 20
*it += 1; // 测试通过普通迭代器修改元素
}
cout << endl;
// 验证修改后的值
cout << "After modification:" << endl;
for (auto it = tree.Begin(); it != tree.End(); ++it) {
cout << *it << " "; // 应输出:6 11 16 21
}
cout << endl;
// 2. 测试const迭代器
const auto& constTree = tree;
cout << "Testing const iterator:" << endl;
for (auto it = constTree.Begin(); it != constTree.End(); ++it) {
cout << *it << " "; // 应输出:6 11 16 21
// *it += 1; // 这行应该编译失败,验证const迭代器不可修改
}
cout << endl;
// 3. 测试iterator转const_iterator的兼容性
cout << "Testing iterator to const_iterator conversion:" << endl;
auto it = tree.Begin();
Lzc::RBTree<int, int, SetKfromT>::ConstIterator cit = it;
cout << *cit << endl; // 应输出第一个元素6
assert(*cit == *it);
// 4. 测试反向遍历
cout << "Testing reverse traversal using --:" << endl;
auto rit = tree.End();
--rit; // 移动到最后一个元素
for (; rit != tree.Begin(); --rit) {
cout << *rit << " "; // 应输出:21 16 11
}
cout << *rit << endl; // 输出最后一个元素6
cout << "All iterator tests passed!" << endl;
}
void TestMapBasic()
{
Lzc::MyMap<int, string> map;
map.insert({ 10, "A" });
map.insert({ 20, "B" });
map.insert({ 5, "C" });
map.insert({ 15, "D" });
// 测试查找
assert(map.find(10) != map.end());
assert(map.find(20) != map.end());
assert(map.find(5) != map.end());
assert(map.find(15) != map.end());
assert(map.find(100) == map.end()); // 不存在的 key
// 测试 operator[]
map[30] = "E"; // 插入新键
assert(map.find(30) != map.end());
assert(map[30] == "E");
map[10] = "AA"; // 修改已有键
assert(map[10] == "AA");
// 测试遍历(中序输出)
cout << "Map InOrder Traversal:" << endl;
for (auto it = map.begin(); it != map.end(); ++it)
{
cout << it->first << ":" << it->second << " ";
}
cout << endl; // 应输出:5:C 10:AA 15:D 20:B 30:E
cout << "TestMapBasic passed!" << endl;
}
void TestSetBasic()
{
Lzc::MySet<int> set;
set.insert(10);
set.insert(20);
set.insert(5);
set.insert(15);
// 测试查找
assert(set.find(10) != set.end());
assert(set.find(20) != set.end());
assert(set.find(5) != set.end());
assert(set.find(15) != set.end());
assert(set.find(100) == set.end()); // 不存在的 key
// 测试遍历(中序输出)
cout << "Set InOrder Traversal:" << endl;
for (auto it = set.begin(); it != set.end(); ++it)
{
cout << *it << " ";
}
cout << endl; // 应输出:5 10 15 20
cout << "TestSetBasic passed!" << endl;
}
void TestIteratorDecrement()
{
Lzc::MyMap<int, string> map;
map.insert({ 10, "A" });
map.insert({ 20, "B" });
map.insert({ 5, "C" });
map.insert({ 15, "D" });
// 测试 -- 操作
auto it = map.find(20);
assert(it != map.end());
--it; // 应该指向 15
assert(it->first == 15);
--it; // 应该指向 10
assert(it->first == 10);
--it; // 应该指向 5
assert(it->first == 5);
--it; // 应该等于 begin(),再 -- 会未定义行为(不测试)
// 测试 --end()
it = map.end();
--it; // 应该指向最大的元素 20
assert(it->first == 20);
cout << "TestIteratorDecrement passed!" << endl;
}
void TestCopyAndAssignment()
{
Lzc::MyMap<int, string> map1;
map1.insert({ 10, "A" });
map1.insert({ 20, "B" });
map1.insert({ 5, "C" });
// 测试拷贝构造
Lzc::MyMap<int, string> map2(map1);
assert(map2.find(10) != map2.end());
assert(map2.find(20) != map2.end());
assert(map2.find(5) != map2.end());
// 测试赋值运算符
Lzc::MyMap<int, string> map3;
map3 = map1;
assert(map3.find(10) != map3.end());
assert(map3.find(20) != map3.end());
assert(map3.find(5) != map3.end());
// 修改原 map,确保深拷贝
map1.insert({ 30, "D" });
assert(map2.find(30) == map2.end()); // map2 不应受影响
assert(map3.find(30) == map3.end()); // map3 不应受影响
cout << "TestCopyAndAssignment passed!" << endl;
}
void TestEmptyContainer()
{
Lzc::MyMap<int, string> emptyMap;
assert(emptyMap.begin() == emptyMap.end()); // 空 map 的 begin == end
Lzc::MySet<int> emptySet;
assert(emptySet.begin() == emptySet.end()); // 空 set 的 begin == end
cout << "TestEmptyContainer passed!" << endl;
}
#include <random>
void TestRandomData()
{
Lzc::MyMap<int, int> map;
const int N = 10000;
// 插入 10000 个随机数
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<> dis(1, 100000);
for (int i = 0; i < N; ++i)
{
int key = dis(gen);
map[key] = key; // 使用 operator[]
}
// 检查是否能正确查找
for (int i = 0; i < 100; ++i) // 随机抽查 100 个
{
int key = dis(gen);
auto it = map.find(key);
if (it != map.end())
{
assert(it->second == key);
}
}
cout << "TestRandomData passed!" << endl;
}
int main()
{
TestIterators();
TestMapBasic();
TestSetBasic();
TestIteratorDecrement();
TestCopyAndAssignment();
TestEmptyContainer();
TestRandomData();
cout << "All tests passed!" << endl;
return 0;
}