[Codeforces1413F]Roads and Ramen

给定一棵 \(n\) 个点的无根树,每条边都有标记0或1。现在有 \(m\) 次操作,每次操作将第 \(id\) 条边的标记异或1。在每次操作后输出当前树上最长的有偶数个1的路径长度。

\(n,m \leq 5 \times 10^5\)

动态DP

\(dp[x][0/1]\) 表示 \(x\) 向下有偶数/奇数个1的最长路径,\(dp[x][2]\) 表示 \(x\) 子树内偶数个1的直径。

容易写出转移(\(val_y\) 表示 \(y\) 的父边的权值): \[ \begin{aligned} dp'[x][0] &= \max\{dp[x][0], dp[y][val_y] + 1\}\\ dp'[x][1] &= \max\{dp[x][1], dp[y][val_y \oplus 1] + 1\}\\ dp'[x][2] &= \max\{dp[x][0] + dp[y][val_y] + 1, dp[x][1] + dp[y][val_y \oplus 1] + 1, dp[x][2],dp[y][2]\} \end{aligned} \] 改写成 \(x\) 从重儿子 \(y\) 转移的矩阵(这里以 \(val_y = 0\) 为例): \[ \begin{bmatrix} dp'[x][0]\\ dp'[x][1]\\ dp'[x][2]\\ 0 \end{bmatrix} = \begin{bmatrix} 1 & -\infty & -\infty & dp[x][0]\\ -\infty & 1 & -\infty & dp[x][1]\\ dp[x][0] + 1 & dp[x][1] + 1 & 0 & dp[x][2]\\ -\infty & -\infty & -\infty & 0 \end{bmatrix} \begin{bmatrix} dp[y][0]\\ dp[y][1]\\ dp[y][2]\\ 0 \end{bmatrix} \] 上面的 \(dp[x]\) 是指排除掉 \(x\) 重儿子的答案,\(dp'[x]\) 是指完整的 \(x\) 的答案。

\(dp[x]\) 考虑直接用 std::multiset 存一下每个节点虚儿子的 \(dp[y][0], dp[y][1], dp[y][2]\),取每个的最大和次大值即可快速求出。

复杂度是 \(O(4^3 m \log n + m \log^2n)\),能跑过 \(5 \times 10^5\) 也是奇迹。

细节:在 multiset 中删除 \(-\infty\) 时不能直接查找对应的值,应该直接删除 multiset 中最小的那个值(具体见函数 void erase(std::multiset<int>&, int))。

#include <bits/stdc++.h>
#ifdef LOCAL
#define dbg(args...) std::cerr << "\033[32;1m" << #args << " -> ", err(args)
#else
#define dbg(...)
#endif
inline void err() { std::cerr << "\033[0m\n"; }
template<class T, class... U>
inline void err(const T &x, const U &... a) { std::cerr << x << ' '; err(a...); }
template <class T>
inline void readInt(T &w) {
  char c, p = 0;
  while (!isdigit(c = getchar())) p = c == '-';
  for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
  if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
template <class T, class U>
inline bool smin(T &x, const U &y) { return y < x ? x = y, 1 : 0; }
template <class T, class U>
inline bool smax(T &x, const U &y) { return x < y ? x = y, 1 : 0; }

typedef long long LL;
typedef std::pair<int, int> PII;

using Matrix = std::array<std::array<int, 4>, 4>;
Matrix operator*(const Matrix &a, const Matrix &b) {
  Matrix r;
  r[0][0] = std::max({ a[0][0] + b[0][0], a[0][1] + b[1][0], a[0][2] + b[2][0], a[0][3] + b[3][0] });
  r[0][1] = std::max({ a[0][0] + b[0][1], a[0][1] + b[1][1], a[0][2] + b[2][1], a[0][3] + b[3][1] });
  r[0][2] = std::max({ a[0][0] + b[0][2], a[0][1] + b[1][2], a[0][2] + b[2][2], a[0][3] + b[3][2] });
  r[0][3] = std::max({ a[0][0] + b[0][3], a[0][1] + b[1][3], a[0][2] + b[2][3], a[0][3] + b[3][3] });
  r[1][0] = std::max({ a[1][0] + b[0][0], a[1][1] + b[1][0], a[1][2] + b[2][0], a[1][3] + b[3][0] });
  r[1][1] = std::max({ a[1][0] + b[0][1], a[1][1] + b[1][1], a[1][2] + b[2][1], a[1][3] + b[3][1] });
  r[1][2] = std::max({ a[1][0] + b[0][2], a[1][1] + b[1][2], a[1][2] + b[2][2], a[1][3] + b[3][2] });
  r[1][3] = std::max({ a[1][0] + b[0][3], a[1][1] + b[1][3], a[1][2] + b[2][3], a[1][3] + b[3][3] });
  r[2][0] = std::max({ a[2][0] + b[0][0], a[2][1] + b[1][0], a[2][2] + b[2][0], a[2][3] + b[3][0] });
  r[2][1] = std::max({ a[2][0] + b[0][1], a[2][1] + b[1][1], a[2][2] + b[2][1], a[2][3] + b[3][1] });
  r[2][2] = std::max({ a[2][0] + b[0][2], a[2][1] + b[1][2], a[2][2] + b[2][2], a[2][3] + b[3][2] });
  r[2][3] = std::max({ a[2][0] + b[0][3], a[2][1] + b[1][3], a[2][2] + b[2][3], a[2][3] + b[3][3] });
  r[3][0] = std::max({ a[3][0] + b[0][0], a[3][1] + b[1][0], a[3][2] + b[2][0], a[3][3] + b[3][0] });
  r[3][1] = std::max({ a[3][0] + b[0][1], a[3][1] + b[1][1], a[3][2] + b[2][1], a[3][3] + b[3][1] });
  r[3][2] = std::max({ a[3][0] + b[0][2], a[3][1] + b[1][2], a[3][2] + b[2][2], a[3][3] + b[3][2] });
  r[3][3] = std::max({ a[3][0] + b[0][3], a[3][1] + b[1][3], a[3][2] + b[2][3], a[3][3] + b[3][3] });
  return r;
}

constexpr int N(5e5 + 5), INF(1e9);

int n, m, val[N], fa[N], son[N], siz[N];

std::vector<int> g[N];

void dfs1(int x) {
  siz[x] = 1;  
  for (int y : g[x]) {
    if (y == fa[x]) continue;
    fa[y] = x;
    dfs1(y);
    siz[x] += siz[y];
    if (siz[y] > siz[son[x]]) son[x] = y;
  }
}

struct Node {
  Node *ls, *rs, *fa;
  Matrix val, sum;
  inline void pushup() {
    sum = ls ? ls->sum * val : val;
    if (rs) sum = sum * rs->sum;
  }
} t[N];
std::multiset<int> vir[N][3];
void updateVal(int x) {
  auto &v = t[x].val;
  auto p0 = vir[x][0].rbegin(), p1 = vir[x][1].rbegin();
  int k = val[son[x]], m0 = *p0, m1 = *p1;
  v[0][k] = 1, v[0][!k] = v[0][2] = -INF, v[0][3] = m0;
  v[1][!k] = 1, v[1][k] = v[1][2] = -INF, v[1][3] = m1;
  v[2][k] = m0 + 1, v[2][!k] = m1 + 1, v[2][2] = 0;
  v[2][3] = vir[x][0].size() > 1 ? std::max(m0 + *++p0, m1 + *++p1) : 0;
  smax(v[2][3], *vir[x][2].rbegin());
  v[3][0] = v[3][1] = v[3][2] = -INF, v[3][3] = 0;
}
int f[N][3];
void dfs2(int x) {
  f[x][0] = f[x][2] = 0, f[x][1] = -INF;
  vir[x][0].insert(0), vir[x][1].insert(-INF), vir[x][2].insert(0);
  for (int y : g[x]) {
    if (y == fa[x]) continue;
    dfs2(y);
    smax(f[x][2], f[y][2]);
    smax(f[x][2], f[x][0] + f[y][val[y]] + 1);
    smax(f[x][2], f[x][1] + f[y][!val[y]] + 1);
    smax(f[x][0], f[y][val[y]] + 1);
    smax(f[x][1], f[y][!val[y]] + 1);
    if (y == son[x]) continue;
    vir[x][0].insert(f[y][val[y]] + 1);
    vir[x][1].insert(f[y][!val[y]] + 1);
    vir[x][2].insert(f[y][2]);
  }
  updateVal(x);
}
int top[N], s[N], sum[N];
Node* build(int l, int r) {
  if (l == r) return t[s[l]].pushup(), t + s[l];
  int m = l, tot = sum[r] + sum[l - 1];
  while (m < r && sum[m] << 1 < tot) m++;
  Node *o = t + s[m];
  if (l < m) o->ls = build(l, m - 1), o->ls->fa = o;
  if (m < r) o->rs = build(m + 1, r), o->rs->fa = o;
  o->pushup();
  return o;
}
Node *root;
void dfs3(int x, int tp) {
  top[x] = tp;
  if (x == tp) {
    int m = 0;
    for (int i = x; i; i = son[i]) {
      s[++m] = i;
      sum[m] = sum[m - 1] + siz[i] - siz[son[i]];
    }
    if (x > 1)
      build(1, m)->fa = t + fa[x];
    else
      root = build(1, m);
  }
  if (!son[x]) return;
  dfs3(son[x], tp);
  for (int y : g[x]) {
    if (y == fa[x] || y == son[x]) continue;
    dfs3(y, y);
  }
}
inline void erase(std::multiset<int> &s, int x) {
  if (x < 0)
    assert(*s.begin() < 0), s.erase(s.begin());
  else 
    s.erase(s.find(x));
}
void work(int x, int y) {
  assert(fa[y] == x);
  Node *o = t + x;
  int &v = val[y];
  if (son[x] == y) {
    v ^= 1;
  } else {
    Node *p = t + y;
    while (p->fa != o) p = p->fa;
    erase(vir[x][0], p->sum[v][3] + 1);
    erase(vir[x][1], p->sum[!v][3] + 1);
    v ^= 1;
    vir[x][0].insert(p->sum[v][3] + 1);
    vir[x][1].insert(p->sum[!v][3] + 1);
  }
  for (updateVal(x); o->fa; o = o->fa) {
    if (o->fa->ls != o && o->fa->rs != o) {
      x = o->fa - t, y = val[top[o - t]];
      erase(vir[x][0], o->sum[y][3] + 1);
      erase(vir[x][1], o->sum[!y][3] + 1);
      erase(vir[x][2], o->sum[2][3]);
      o->pushup();
      vir[x][0].insert(o->sum[y][3] + 1);
      vir[x][1].insert(o->sum[!y][3] + 1);
      vir[x][2].insert(o->sum[2][3]);
      updateVal(x);
    } else {
      o->pushup();
    }
  }
  o->pushup();
}
struct Edge {
  int x, y, z;
} e[N];
int main() {
  readInt(n);
  for (int i = 1; i < n; i++) {
    auto &[x, y, z] = e[i];
    readInt(x, y, z);
    g[x].push_back(y), g[y].push_back(x);
  }
  dfs1(1);
  for (int i = 1; i < n; i++) {
    auto &[x, y, z] = e[i];
    if (fa[y] != x) std::swap(x, y);
    assert(fa[y] == x);
    val[y] = z;
  }
  
  dfs2(1);
  dfs3(1, 1);
  readInt(m);
  while (m--) {
    int i; readInt(i);
    work(e[i].x, e[i].y);
    printf("%d\n", root->sum[2][3]);
  }
  return 0;
}

Top Tree

旧的Top Tree 写法

#include <bits/stdc++.h>
#define dbg(...)                                           \
  std::cerr << "\033[32;1m", fprintf(stderr, __VA_ARGS__), \
      std::cerr << "\033[0m"
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}
using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(5e5 + 5);

struct Info {
  int c, len, d[2][2], r[2];
  Info() : c(0), len(0) {
    d[0][0] = 0, d[0][1] = -1e9;
    d[1][0] = 0, d[1][1] = -1e9;
    r[0] = r[1] = -1e9;
  }
  void reverse() {
    std::swap(d[0][0], d[1][0]);
    std::swap(d[0][1], d[1][1]);
  }
  Info &compress(Info a, Info b) {
    c = a.c ^ b.c;
    len = a.len + b.len;
    d[0][0] = std::max(a.d[0][0], a.len + b.d[0][a.c]);
    d[0][1] = std::max(a.d[0][1], a.len + b.d[0][!a.c]);
    d[1][0] = std::max(b.d[1][0], b.len + a.d[1][b.c]);
    d[1][1] = std::max(b.d[1][1], b.len + a.d[1][!b.c]);
    r[0] = std::max(
        {a.r[0], b.r[0], a.d[1][0] + b.d[0][0], a.d[1][1] + b.d[0][1]});
    r[1] = std::max(
        {a.r[1], b.r[1], a.d[1][0] + b.d[0][1], a.d[1][1] + b.d[0][0]});
    return *this;
  }
  Info &rake(Info a, Info b) {
    c = b.c;
    len = b.len;
    d[0][0] = std::max(a.d[0][0], b.d[0][0]);
    d[0][1] = std::max(a.d[0][1], b.d[0][1]);
    d[1][0] = std::max(b.d[1][0], b.len + a.d[0][b.c]);
    d[1][1] = std::max(b.d[1][1], b.len + a.d[0][!b.c]);
    r[0] = std::max(
        {a.r[0], b.r[0], a.d[0][0] + b.d[0][0], a.d[0][1] + b.d[0][1]});
    r[1] = std::max(
        {a.r[1], b.r[1], a.d[0][0] + b.d[0][1], a.d[0][1] + b.d[0][0]});
    return *this;
  }
};


struct SplayTree {
  static SplayTree *null;
  SplayTree *ch[3], *fa;
  bool rev;
  Info info;
  SplayTree() : rev(false), info() {
    static bool init = true;
    if (init) {
      init = false;
      null = new SplayTree;
      null->ch[0] = null->ch[1] = null->ch[2] = null->fa = null;
    }
    ch[0] = ch[1] = ch[2] = fa = null;
  }
  bool notRoot() const { return fa->ch[0] == this || fa->ch[1] == this; }
  void reverse() {
    rev ^= 1;
    std::swap(ch[0], ch[1]);
    info.reverse();
  }
  bool dir() const { return fa->ch[1] == this; }
  void sch(int d, SplayTree *c) { ch[d] = c, c->fa = this; }
  void rotate() {
    SplayTree *p = fa;
    bool d = dir();
    fa = p->fa;
    if (fa != null) {
      fa->ch[fa->ch[2] == p ? 2 : p->dir()] = this;
    }
    p->sch(d, ch[!d]), sch(!d, p);
    p->pushup();
  }
  void splay() {
    static SplayTree *s[N], **t;
    t = s;
    for (auto x = this; x->notRoot(); x = x->fa) *++t = x->fa;
    while (t != s) (*t--)->pushdown();
    for (pushdown(); notRoot(); rotate()) {
      if (fa->notRoot()) {
        (fa->dir() == dir() ? fa : this)->rotate();
      }
    }
    pushup();
  }
  virtual void pushup() {}
  virtual void pushdown() {}
  virtual ~SplayTree() {}
};
SplayTree *SplayTree::null;

struct RakeTree : SplayTree {
  void pushup() {
    info.rake(ch[0]->info, Info().rake(ch[2]->info, ch[1]->info));
  }
};

struct CompressTree : SplayTree {
  void pushup() {
    info.compress(ch[0]->info, Info().rake(ch[2]->info, ch[1]->info));
    // dbg("vt pushup!\n");
  }
  void pushdown() {
    if (rev) {
      rev = false;
      ch[0]->reverse(), ch[1]->reverse();
    }
  }
  void access() {
    splay();
    if (ch[1] != null) {
      auto r = new RakeTree;
      r->sch(0, ch[2]), r->sch(2, ch[1]), r->pushup();
      ch[1] = null, sch(2, r), pushup();
    }
    for (; fa != null; rotate()) {
      fa->splay();
      SplayTree *m = fa, *p = m->fa;
      assert(p != null);
      p->splay();
      if (p->ch[1] != null) {
        m->sch(2, p->ch[1]);
        m->pushup();
      } else {
        if (m->ch[0] == null) {
          p->sch(2, m->ch[1]);
        } else if (m->ch[1] == null) {
          p->sch(2, m->ch[0]);
        } else {
          auto x = m->ch[0];
          x->fa = null;
          while (x->pushdown(), x->ch[1] != null) x = x->ch[1];
          x->splay();
          p->sch(2, x);
          x->sch(1, m->ch[1]);
          x->pushup();
        }
        delete m;
      }
      p->sch(1, this);
      p->pushup();
    }
    pushup();
  }
  void evert() { access(), reverse(), pushdown(); }
};
void link(CompressTree *x, CompressTree *y, int c) {
  x->access(), y->evert();
  auto e = new SplayTree;
  e->info.len = 1, e->info.c = c;
  x->sch(1, y), y->sch(0, e);
  y->pushup(), x->pushup();
}

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m;
  std::cin >> n;
  CompressTree *t = new CompressTree[n + 1];
  assert(SplayTree::null);
  std::vector<int> u(n - 1), v(n - 1);
  for (int i = 0, c; i < n - 1; i++) {
    std::cin >> u[i] >> v[i] >> c;
    link(t + u[i], t + v[i], c);
  }
  std::cin >> m;
  while (m--) {
    int i;
    std::cin >> i;
    i--;
    t[u[i]].evert(), t[v[i]].access();
    t[u[i]].ch[1]->info.c ^= 1;
    t[u[i]].pushup();
    t[v[i]].pushup();
    std::cout << t[v[i]].info.r[0] << "\n";
  }
  return 0;
}

其他解法

比如线段树、括号序列,留坑。


最后修改于 2021-08-13