给定一棵 \(n\) 个点的无根树,每条边都有标记0或1。现在有 \(m\) 次操作,每次操作将第 \(id\) 条边的标记异或1。在每次操作后输出当前树上最长的有偶数个1的路径长度。
\(n,m \leq 5 \times 10^5\)。
动态DP
设 \(dp[x][0/1]\) 表示 \(x\) 向下有偶数/奇数个1的最长路径,\(dp[x][2]\) 表示 \(x\) 子树内偶数个1的直径。
容易写出转移(\(val_y\) 表示 \(y\) 的父边的权值): \[ \begin{aligned} dp'[x][0] &= \max\{dp[x][0], dp[y][val_y] + 1\}\\ dp'[x][1] &= \max\{dp[x][1], dp[y][val_y \oplus 1] + 1\}\\ dp'[x][2] &= \max\{dp[x][0] + dp[y][val_y] + 1, dp[x][1] + dp[y][val_y \oplus 1] + 1, dp[x][2],dp[y][2]\} \end{aligned} \] 改写成 \(x\) 从重儿子 \(y\) 转移的矩阵(这里以 \(val_y = 0\) 为例): \[ \begin{bmatrix} dp'[x][0]\\ dp'[x][1]\\ dp'[x][2]\\ 0 \end{bmatrix} = \begin{bmatrix} 1 & -\infty & -\infty & dp[x][0]\\ -\infty & 1 & -\infty & dp[x][1]\\ dp[x][0] + 1 & dp[x][1] + 1 & 0 & dp[x][2]\\ -\infty & -\infty & -\infty & 0 \end{bmatrix} \begin{bmatrix} dp[y][0]\\ dp[y][1]\\ dp[y][2]\\ 0 \end{bmatrix} \] 上面的 \(dp[x]\) 是指排除掉 \(x\) 重儿子的答案,\(dp'[x]\) 是指完整的 \(x\) 的答案。
\(dp[x]\) 考虑直接用
std::multiset
存一下每个节点虚儿子的 \(dp[y][0], dp[y][1],
dp[y][2]\),取每个的最大和次大值即可快速求出。
复杂度是 \(O(4^3 m \log n + m \log^2n)\),能跑过 \(5 \times 10^5\) 也是奇迹。
细节:在 multiset
中删除 \(-\infty\)
时不能直接查找对应的值,应该直接删除 multiset
中最小的那个值(具体见函数
void erase(std::multiset<int>&, int)
)。
#include <bits/stdc++.h>
#ifdef LOCAL
#define dbg(args...) std::cerr << "\033[32;1m" << #args << " -> ", err(args)
#else
#define dbg(...)
#endif
inline void err() { std::cerr << "\033[0m\n"; }
template<class T, class... U>
inline void err(const T &x, const U &... a) { std::cerr << x << ' '; err(a...); }
template <class T>
inline void readInt(T &w) {
char c, p = 0;
while (!isdigit(c = getchar())) p = c == '-';
for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
template <class T, class U>
inline bool smin(T &x, const U &y) { return y < x ? x = y, 1 : 0; }
template <class T, class U>
inline bool smax(T &x, const U &y) { return x < y ? x = y, 1 : 0; }
typedef long long LL;
typedef std::pair<int, int> PII;
using Matrix = std::array<std::array<int, 4>, 4>;
operator*(const Matrix &a, const Matrix &b) {
Matrix ;
Matrix r[0][0] = std::max({ a[0][0] + b[0][0], a[0][1] + b[1][0], a[0][2] + b[2][0], a[0][3] + b[3][0] });
r[0][1] = std::max({ a[0][0] + b[0][1], a[0][1] + b[1][1], a[0][2] + b[2][1], a[0][3] + b[3][1] });
r[0][2] = std::max({ a[0][0] + b[0][2], a[0][1] + b[1][2], a[0][2] + b[2][2], a[0][3] + b[3][2] });
r[0][3] = std::max({ a[0][0] + b[0][3], a[0][1] + b[1][3], a[0][2] + b[2][3], a[0][3] + b[3][3] });
r[1][0] = std::max({ a[1][0] + b[0][0], a[1][1] + b[1][0], a[1][2] + b[2][0], a[1][3] + b[3][0] });
r[1][1] = std::max({ a[1][0] + b[0][1], a[1][1] + b[1][1], a[1][2] + b[2][1], a[1][3] + b[3][1] });
r[1][2] = std::max({ a[1][0] + b[0][2], a[1][1] + b[1][2], a[1][2] + b[2][2], a[1][3] + b[3][2] });
r[1][3] = std::max({ a[1][0] + b[0][3], a[1][1] + b[1][3], a[1][2] + b[2][3], a[1][3] + b[3][3] });
r[2][0] = std::max({ a[2][0] + b[0][0], a[2][1] + b[1][0], a[2][2] + b[2][0], a[2][3] + b[3][0] });
r[2][1] = std::max({ a[2][0] + b[0][1], a[2][1] + b[1][1], a[2][2] + b[2][1], a[2][3] + b[3][1] });
r[2][2] = std::max({ a[2][0] + b[0][2], a[2][1] + b[1][2], a[2][2] + b[2][2], a[2][3] + b[3][2] });
r[2][3] = std::max({ a[2][0] + b[0][3], a[2][1] + b[1][3], a[2][2] + b[2][3], a[2][3] + b[3][3] });
r[3][0] = std::max({ a[3][0] + b[0][0], a[3][1] + b[1][0], a[3][2] + b[2][0], a[3][3] + b[3][0] });
r[3][1] = std::max({ a[3][0] + b[0][1], a[3][1] + b[1][1], a[3][2] + b[2][1], a[3][3] + b[3][1] });
r[3][2] = std::max({ a[3][0] + b[0][2], a[3][1] + b[1][2], a[3][2] + b[2][2], a[3][3] + b[3][2] });
r[3][3] = std::max({ a[3][0] + b[0][3], a[3][1] + b[1][3], a[3][2] + b[2][3], a[3][3] + b[3][3] });
rreturn r;
}
constexpr int N(5e5 + 5), INF(1e9);
int n, m, val[N], fa[N], son[N], siz[N];
std::vector<int> g[N];
void dfs1(int x) {
[x] = 1;
sizfor (int y : g[x]) {
if (y == fa[x]) continue;
[y] = x;
fa(y);
dfs1[x] += siz[y];
sizif (siz[y] > siz[son[x]]) son[x] = y;
}
}
struct Node {
*ls, *rs, *fa;
Node , sum;
Matrix valinline void pushup() {
= ls ? ls->sum * val : val;
sum if (rs) sum = sum * rs->sum;
}
} t[N];
std::multiset<int> vir[N][3];
void updateVal(int x) {
auto &v = t[x].val;
auto p0 = vir[x][0].rbegin(), p1 = vir[x][1].rbegin();
int k = val[son[x]], m0 = *p0, m1 = *p1;
[0][k] = 1, v[0][!k] = v[0][2] = -INF, v[0][3] = m0;
v[1][!k] = 1, v[1][k] = v[1][2] = -INF, v[1][3] = m1;
v[2][k] = m0 + 1, v[2][!k] = m1 + 1, v[2][2] = 0;
v[2][3] = vir[x][0].size() > 1 ? std::max(m0 + *++p0, m1 + *++p1) : 0;
v(v[2][3], *vir[x][2].rbegin());
smax[3][0] = v[3][1] = v[3][2] = -INF, v[3][3] = 0;
v}
int f[N][3];
void dfs2(int x) {
[x][0] = f[x][2] = 0, f[x][1] = -INF;
f[x][0].insert(0), vir[x][1].insert(-INF), vir[x][2].insert(0);
virfor (int y : g[x]) {
if (y == fa[x]) continue;
(y);
dfs2(f[x][2], f[y][2]);
smax(f[x][2], f[x][0] + f[y][val[y]] + 1);
smax(f[x][2], f[x][1] + f[y][!val[y]] + 1);
smax(f[x][0], f[y][val[y]] + 1);
smax(f[x][1], f[y][!val[y]] + 1);
smaxif (y == son[x]) continue;
[x][0].insert(f[y][val[y]] + 1);
vir[x][1].insert(f[y][!val[y]] + 1);
vir[x][2].insert(f[y][2]);
vir}
(x);
updateVal}
int top[N], s[N], sum[N];
* build(int l, int r) {
Nodeif (l == r) return t[s[l]].pushup(), t + s[l];
int m = l, tot = sum[r] + sum[l - 1];
while (m < r && sum[m] << 1 < tot) m++;
*o = t + s[m];
Node if (l < m) o->ls = build(l, m - 1), o->ls->fa = o;
if (m < r) o->rs = build(m + 1, r), o->rs->fa = o;
->pushup();
oreturn o;
}
*root;
Node void dfs3(int x, int tp) {
[x] = tp;
topif (x == tp) {
int m = 0;
for (int i = x; i; i = son[i]) {
[++m] = i;
s[m] = sum[m - 1] + siz[i] - siz[son[i]];
sum}
if (x > 1)
(1, m)->fa = t + fa[x];
buildelse
= build(1, m);
root }
if (!son[x]) return;
(son[x], tp);
dfs3for (int y : g[x]) {
if (y == fa[x] || y == son[x]) continue;
(y, y);
dfs3}
}
inline void erase(std::multiset<int> &s, int x) {
if (x < 0)
assert(*s.begin() < 0), s.erase(s.begin());
else
.erase(s.find(x));
s}
void work(int x, int y) {
assert(fa[y] == x);
*o = t + x;
Node int &v = val[y];
if (son[x] == y) {
^= 1;
v } else {
*p = t + y;
Node while (p->fa != o) p = p->fa;
(vir[x][0], p->sum[v][3] + 1);
erase(vir[x][1], p->sum[!v][3] + 1);
erase^= 1;
v [x][0].insert(p->sum[v][3] + 1);
vir[x][1].insert(p->sum[!v][3] + 1);
vir}
for (updateVal(x); o->fa; o = o->fa) {
if (o->fa->ls != o && o->fa->rs != o) {
= o->fa - t, y = val[top[o - t]];
x (vir[x][0], o->sum[y][3] + 1);
erase(vir[x][1], o->sum[!y][3] + 1);
erase(vir[x][2], o->sum[2][3]);
erase->pushup();
o[x][0].insert(o->sum[y][3] + 1);
vir[x][1].insert(o->sum[!y][3] + 1);
vir[x][2].insert(o->sum[2][3]);
vir(x);
updateVal} else {
->pushup();
o}
}
->pushup();
o}
struct Edge {
int x, y, z;
} e[N];
int main() {
(n);
readIntfor (int i = 1; i < n; i++) {
auto &[x, y, z] = e[i];
(x, y, z);
readInt[x].push_back(y), g[y].push_back(x);
g}
(1);
dfs1for (int i = 1; i < n; i++) {
auto &[x, y, z] = e[i];
if (fa[y] != x) std::swap(x, y);
assert(fa[y] == x);
[y] = z;
val}
(1);
dfs2(1, 1);
dfs3(m);
readIntwhile (m--) {
int i; readInt(i);
(e[i].x, e[i].y);
work("%d\n", root->sum[2][3]);
printf}
return 0;
}
Top Tree
旧的Top Tree 写法
#include <bits/stdc++.h>
#define dbg(...) \
std::cerr << "\033[32;1m", fprintf(stderr, __VA_ARGS__), \
std::cerr << "\033[0m"
template <class T, class U>
inline bool smin(T &x, const U &y) {
return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
return x < y ? x = y, 1 : 0;
}
using LL = long long;
using PII = std::pair<int, int>;
constexpr int N(5e5 + 5);
struct Info {
int c, len, d[2][2], r[2];
() : c(0), len(0) {
Info[0][0] = 0, d[0][1] = -1e9;
d[1][0] = 0, d[1][1] = -1e9;
d[0] = r[1] = -1e9;
r}
void reverse() {
std::swap(d[0][0], d[1][0]);
std::swap(d[0][1], d[1][1]);
}
&compress(Info a, Info b) {
Info = a.c ^ b.c;
c = a.len + b.len;
len [0][0] = std::max(a.d[0][0], a.len + b.d[0][a.c]);
d[0][1] = std::max(a.d[0][1], a.len + b.d[0][!a.c]);
d[1][0] = std::max(b.d[1][0], b.len + a.d[1][b.c]);
d[1][1] = std::max(b.d[1][1], b.len + a.d[1][!b.c]);
d[0] = std::max(
r{a.r[0], b.r[0], a.d[1][0] + b.d[0][0], a.d[1][1] + b.d[0][1]});
[1] = std::max(
r{a.r[1], b.r[1], a.d[1][0] + b.d[0][1], a.d[1][1] + b.d[0][0]});
return *this;
}
&rake(Info a, Info b) {
Info = b.c;
c = b.len;
len [0][0] = std::max(a.d[0][0], b.d[0][0]);
d[0][1] = std::max(a.d[0][1], b.d[0][1]);
d[1][0] = std::max(b.d[1][0], b.len + a.d[0][b.c]);
d[1][1] = std::max(b.d[1][1], b.len + a.d[0][!b.c]);
d[0] = std::max(
r{a.r[0], b.r[0], a.d[0][0] + b.d[0][0], a.d[0][1] + b.d[0][1]});
[1] = std::max(
r{a.r[1], b.r[1], a.d[0][0] + b.d[0][1], a.d[0][1] + b.d[0][0]});
return *this;
}
};
struct SplayTree {
static SplayTree *null;
*ch[3], *fa;
SplayTree bool rev;
;
Info info() : rev(false), info() {
SplayTreestatic bool init = true;
if (init) {
= false;
init = new SplayTree;
null ->ch[0] = null->ch[1] = null->ch[2] = null->fa = null;
null}
[0] = ch[1] = ch[2] = fa = null;
ch}
bool notRoot() const { return fa->ch[0] == this || fa->ch[1] == this; }
void reverse() {
^= 1;
rev std::swap(ch[0], ch[1]);
.reverse();
info}
bool dir() const { return fa->ch[1] == this; }
void sch(int d, SplayTree *c) { ch[d] = c, c->fa = this; }
void rotate() {
*p = fa;
SplayTree bool d = dir();
= p->fa;
fa if (fa != null) {
->ch[fa->ch[2] == p ? 2 : p->dir()] = this;
fa}
->sch(d, ch[!d]), sch(!d, p);
p->pushup();
p}
void splay() {
static SplayTree *s[N], **t;
= s;
t for (auto x = this; x->notRoot(); x = x->fa) *++t = x->fa;
while (t != s) (*t--)->pushdown();
for (pushdown(); notRoot(); rotate()) {
if (fa->notRoot()) {
(fa->dir() == dir() ? fa : this)->rotate();
}
}
();
pushup}
virtual void pushup() {}
virtual void pushdown() {}
virtual ~SplayTree() {}
};
*SplayTree::null;
SplayTree
struct RakeTree : SplayTree {
void pushup() {
.rake(ch[0]->info, Info().rake(ch[2]->info, ch[1]->info));
info}
};
struct CompressTree : SplayTree {
void pushup() {
.compress(ch[0]->info, Info().rake(ch[2]->info, ch[1]->info));
info// dbg("vt pushup!\n");
}
void pushdown() {
if (rev) {
= false;
rev [0]->reverse(), ch[1]->reverse();
ch}
}
void access() {
();
splayif (ch[1] != null) {
auto r = new RakeTree;
->sch(0, ch[2]), r->sch(2, ch[1]), r->pushup();
r[1] = null, sch(2, r), pushup();
ch}
for (; fa != null; rotate()) {
->splay();
fa*m = fa, *p = m->fa;
SplayTree assert(p != null);
->splay();
pif (p->ch[1] != null) {
->sch(2, p->ch[1]);
m->pushup();
m} else {
if (m->ch[0] == null) {
->sch(2, m->ch[1]);
p} else if (m->ch[1] == null) {
->sch(2, m->ch[0]);
p} else {
auto x = m->ch[0];
->fa = null;
xwhile (x->pushdown(), x->ch[1] != null) x = x->ch[1];
->splay();
x->sch(2, x);
p->sch(1, m->ch[1]);
x->pushup();
x}
delete m;
}
->sch(1, this);
p->pushup();
p}
();
pushup}
void evert() { access(), reverse(), pushdown(); }
};
void link(CompressTree *x, CompressTree *y, int c) {
->access(), y->evert();
xauto e = new SplayTree;
->info.len = 1, e->info.c = c;
e->sch(1, y), y->sch(0, e);
x->pushup(), x->pushup();
y}
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n, m;
std::cin >> n;
*t = new CompressTree[n + 1];
CompressTree assert(SplayTree::null);
std::vector<int> u(n - 1), v(n - 1);
for (int i = 0, c; i < n - 1; i++) {
std::cin >> u[i] >> v[i] >> c;
(t + u[i], t + v[i], c);
link}
std::cin >> m;
while (m--) {
int i;
std::cin >> i;
--;
i[u[i]].evert(), t[v[i]].access();
t[u[i]].ch[1]->info.c ^= 1;
t[u[i]].pushup();
t[v[i]].pushup();
tstd::cout << t[v[i]].info.r[0] << "\n";
}
return 0;
}
其他解法
比如线段树、括号序列,留坑。
最后修改于 2021-08-13