简介
Splay 维护平衡的方式就是每访问到一个节点, 都把它旋转到根. 这个把节点 x 旋转到根的过程就叫做 Splay.
维护变量
1
|
int rt, fa[_], ch[_][2], val[_], num[_], sum[_];
|
rt
: 根节点.
fa[u]
: 节点 u 的父节点.
ch[u][0/1]
: 节点 u 的左右儿子.
val[u]
: 节点 u 的权值.
num[u]
: 权值为 val[u]
的元素个数.
sum[u]
: 节点 u 的子树 num
之和.
主要操作
Update
1
|
void upd(int x) { sum[x] = sum[ch[x][0]] + sum[ch[x][1]] + num[x]; }
|
Rotate
和 Treap 的 Rotate 没啥区别.
1
2
3
4
5
6
7
8
|
void Rotate(int x) {
int y = fa[x], t = get(x);
ch[y][t] = ch[x][!t], fa[ch[x][!t]] = ch[x][!t] ? y : 0; // modify ch[x][!t]
fa[x] = fa[y], ch[fa[x]][get(y)] = fa[x] ? x : 0; // modify x
ch[x][!t] = y, fa[y] = x; // modify y
upd(y), upd(x);
}
|
Splay
设 get(u) = 0/1
表示 u 是 fa[u] 的左/右儿子.
为了保证平衡, 若有 get(u) == get(fa[u])
, 则要先 Rotate(fa[u]) 后再 Rotate(u).
1
2
3
4
5
6
7
|
void Splay(int x) {
while (fa[x]) {
if (fa[fa[x]] and get(fa[x]) == get(x)) Rotate(fa[x]);
Rotate(x);
}
rt = x;
}
|
Insert
一直往下找, 若有 val[u] == w
, 则 ++num[u]
; 若到了空节点, 则新建节点. 最后 Splay.
1
2
3
4
5
6
7
|
void Ins(int &u, int f, int w) {
if (!u) { u = ++tot, fa[u] = f, val[u] = w, num[u] = sum[u] = 1; Splay(u); return; }
if (val[u] == w) { ++num[u], ++sum[u]; Splay(u); return; }
Ins(ch[u][val[u] < w], u, w);
}
void Ins(int w) { Ins(rt, 0, w); }
|
多写一个是为了方便调用.
Merge
设合并的两棵树为 A, B, 则需满足 A 的最大值小于 B 的最小值.
若有一棵树为空, 则将另外一棵树的根节点设为根.
否则将 A 中的最大值 Splay 到根, 然后将 B 设为 A 的右子树. (记得更新相关信息).
1
2
3
4
5
6
7
|
void Merge(int x, int y) {
if (!x) { rt = y, fa[y] = 0; return; } //**
rt = x, fa[x] = 0;
while (ch[x][1]) x = ch[x][1];
Splay(x);
ch[rt][1] = y, fa[y] = rt, upd(rt); //**
}
|
(//**
是容易写错的地方.)
Delete
若有 num[u] > 1
, 则 --num[u]
.
否则将 Splay(u)
, 然后合并 u 的两棵子树.
1
2
3
4
5
|
void Del(int w) {
Find(rt, w);
if (num[rt] > 1) --num[rt];
else Merge(ch[rt][0], ch[rt][1]);
}
|
其他
找排名, 找第 k 小 / 大, 找前驱, 找后继. 都和 Treap 差不多.
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
#include <cstdio>
#include <iostream>
using namespace std;
const int _ = 1e5 + 7;
struct SPLAY {
#define get(x) (x == ch[fa[x]][1])
int rt, fa[_], ch[_][2], num[_], sum[_], val[_], tot;
void upd(int x) { sum[x] = sum[ch[x][0]] + sum[ch[x][1]] + num[x]; }
void Rotate(int x) {
int y = fa[x], t = get(x);
ch[y][t] = ch[x][!t], fa[ch[x][!t]] = ch[x][!t] ? y : 0; // modify ch[x][!t]
fa[x] = fa[y], ch[fa[x]][get(y)] = fa[x] ? x : 0; // modify x
ch[x][!t] = y, fa[y] = x; // modify y
upd(y), upd(x);
}
void Splay(int x) {
while (fa[x]) {
if (fa[fa[x]] and get(fa[x]) == get(x)) Rotate(fa[x]);
Rotate(x);
}
rt = x;
}
void Ins(int &u, int f, int w) {
if (!u) { u = ++tot, fa[u] = f, val[u] = w, num[u] = sum[u] = 1; Splay(u); return; }
if (val[u] == w) { ++num[u], ++sum[u]; Splay(u); return; }
Ins(ch[u][val[u] < w], u, w);
}
void Ins(int w) { Ins(rt, 0, w); }
void Find(int u, int w) {
if (val[u] == w) { Splay(u); return; }
Find(ch[u][val[u] < w], w);
}
void Merge(int x, int y) {
if (!x) { rt = y, fa[y] = 0; return; } //**
rt = x, fa[x] = 0;
while (ch[x][1]) x = ch[x][1];
Splay(x);
ch[rt][1] = y, fa[y] = rt, upd(rt); //**
}
void Del(int w) {
Find(rt, w);
if (num[rt] > 1) --num[rt];
else Merge(ch[rt][0], ch[rt][1]);
}
int rk(int w) {
int u = rt, res = 0;
while (u) {
if (val[u] == w) return res + sum[ch[u][0]] + 1;
else if (val[u] < w) res += sum[ch[u][0]] + num[u], u = ch[u][1];
else u = ch[u][0];
}
return res + 1;
}
int kth(int res) {
int u = rt;
while (res) {
if (sum[ch[u][0]] < res and sum[ch[u][0]] + num[u] >= res) return val[u];
else if (res <= sum[ch[u][0]]) u = ch[u][0];
else res -= sum[ch[u][0]] + num[u], u = ch[u][1];
}
return val[u];
}
int pre(int w) {
Ins(w);
int u = ch[rt][0];
while (ch[u][1]) u = ch[u][1];
Del(w); //**
return val[u];
}
int suf(int w) {
Ins(w);
int u = ch[rt][1];
while (ch[u][0]) u = ch[u][0];
Del(w); //**
return val[u];
}
} S;
int n;
int main() {
cin >> n;
for (int i = 1, ty, x; i <= n; ++i) {
scanf("%d%d", &ty, &x);
if (ty == 1) S.Ins(x);
if (ty == 2) S.Del(x);
if (ty == 3) printf("%d\n", S.rk(x));
if (ty == 4) printf("%d\n", S.kth(x));
if (ty == 5) printf("%d\n", S.pre(x));
if (ty == 6) printf("%d\n", S.suf(x));
}
return 0;
}
|