$\newcommand{\O}{\mathrm{O}}$

Splay Tree ベースの Set(順序付き集合) の実装. やっていることは こちら の実装と同じ.
時間計算量: 各クエリならし $\O (\log n)$
template<class _Key> class SetIterator;
template<class _Key> class Set {
private:
using iterator = SetIterator<_Key>;
struct node {
const _Key _M_key;
node *_M_left, *_M_right, *_M_parent;
node(_Key&& key) noexcept
: _M_key(move(key)), _M_left(nullptr), _M_right(nullptr), _M_parent(nullptr){}
inline bool isRoot() const noexcept { return !_M_parent; }
void rotate(const bool right){
node *p = _M_parent, *g = p->_M_parent;
if(right){
if((p->_M_left = _M_right)) _M_right->_M_parent = p;
_M_right = p, p->_M_parent = this;
}else{
if((p->_M_right = _M_left)) _M_left->_M_parent = p;
_M_left = p, p->_M_parent = this;
}
if(!(_M_parent = g)) return;
if(g->_M_left == p) g->_M_left = this;
if(g->_M_right == p) g->_M_right = this;
}
};
friend SetIterator<_Key>;
size_t _M_node_count;
node *_M_root, *_M_header, *_M_start;
inline void confirm_header(){
if(!_M_header){
_Key new_key;
_M_root = _M_header = _M_start = new node(move(new_key));
}
}
node *splay(node *u){
while(!(u->isRoot())){
node *p = u->_M_parent, *gp = p->_M_parent;
if(p->isRoot()){
u->rotate((u == p->_M_left));
}else{
bool flag = (u == p->_M_left);
if((u == p->_M_left) == (p == gp->_M_left)){
p->rotate(flag), u->rotate(flag);
}else{
u->rotate(flag), u->rotate(!flag);
}
}
}
return _M_root = u;
}
static node *increment(node *ver){
if(ver->_M_right){
ver = ver->_M_right;
while(ver->_M_left) ver = ver->_M_left;
}else{
node *nx = ver->_M_parent;
while(ver == nx->_M_right) ver = nx, nx = nx->_M_parent;
ver = nx;
}
return ver;
}
static node *decrement(node *ver){
if(ver->_M_left){
ver = ver->_M_left;
while(ver->_M_right) ver = ver->_M_right;
}else{
node *nx = ver->_M_parent;
while(ver == nx->_M_left) ver = nx, nx = nx->_M_parent;
ver = nx;
}
return ver;
}
node *join(node *ver1, node *ver2, const node *ver){
while(ver2->_M_left) ver2 = ver2->_M_left;
splay(ver2)->_M_left = ver1;
return ver1 ? (ver1->_M_parent = ver2) : (_M_start = ver2);
}
node *_find(const _Key& key){
confirm_header();
node *cur = nullptr, *nx = _M_root;
do {
cur = nx;
if(cur == _M_header || key < cur->_M_key) nx = cur->_M_left;
else if(cur->_M_key < key) nx = cur->_M_right;
else return splay(cur);
}while(nx);
return _M_header;
}
template<typename Key>
node *_insert(Key&& key){
confirm_header();
node *cur = nullptr, *nx = _M_root;
do {
cur = nx;
if(cur == _M_header || key < cur->_M_key) nx = cur->_M_left;
else if(cur->_M_key < key) nx = cur->_M_right;
else return splay(cur);
}while(nx);
if(cur == _M_header || key < cur->_M_key){
_Key new_key = forward<Key>(key);
nx = new node(move(new_key));
cur->_M_left = nx, nx->_M_parent = cur;
if(cur == _M_start) _M_start = nx;
return ++_M_node_count, splay(nx);
}else{
_Key new_key = forward<Key>(key);
nx = new node(move(new_key));
cur->_M_right = nx, nx->_M_parent = cur;
return ++_M_node_count, splay(nx);
}
}
template<typename... Args>
node *_emplace(Args&&... args){
return _insert(_Key(forward<Args>(args)...));
}
node *_erase(node *root_ver){
confirm_header();
assert(root_ver != _M_header);
if(root_ver->_M_left) root_ver->_M_left->_M_parent = nullptr;
if(root_ver->_M_right) root_ver->_M_right->_M_parent = nullptr;
node *res = join(root_ver->_M_left, root_ver->_M_right, root_ver);
delete root_ver;
return --_M_node_count, res;
}
size_t _erase(const _Key& key){
node *ver = _find(key);
if(ver == _M_header){
return 0;
}else{
_erase(ver);
return 1;
}
}
node *_lower_bound(const _Key& key){
confirm_header();
node *cur = nullptr, *nx = _M_root, *res = nullptr;
do {
cur = nx;
if(cur == _M_header){ nx = cur->_M_left; continue; }
else if(!(cur->_M_key < key)){
nx = cur->_M_left;
if(!res || !(res->_M_key < cur->_M_key)) res = cur;
}else nx = cur->_M_right;
}while(nx);
splay(cur);
return res ? res : _M_header;
}
node *_upper_bound(const _Key& key){
confirm_header();
node *cur = nullptr, *nx = _M_root, *res = nullptr;
do {
cur = nx;
if(cur == _M_header){ nx = cur->_M_left; continue; }
else if(key < cur->_M_key){
nx = cur->_M_left;
if(!res || !(res->_key < cur->_M_key)) res = cur;
}else nx = cur->_M_right;
}while(nx);
splay(cur);
return res ? res : _M_header;
}
void clear_dfs(node *cur) noexcept {
if(cur->_M_left) clear_dfs(cur->_M_left);
if(cur->_M_right) clear_dfs(cur->_M_right);
delete cur;
}
public:
Set() noexcept : _M_node_count(0), _M_root(nullptr), _M_header(nullptr), _M_start(nullptr){}
Set(const Set&) = delete;
Set(Set&& another) : _M_node_count(move(another._M_node_count)){
_M_root = another._M_root, _M_header = another._M_header, _M_start = another._M_start;
another._M_root = nullptr, another._M_header = nullptr, another._M_start = nullptr;
}
Set& operator=(const Set&) = delete;
Set& operator=(Set&& another){
this->~Set();
_M_node_count = another._M_node_count;
_M_root = another._M_root, _M_header = another._M_header, _M_start = another._M_start;
another._M_root = nullptr, another._M_header = nullptr, another._M_start = nullptr;
}
// ~Set(){ if(_M_root) clear_dfs(_M_root); }
friend ostream& operator<< (ostream& os, Set& st) noexcept {
for(auto& val : st) os << val << " ";
return os;
}
size_t size() const noexcept { return _M_node_count; }
bool empty() const noexcept { return size() == 0; }
iterator begin() noexcept { return confirm_header(), iterator(_M_start); }
iterator end() noexcept { return confirm_header(), iterator(_M_header); }
void clear(){
if(_M_root) clear_dfs(_M_root);
_M_node_count = 0;
_Key new_key = _Key();
_M_root = _M_header = _M_start = new node(move(new_key));
}
iterator find(const _Key& key){ return iterator(_find(key)); }
size_t count(const _Key& key){ return (_find(key) != _M_header); }
iterator insert(const _Key& key){ return iterator(_insert(key)); }
iterator insert(_Key&& key){ return iterator(_insert(move(key))); }
template<typename... Args>
iterator emplace(Args&&... args){ return iterator(_emplace(forward<Args>(args)...)); }
size_t erase(const _Key& key){ return _erase(key); }
iterator erase(const iterator& itr){ return iterator(_erase(splay(itr.node_ptr))); }
iterator lower_bound(const _Key& key){ return iterator(_lower_bound(key)); }
iterator upper_bound(const _Key& key){ return iterator(_upper_bound(key)); }
};
template<class _Key>
class SetIterator {
private:
friend Set<_Key>;
typename Set<_Key>::node *node_ptr;
using iterator_category = bidirectional_iterator_tag;
using value_type = const _Key;
using difference_type = const _Key;
using pointer = const _Key*;
using reference = const _Key&;
private:
SetIterator(typename Set<_Key>::node *st) noexcept : node_ptr(st){}
public:
SetIterator() noexcept : node_ptr(){}
SetIterator(const SetIterator& itr) noexcept : node_ptr(itr.node_ptr){}
SetIterator& operator=(const SetIterator& itr) & noexcept { return node_ptr = itr.node_ptr, *this; }
SetIterator& operator=(const SetIterator&& itr) & noexcept { return node_ptr = itr.node_ptr, *this; }
reference operator*() const { return node_ptr->_M_key; }
pointer operator->() const { return &(node_ptr->_M_key); }
SetIterator& operator++() noexcept { return node_ptr = Set<_Key>::increment(node_ptr), *this; }
SetIterator operator++(int) const noexcept { return SetIterator(Set<_Key>::increment(this->node_ptr)); }
SetIterator& operator--() noexcept { return node_ptr = Set<_Key>::decrement(node_ptr), *this; }
SetIterator operator--(int) const noexcept { return SetIterator(Set<_Key>::decrement(this->node_ptr)); }
bool operator==(const SetIterator& itr) const noexcept { return !(*this != itr); };
bool operator!=(const SetIterator& itr) const noexcept { return node_ptr != itr.node_ptr; }
};
AOJ : Set - Set: Delete 提出コード