将节点的值定义为其输入中1的个数。观察发现,每次修改的都是一条链上的值:如果是将某个节点的值从1改成0,那么就是要将这个节点到其祖先中离该节点最近的值非2的节点的值全部减一,如果是将某个节点的值从0改成1,那么就是要将这个节点到其祖先中离该节点最近的值非1的节点的值全部加1。
所以维护splay表示的链中最后一个值非1和值非2的节点的位置即可。修改节点x的值的时候,先access(x)
,然后将最后一个值非1(或者非2)的节点(记为y)splay到根,然后对y节点做单点修改,对y的右子树做区间加。有一个特殊情况,就是找不到这样的节点y,这说明从x一直到根都是1(或者2),这时直接对整条链做区间加就好了,而且此时根的输出必然会翻转。
注意,这题中,LCT并不需要做换根操作。
这题告诉我们,LCT要灵活应用,不要只记得link
和cut
以及区间操作之类的标准化的用法。灵活运用splay
和access
更强大。
代码:
#include <iostream>
#include <cstring>
#include <cassert>
#include <cmath>
using namespace std;
#define MAXN 500011
struct LCT {
int c[MAXN][2], fa[MAXN], sta[MAXN];
// val[i] is the number of 1 in the input of node i.
// not_1[i] is the last node whose value is not 1 in the current subtree.
int val[MAXN], not_1[MAXN], not_2[MAXN], lazy_add[MAXN];
inline int& ls(int rt) {
return c[rt][0];
}
inline int& rs(int rt) {
return c[rt][1];
}
inline bool not_splay_rt(int x) {
return ls(fa[x]) == x || rs(fa[x]) == x;
}
inline int side(int x) {
return x == rs(fa[x]);
}
void Init() {
// Initially every node is a tree by itself.
// memset all to 0.
}
inline void push_add(int x, unsigned int c) {
(not_1[x], not_2[x]);
swap[x] += c;
val[x] += c;
lazy_add}
inline void pushdown(int x) {
if (lazy_add[x]) {
if (ls(x))
(ls(x), lazy_add[x]);
push_addif (rs(x))
(rs(x), lazy_add[x]);
push_add[x] = 0;
lazy_add}
}
inline void pushup(int x) {
[x] = not_1[rs(x)] ? not_1[rs(x)] :
not_1(val[x] != 1 ? x : not_1[ls(x)]);
[x] = not_2[rs(x)] ? not_2[rs(x)] :
not_2(val[x] != 2 ? x : not_2[ls(x)]);
}
// s[x] is not updated
void __rotate_up(int x) {
int y = fa[x], z = fa[y], side_x = side(x), w = c[x][side_x ^ 1];
[x] = z;
faif (not_splay_rt(y))
[z][side(y)] = x;
cif (w)
[w] = y;
fa[y][side_x] = w;
c[y] = x;
fa[x][side_x ^ 1] = y;
c(y);
pushup}
// s[x] is not updated
void __splay(int x) {
int y = x, top = 0;
while(1) {
[++top] = y;
staif (!not_splay_rt(y))
break;
= fa[y];
y }
int to = fa[y];
while (top)
(sta[top--]);
pushdownwhile (fa[x] != to) {
int y = fa[x];
if (fa[y] != to)
(side(x) == side(y) ? y : x);
__rotate_up(x);
__rotate_up}
}
inline void splay(int x) {
(x);
__splay(x);
pushup}
void access(int x) {
int ori_x = x;
for (int w = 0; x; w = x, x = fa[x]) {
(x);
__splay(x) = w;
rs(x);
pushup}
(ori_x);
splay}
inline void link_new(int rt, int x) {
// If simply fa[x] = y, the complexity might be wrong.
(rt);
access(x);
access[x] = rt;
fa(rt) = x;
ls(rt); // Might be unnecessary
pushup}
};
int main() {
static LCT lct;
int n, q;
static int fa[MAXN * 3], refcnt[MAXN], sta[MAXN];
static bool val[MAXN * 3];
int top = 0;
("%d", &n);
scanffor (int i = 1; i <= n; ++i) {
int x1, x2, x3;
("%d%d%d", &x1, &x2, &x3);
scanf[x1] = fa[x2] = fa[x3] = i;
fa[i] = 3;
refcnt}
for (int i = n + 1; i <= 3 * n + 1; ++i) {
int v;
("%d", &v);
scanf[i] = v;
valif (v)
++lct.val[fa[i]];
if (--refcnt[fa[i]] == 0) {
[++top] = fa[i];
sta}
}
while (top) {
int i = sta[top--];
if (lct.val[i] > 1) {
++lct.val[fa[i]];
}
if (fa[i])
.link_new(i, fa[i]);
lctif (--refcnt[fa[i]] == 0) {
[++top] = fa[i];
sta}
}
int rt = sta[1];
bool rt_val = (lct.val[rt] > 1);
("%d", &q);
scanfwhile (q--) {
int i;
("%d", &i);
scanfint x = fa[i];
.access(x);
lctint add = (val[i] == 0 ? 1 : -1);
[i] ^= 1;
valint y = (add > 0 ? lct.not_1 : lct.not_2)[x];
if (y == 0) {
.push_add(x, add);
lct^= 1;
rt_val } else {
.splay(y);
lct.val[y] += add;
lct.push_add(lct.rs(y), add);
lct.pushup(y);
lct}
("%d\n", rt_val);
printf}
return 0;
}