C++进阶——封装红黑树实现map和set

目录

1、源码及框架分析

2、模拟实现map和set

2.1 复用的红黑树框架及Insert

2.2 iterator的实现

2.2.1 iterator的核心源码

2.2.2 iterator的实现思路

2.3 map支持[ ]

2.4 map和set的代码实现

2.4.1 MyMap.h

2.4.2 MySet.h

2.4.3 RBTree.h

2.4.4 Test.cpp


1、源码及框架分析

SGI-STL30版本源代码,mapset源代码map/set/stl_map.h/stl_set.h/stl_tree.h等几个头文件中。 map和set的实现结构框架核心部分截取出来如下:

// set
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_set.h>
#include <stl_multiset.h>

// map
#ifndef __SGI_STL_INTERNAL_TREE_H
#include <stl_tree.h>
#endif
#include <stl_map.h>
#include <stl_multimap.h>

// stl_set.h
template <class Key, class Compare = less<Key>, class Alloc = alloc>
class set {
public:
    // typedefs:
    typedef Key key_type;
    typedef Key value_type;

private:
    typedef rb_tree<key_type, value_type,
                    identity<value_type>, key_compare, Alloc> rep_type;
    rep_type t;  // red-black tree representing set
};

// stl_map.h
template <class Key, class T, class Compare = less<Key>,
 class Alloc = alloc>
class map {
public:
    // typedefs:
    typedef Key key_type;
    typedef T mapped_type;
    typedef pair<const Key, T> value_type;

private:
    typedef rb_tree<key_type, value_type,
                    select1st<value_type>, key_compare, Alloc> rep_type;
    rep_type t;  // red-black tree representing map
};

// stl_tree.h
struct __rb_tree_node_base {
    typedef __rb_tree_color_type color_type;
    typedef __rb_tree_node_base* base_ptr;
    color_type color;
    base_ptr parent;
    base_ptr left;
    base_ptr right;
};

// stl_tree.h
template <class Key, class Value, class KeyOfValue, class Compare,
 class Alloc = alloc>
class rb_tree {
protected:
    typedef void* void_pointer;
    typedef __rb_tree_node_base* base_ptr;
    typedef __rb_tree_node<Value> rb_tree_node;
    typedef rb_tree_node* link_type;
    typedef Key key_type;
    typedef Value value_type;

public:
    // insert
    pair<iterator, bool> insert_unique(const value_type& x);

    // erase and find
    size_type erase(const key_type& x);
    iterator find(const key_type& x);

protected:
    size_type node_count; // keeps track of size of tree
    link_type header;
};

template <class Value>
struct __rb_tree_node : public __rb_tree_node_base {
    typedef __rb_tree_node<Value>* link_type;
    Value value_field;
};

template <class Key, class Value, class KeyOfValue, class Compare,class Alloc = alloc>

插入Value删除查找KeyKeyOfValue是一个仿函数Value中的Key值

2、模拟实现map和set

2.1 复用的红黑树框架及Insert

1. 这里相比源码调整一下,key参数就用Kvalue参数就用V红黑树中的数据类型,我们使用T

2. 源码中的pair的 < 比较,比较了key和value,但是红黑树只需要比较key,所以MyMapMySet各自实现了一个只比较key仿函数KfromT。MySet是为了兼容MyMap,所以也要实现。

3. const保证了不能修改key

RBTree<K, pair<const K, V>, MapKfromT> _t;

RBTree<K, const K, SetKfromT> _t;

template<class K, class T, class KfromT>

class RBTree{};

// 源码中 pair 支持的 < 重载
//template <class T1, class T2>
//bool operator<(const pair<T1, T2>& lhs, const pair<T1, T2>& rhs) {
//    return lhs.first < rhs.first || (!(rhs.first < lhs.first) && lhs.second < rhs.second);
//}

// Mymap.h
namespace Lzc
{
	template<class K, class V>
	class MyMap
	{
		struct MapKfromT
		{
			const K& operator()(const pair<const K, V>& kv)
			{
				return kv.first;
			}
		};

	public:
		bool insert(const pair<const K, V>& kv)
		{
			return _t.Insert(kv);
		}

	private:
		RBTree<K, pair<const K, V>, MapKfromT> _t;
	};
}

// Myset.h
namespace Lzc
{
	template<class K>
	class MySet
	{
		struct SetKfromT
		{
			const K& operator()(const K& k)
			{
				return k;
			}
		};

	public:
		bool insert(const K& k)
		{
			return _t.Insert(k);
		}

	private:
		RBTree<K, const K, SetKfromT> _t;
	};
}

// RBTree.h
namespace Lzc
{
	enum Color
	{
		RED,
		BLACK
	};

	template<class T>
	struct RBTreeNode
	{
		T _data;
		RBTreeNode<T>* _left;
		RBTreeNode<T>* _right;
		RBTreeNode<T>* _parent;
		Color _col;

		RBTreeNode(const T& data)
			:_data(data)
			, _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
			, _col(RED)
		{ }
	};

	template<class K, class T, class KfromT>
	class RBTree
	{
		typedef RBTreeNode<T> Node;
	public:
		KfromT KfT;
		bool Insert(const T& data)
		{
			if (_root == nullptr)
			{
				_root = new Node(data);
				_root->_col = BLACK;
				return true;
			}

			Node* parent = nullptr;
			Node* cur = _root;
			while (cur)
			{
				if (KfT(data) > KfT(cur->_data))
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (KfT(data) < KfT(cur->_data))
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return false;
				}
			}

			cur = new Node(data);
			if (KfT(data) > KfT(parent->_data))
				parent->_right = cur;
			else
				parent->_left = cur;
			cur->_parent = parent;

			while (parent && parent->_col == RED)
			{
				Node* grandfather = parent->_parent;
				Node* uncle;
				if (parent == grandfather->_left)
				{
					//    g
					//  p   u
					uncle = grandfather->_right;
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_left)
						{
							RotateR(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateL(parent);
							RotateR(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
				else
				{
					//    g
					//  u   p
					uncle = grandfather->_left;
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_right)
						{
							RotateL(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateR(parent);
							RotateL(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
			}

			if (parent == nullptr)
				_root->_col = BLACK;

			return true;
		}

		void RotateR(Node* parent)
		{
			Node* pParent = parent->_parent;
			Node* subL = parent->_left;
			Node* subLR = subL->_right;

			parent->_left = subLR;
			if (subLR)
				subLR->_parent = parent;

			subL->_right = parent;
			parent->_parent = subL;
			subL->_parent = pParent;
			if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
			{
				_root = subL;
			}
			else
			{
				if (pParent->_left == parent)
					pParent->_left = subL;
				else
					pParent->_right = subL;
			}
		}

		void RotateL(Node* parent)
		{
			Node* pParent = parent->_parent;
			Node* subR = parent->_right;
			Node* subRL = subR->_left;

			parent->_right = subRL;
			if (subRL)
				subRL->_parent = parent;

			subR->_left = parent;
			parent->_parent = subR;
			subR->_parent = pParent;
			if (pParent == nullptr)
				_root = subR;
			else
			{
				if (pParent->_left == parent)
					pParent->_left = subR;
				else
					pParent->_right = subR;
			}
		}

		Node* Find(const K& key)
		{
			Node* cur = _root;
			while (cur)
			{
				if (key > KfT(cur->_data))
					cur = cur->_right;
				else if (key < KfT(cur->_data))
					cur = cur->_left;
				else
					return cur;
			}
			return nullptr;
		}

		~RBTree()
		{
			Destroy(_root);
			_root = nullptr;
		}

		void Destroy(Node* root)
		{
			if (root == nullptr)
				return;
			Destroy(root->_left);
			Destroy(root->_right);
			delete root;
		}

	private:
		Node* _root = nullptr;
	};
}

2.2 iterator的实现

2.2.1 iterator的核心源码
typedef bool __rb_tree_color_type;
const __rb_tree_color_type __rb_tree_red = false;
const __rb_tree_color_type __rb_tree_black = true;

struct __rb_tree_base_iterator {
    typedef __rb_tree_node_base::base_ptr base_ptr;
    base_ptr node;

    void increment() {
        if (node->right != 0) {
            node = node->right;
            while (node->left != 0)
                node = node->left;
        } else {
            base_ptr y = node->parent;
            while (node == y->right) {
                node = y;
                y = y->parent;
            }
            if (node->right != y)
                node = y;
        }
    }

    void decrement() {
        if (node->color == __rb_tree_red && node->parent->parent == node) {
            node = node->right;
        } else if (node->left != 0) {
            base_ptr y = node->left;
            while (y->right != 0)
                y = y->right;
            node = y;
        } else {
            base_ptr y = node->parent;
            while (node == y->left) {
                node = y;
                y = y->parent;
            }
            node = y;
        }
    }
};

template <class Value, class Ref, class Ptr>
struct __rb_tree_iterator : public __rb_tree_base_iterator {
    typedef Value value_type;
    typedef Ref reference;
    typedef Ptr pointer;
    typedef __rb_tree_iterator<Value, Value&, Value*> iterator;

    __rb_tree_iterator() {}
    __rb_tree_iterator(link_type x) { node = x; }
    __rb_tree_iterator(const iterator& it) { node = it.node; }

    reference operator*() const { return link_type(node)->value_field; }

#ifndef __SGI_STL_NO_ARROW_OPERATOR
    pointer operator->() const { return &(operator*()); }
#endif /* __SGI_STL_NO_ARROW_OPERATOR */

    self& operator++() {
        increment();
        return *this;
    }

    self& operator--() {
        decrement();
        return *this;
    }

    inline bool operator==(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
        return x.node == y.node;
    }

    inline bool operator!=(const __rb_tree_base_iterator& x, const __rb_tree_base_iterator& y) {
        return x.node != y.node;
    }
};
2.2.2 iterator的实现思路

注意:源码有一个红色的头节点(作为end())指向_root我们没有使用头节点,用nullptr做end()。

1. 整体思路与listiterator一致封装节点的指针,迭代器类模板多传RefPtr两个参数一份模板实现iteratorconst_iterator

2. 重点operator++operator--的实现。operator++中序遍历左中右

当左为空,表示左访问完了,访问中(其实只能访问中,给的节点就是访问完的中节点),

如果右不为空,(在右子树中进行,左中右,访问右子树的最左节点)。

如果右为空,(那么这个子树已经访问完了,如果这个子树是外面的右子树,那么外面一层的子树也访问完了,直到子树是外面一层的左子树,这时左子树访问完了,访问中。即使走到了最右节点,向上走,parent为nullptr,就返回nullptr)。

然后更新迭代器中的节点指针,返回*this。

operator--就是走右中左,基本相同。

3. begin()end()begin()就给最左节点end()nullptr,但是,--end()呢? 

所以给迭代器类模板增加一个成员变量_root(红黑树的根节点),--end()就可以是最右节点

4. const 对象const迭代器不能修改数据,所以

        typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
        typedef RBTreeIterator<T, const T&, const T*, const Node*> ConstIterator;

	template<class T, class Ref, class Ptr,class NodePtr>
	struct RBTreeIterator
	{
		typedef RBTreeNode<T> Node;
		typedef RBTreeIterator<T, Ref, Ptr, NodePtr> Self;
		typedef RBTreeIterator<T, T&, T*, Node*> Iterator;

		NodePtr _node;
		NodePtr _root;

		RBTreeIterator(NodePtr node, NodePtr root)
			:_node(node)
			, _root(root)
		{
		}

		// Iterator中Iterator->Iterator,ConstIterator中Iterator->ConstIterator
		RBTreeIterator(const Iterator& it)
			:_node(it._node)
			, _root(it._root)
		{
		}

		Self& operator++()
		{
			if (_node->_right)
			{
				NodePtr cur = _node->_right;
				while (cur->_left)
				{
					cur = cur->_left;
				}
				_node = cur;
			}
			else
			{
				NodePtr cur = _node;
				NodePtr parent = cur->_parent;
				while (parent && cur == parent->_right)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}

			return *this;
		}

		Self& operator--()
		{
			// --end,因为end == nullptr,所以最右节点需要_root
			if (_node == nullptr)
			{
				NodePtr MostRight = _root;
				while (MostRight && MostRight->_right)
				{
					MostRight = MostRight->_right;
				}
				_node = MostRight;
			}
			else if (_node->_left)
			{
				NodePtr cur = _node->_left;
				while (cur->_right)
				{
					cur = cur->_right;
				}
				_node = cur;
			}
			else
			{
				NodePtr cur = _node;
				NodePtr parent = cur->_parent;
				while (parent && cur == parent->_left)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}

			return *this;
		}

		Ref operator*()
		{
			assert(_node);
			return _node->_data;
		}

		Ptr operator->()
		{
			assert(_node);
			return &(_node->_data);
		}

		bool operator!=(const Self& s) const
		{
			return _node != s._node;
		}

		bool operator==(const Self& s) const
		{
			return _node == s._node;
		}
	};

2.3 map支持[ ]

MyMap支持[ ]主要需要修改Insert返回值

修改RBtree中的insert返回值为pair<Iterator,bool> Insert(const T& data),

插入失败,就返回相同的keyvalue的引用

插入成功,就返回keyvalue(默认值)的引用

2.4 MyMap和MySet的代码实现

2.4.1 MyMap.h
#pragma once
#include "RBTree.h"

namespace Lzc
{
	template<class K, class V>
	class MyMap
	{
		struct MapKfromT
		{
			const K& operator()(const pair<const K, V>& kv)
			{
				return kv.first;
			}
		};

	public:
		typedef typename RBTree<K, pair<const K, V>, MapKfromT>::Iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapKfromT>::ConstIterator const_iterator;

		pair<iterator, bool> insert(const pair<const K, V>& kv)
		{
			return _t.Insert(kv);
		}

		V& operator[](const K& k)
		{
			iterator ret = _t.Insert({ k, V() }).first;
			return ret->second;
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

		const_iterator find(const K& key) const
		{
			return _t.Find(key);
		}

		iterator begin()
		{
			return _t.Begin();
		}

		iterator end()
		{
			return _t.End();
		}

		const_iterator begin() const
		{
			return _t.Begin();
		}

		const_iterator end() const
		{
			return _t.End();
		}

	private:
		RBTree<K, pair<const K, V>, MapKfromT> _t;
	};
}
2.4.2 MySet.h
#pragma once

#include "RBTree.h"

namespace Lzc
{
	template<class K>
	class MySet
	{
		struct SetKfromT
		{
			const K& operator()(const K& key)
			{
				return key;
			}
		};

	public:
		typedef typename RBTree<K, const K, SetKfromT>::Iterator iterator;
		typedef typename RBTree<K, const K, SetKfromT>::ConstIterator const_iterator;

		pair<iterator, bool> insert(const K& key)
		{
			return _t.Insert(key);
		}

		iterator find(const K& key)
		{
			return _t.Find(key);
		}

		const_iterator find(const K& key) const
		{
			return _t.Find(key);
		}

		iterator begin()
		{
			return _t.Begin();
		}

		iterator end()
		{
			return _t.End();
		}

		const_iterator begin() const
		{
			return _t.Begin();
		}

		const_iterator end() const
		{
			return _t.End();
		}

	private:
		RBTree<K, const K, SetKfromT> _t;
	};
}
2.4.3 RBTree.h
#pragma once

#include <iostream>
#include <assert.h>

using namespace std;

namespace Lzc
{
	enum Color
	{
		RED,
		BLACK
	};

	template<class T>
	struct RBTreeNode
	{
		T _data;
		RBTreeNode<T>* _left;
		RBTreeNode<T>* _right;
		RBTreeNode<T>* _parent;
		Color _col;

		RBTreeNode(const T& data)
			:_data(data)
			, _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
			, _col(RED)
		{ }
	};

	template<class T, class Ref, class Ptr,class NodePtr>
	struct RBTreeIterator
	{
		typedef RBTreeNode<T> Node;
		typedef RBTreeIterator<T, Ref, Ptr, NodePtr> Self;
		typedef RBTreeIterator<T, T&, T*, Node*> Iterator;

		NodePtr _node;
		NodePtr _root;

		RBTreeIterator(NodePtr node, NodePtr root)
			:_node(node)
			, _root(root)
		{
		}

		// Iterator中Iterator->Iterator,ConstIterator中Iterator->ConstIterator
		RBTreeIterator(const Iterator& it)
			:_node(it._node)
			, _root(it._root)
		{
		}

		Self& operator++()
		{
			if (_node->_right)
			{
				NodePtr cur = _node->_right;
				while (cur->_left)
				{
					cur = cur->_left;
				}
				_node = cur;
			}
			else
			{
				NodePtr cur = _node;
				NodePtr parent = cur->_parent;
				while (parent && cur == parent->_right)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}

			return *this;
		}

		Self& operator--()
		{
			// --end,因为end == nullptr,所以最右节点需要_root
			if (_node == nullptr)
			{
				NodePtr MostRight = _root;
				while (MostRight && MostRight->_right)
				{
					MostRight = MostRight->_right;
				}
				_node = MostRight;
			}
			else if (_node->_left)
			{
				NodePtr cur = _node->_left;
				while (cur->_right)
				{
					cur = cur->_right;
				}
				_node = cur;
			}
			else
			{
				NodePtr cur = _node;
				NodePtr parent = cur->_parent;
				while (parent && cur == parent->_left)
				{
					cur = parent;
					parent = cur->_parent;
				}
				_node = parent;
			}

			return *this;
		}

		Ref operator*()
		{
			return _node->_data;
		}

		Ptr operator->()
		{
			return &(_node->_data);
		}

		bool operator!=(const Self& s) const
		{
			return _node != s._node;
		}

		bool operator==(const Self& s) const
		{
			return _node == s._node;
		}
	};

	template<class K, class T, class KfromT>
	class RBTree
	{
		typedef RBTreeNode<T> Node;
	public:
		typedef RBTreeIterator<T, T&, T*, Node*> Iterator;
		typedef RBTreeIterator<T, const T&, const T*, const Node*> ConstIterator;

		RBTree()
			:_root(nullptr)
		{ }

		RBTree(const RBTree& rbt)
		{
			_root = Copy(rbt._root, nullptr);
		}

		RBTree& operator=(const RBTree& rbt)
		{
			if (this != &rbt)
			{
				RBTree tmp(rbt);
				swap(_root, tmp._root);
			}


			return *this;
		}

		~RBTree()
		{
			Destroy(_root);
			_root = nullptr;
		}

		Iterator Begin()
		{
			Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}
			return Iterator(cur, _root);
		}

		Iterator End()
		{
			return Iterator(nullptr, _root);
		}

		ConstIterator Begin() const
		{
			const Node* cur = _root;
			while (cur && cur->_left)
			{
				cur = cur->_left;
			}
			return ConstIterator(cur, _root);
		}

		ConstIterator End() const
		{
			return ConstIterator(nullptr, _root);
		}

		KfromT KfT;
		pair<Iterator, bool> Insert(const T& data)
		{
			if (_root == nullptr)
			{
				_root = new Node(data);
				_root->_col = BLACK;
				return { Iterator(_root,_root),true };
			}

			Node* parent = nullptr;
			Node* cur = _root;
			while (cur)
			{
				if (KfT(data) > KfT(cur->_data))
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (KfT(data) < KfT(cur->_data))
				{
					parent = cur;
					cur = cur->_left;
				}
				else
				{
					return { Iterator(cur,_root),false };
				}
			}

			cur = new Node(data);
			Node* newnode = cur; // cur可能后面会更新
			if (KfT(data) > KfT(parent->_data))
				parent->_right = cur;
			else
				parent->_left = cur;
			cur->_parent = parent;

			while (parent && parent->_col == RED)
			{
				Node* grandfather = parent->_parent;
				Node* uncle;
				if (parent == grandfather->_left)
				{
					//    g
					//  p   u
					uncle = grandfather->_right;
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_left)
						{
							RotateR(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateL(parent);
							RotateR(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
				else
				{
					//    g
					//  u   p
					uncle = grandfather->_left;
					if (uncle && uncle->_col == RED)
					{
						parent->_col = uncle->_col = BLACK;
						grandfather->_col = RED;

						cur = grandfather;
						parent = cur->_parent;
					}
					else
					{
						if (cur == parent->_right)
						{
							RotateL(grandfather);
							parent->_col = BLACK;
							grandfather->_col = RED;
						}
						else
						{
							RotateR(parent);
							RotateL(grandfather);
							cur->_col = BLACK;
							grandfather->_col = RED;
						}

						break;
					}
				}
			}

			if (parent == nullptr)
				_root->_col = BLACK;

			return { Iterator(newnode,_root),true };
		}

		void RotateR(Node* parent)
		{
			Node* pParent = parent->_parent;
			Node* subL = parent->_left;
			Node* subLR = subL->_right;

			parent->_left = subLR;
			if (subLR)
				subLR->_parent = parent;

			subL->_right = parent;
			parent->_parent = subL;
			subL->_parent = pParent;
			if (pParent == nullptr) // 当pParent == nullptr时,_root == parent
			{
				_root = subL;
			}
			else
			{
				if (pParent->_left == parent)
					pParent->_left = subL;
				else
					pParent->_right = subL;
			}
		}

		void RotateL(Node* parent)
		{
			Node* pParent = parent->_parent;
			Node* subR = parent->_right;
			Node* subRL = subR->_left;

			parent->_right = subRL;
			if (subRL)
				subRL->_parent = parent;

			subR->_left = parent;
			parent->_parent = subR;
			subR->_parent = pParent;
			if (pParent == nullptr)
				_root = subR;
			else
			{
				if (pParent->_left == parent)
					pParent->_left = subR;
				else
					pParent->_right = subR;
			}
		}

		Iterator Find(const K& key)
		{
			Node* cur = _root;
			while (cur)
			{
				if (key > KfT(cur->_data))
					cur = cur->_right;
				else if (key < KfT(cur->_data))
					cur = cur->_left;
				else
					return Iterator(cur, _root);
			}
			return Iterator(nullptr, _root);
		}

		ConstIterator Find(const K& key) const
		{
			const Node* cur = _root;
			while (cur)
			{
				if (key > KfT(cur->_data))
					cur = cur->_right;
				else if (key < KfT(cur->_data))
					cur = cur->_left;
				else
					return ConstIterator(cur, _root);
			}
			return ConstIterator(nullptr, _root);
		}

	private:
		Node* Copy(Node* root,Node* parent) // 后序
		{
			if (root == nullptr)
				return nullptr;

			Node* newNode = new Node(root->_data);
			newNode->_col = root->_col;
			newNode->_left = Copy(root->_left,newNode);
			newNode->_right = Copy(root->_right,newNode);
			newNode->_parent = parent;

			return newNode;
		}

		void Destroy(Node* root)
		{
			if (root == nullptr)
				return;
			Destroy(root->_left);
			Destroy(root->_right);
			delete root;
		}

		Node* _root;
	};
}
2.4.4 Test.cpp
#include "RBTree.h"
#include "MySet.h"
#include "MyMap.h"

struct SetKfromT
{
    const int& operator()(const int& key)
    {
        return key;
    }
};

void TestIterators() {
    // 1. 测试普通迭代器
    Lzc::RBTree<int, int, SetKfromT> tree;
    tree.Insert(10);
    tree.Insert(20);
    tree.Insert(5);
    tree.Insert(15);

    cout << "Testing regular iterator:" << endl;
    for (auto it = tree.Begin(); it != tree.End(); ++it) {
        cout << *it << " ";  // 应输出:5 10 15 20
        *it += 1;  // 测试通过普通迭代器修改元素
    }
    cout << endl;

    // 验证修改后的值
    cout << "After modification:" << endl;
    for (auto it = tree.Begin(); it != tree.End(); ++it) {
        cout << *it << " ";  // 应输出:6 11 16 21
    }
    cout << endl;

    // 2. 测试const迭代器
    const auto& constTree = tree;
    cout << "Testing const iterator:" << endl;
    for (auto it = constTree.Begin(); it != constTree.End(); ++it) {
        cout << *it << " ";  // 应输出:6 11 16 21
        // *it += 1;  // 这行应该编译失败,验证const迭代器不可修改
    }
    cout << endl;

    // 3. 测试iterator转const_iterator的兼容性
    cout << "Testing iterator to const_iterator conversion:" << endl;
    auto it = tree.Begin();
    Lzc::RBTree<int, int, SetKfromT>::ConstIterator cit = it;
    cout << *cit << endl;  // 应输出第一个元素6
    assert(*cit == *it);

    // 4. 测试反向遍历
    cout << "Testing reverse traversal using --:" << endl;
    auto rit = tree.End();
    --rit;  // 移动到最后一个元素
    for (; rit != tree.Begin(); --rit) {
        cout << *rit << " ";  // 应输出:21 16 11
    }
    cout << *rit << endl;  // 输出最后一个元素6

    cout << "All iterator tests passed!" << endl;
}

void TestMapBasic()
{
    Lzc::MyMap<int, string> map;
    map.insert({ 10, "A" });
    map.insert({ 20, "B" });
    map.insert({ 5, "C" });
    map.insert({ 15, "D" });

    // 测试查找
    assert(map.find(10) != map.end());
    assert(map.find(20) != map.end());
    assert(map.find(5) != map.end());
    assert(map.find(15) != map.end());
    assert(map.find(100) == map.end()); // 不存在的 key

    // 测试 operator[]
    map[30] = "E"; // 插入新键
    assert(map.find(30) != map.end());
    assert(map[30] == "E");

    map[10] = "AA"; // 修改已有键
    assert(map[10] == "AA");

    // 测试遍历(中序输出)
    cout << "Map InOrder Traversal:" << endl;
    for (auto it = map.begin(); it != map.end(); ++it)
    {
        cout << it->first << ":" << it->second << " ";
    }
    cout << endl; // 应输出:5:C 10:AA 15:D 20:B 30:E

    cout << "TestMapBasic passed!" << endl;
}

void TestSetBasic()
{
    Lzc::MySet<int> set;
    set.insert(10);
    set.insert(20);
    set.insert(5);
    set.insert(15);

    // 测试查找
    assert(set.find(10) != set.end());
    assert(set.find(20) != set.end());
    assert(set.find(5) != set.end());
    assert(set.find(15) != set.end());
    assert(set.find(100) == set.end()); // 不存在的 key

    // 测试遍历(中序输出)
    cout << "Set InOrder Traversal:" << endl;
    for (auto it = set.begin(); it != set.end(); ++it)
    {
        cout << *it << " ";
    }
    cout << endl; // 应输出:5 10 15 20

    cout << "TestSetBasic passed!" << endl;
}

void TestIteratorDecrement()
{
    Lzc::MyMap<int, string> map;
    map.insert({ 10, "A" });
    map.insert({ 20, "B" });
    map.insert({ 5, "C" });
    map.insert({ 15, "D" });

    // 测试 -- 操作
    auto it = map.find(20);
    assert(it != map.end());
    --it; // 应该指向 15
    assert(it->first == 15);
    --it; // 应该指向 10
    assert(it->first == 10);
    --it; // 应该指向 5
    assert(it->first == 5);
    --it; // 应该等于 begin(),再 -- 会未定义行为(不测试)

    // 测试 --end()
    it = map.end();
    --it; // 应该指向最大的元素 20
    assert(it->first == 20);

    cout << "TestIteratorDecrement passed!" << endl;
}

void TestCopyAndAssignment()
{
    Lzc::MyMap<int, string> map1;
    map1.insert({ 10, "A" });
    map1.insert({ 20, "B" });
    map1.insert({ 5, "C" });

    // 测试拷贝构造
    Lzc::MyMap<int, string> map2(map1);
    assert(map2.find(10) != map2.end());
    assert(map2.find(20) != map2.end());
    assert(map2.find(5) != map2.end());

    // 测试赋值运算符
    Lzc::MyMap<int, string> map3;
    map3 = map1;
    assert(map3.find(10) != map3.end());
    assert(map3.find(20) != map3.end());
    assert(map3.find(5) != map3.end());

    // 修改原 map,确保深拷贝
    map1.insert({ 30, "D" });
    assert(map2.find(30) == map2.end()); // map2 不应受影响
    assert(map3.find(30) == map3.end()); // map3 不应受影响

    cout << "TestCopyAndAssignment passed!" << endl;
}

void TestEmptyContainer()
{
    Lzc::MyMap<int, string> emptyMap;
    assert(emptyMap.begin() == emptyMap.end()); // 空 map 的 begin == end

    Lzc::MySet<int> emptySet;
    assert(emptySet.begin() == emptySet.end()); // 空 set 的 begin == end

    cout << "TestEmptyContainer passed!" << endl;
}

#include <random>
void TestRandomData()
{
    Lzc::MyMap<int, int> map;
    const int N = 10000;

    // 插入 10000 个随机数
    random_device rd;
    mt19937 gen(rd());
    uniform_int_distribution<> dis(1, 100000);

    for (int i = 0; i < N; ++i)
    {
        int key = dis(gen);
        map[key] = key; // 使用 operator[]
    }

    // 检查是否能正确查找
    for (int i = 0; i < 100; ++i) // 随机抽查 100 个
    {
        int key = dis(gen);
        auto it = map.find(key);
        if (it != map.end())
        {
            assert(it->second == key);
        }
    }

    cout << "TestRandomData passed!" << endl;
}

int main()
{
    TestIterators();
    TestMapBasic();
    TestSetBasic();
    TestIteratorDecrement();
    TestCopyAndAssignment();
    TestEmptyContainer();
    TestRandomData();

    cout << "All tests passed!" << endl;
    return 0;
}
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Lzc-c

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值