2021牛客多校6
题号 标题 团队的状态
A Contracting Convex Hull 未通过
B Defend Ponyville 通过
C Delete Edges 通过
D Gambling Monster 通过
E Growing Tree 通过
F Hamburger Steak 通过
G Hasse Diagram 通过
H Hopping Rabbit 通过
I Intervals on the Ring 通过
J Defend Your Country 通过
K Starch Cat 通过

Contracting Convex Hull

Defend Ponyville

// Author:  HolyK
// Created: Fri Aug 13 20:54:58 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;
constexpr int P(998244353);
inline void inc(int &x, int y) {
  x += y;
  if (x >= P) x -= P;
}
inline void dec(int &x, int y) {
  x -= y;
  if (x < 0) x += P;
}
inline int mod(LL x) {
  return x % P;
}
int fpow(int x, int k = P - 2) {
  int r = 1;
  for (; k; k >>= 1, x = 1LL * x * x % P) {
    if (k & 1) r = 1LL * r * x % P;
  }
  return r;
}
struct Z {
  int x;
  Z(int v = 0) : x(v < 0 ? v + P : v >= P ? v - P : v) {}
  Z inv() const {
    assert(x);
    return Z(fpow(x));
  }
  Z power(int k) const {
    return Z(fpow(x, k));
  }
  Z &operator+=(const Z &r) {
    inc(x, r.x);
    return *this;
  }
  Z &operator-=(const Z &r) {
    dec(x, r.x);
    return *this;
  }
  Z &operator*=(const Z &r) {
    x = 1LL * x * r.x % P;
    return *this;
  }
  Z &operator/=(const Z &r) {
    x = 1LL * x * fpow(r.x) % P;
    return *this;
  }
  Z operator+(const Z &r) const {
    return Z(*this) += r;
  }
  Z operator-(const Z &r) const {
    return Z(*this) -= r;
  }
  Z operator*(const Z &r) const {
    return Z(*this) *= r;
  }
  Z operator/(const Z &r) const {
    return Z(*this) /= r;
  }
  Z operator-() const {
    return Z(P - x);
  }
  operator int() const {
    return x;
  }
};

// using Z = ModInt<P>;

// det(xI + A)
auto charPoly(std::vector<std::vector<Z>> a) {
  
  int n = a.size();
  for (int j = 0; j < n - 2; j++) {
    for (int i = j + 1; i < n; i++) {
      if (!a[i][j]) continue;
      std::swap(a[i], a[j + 1]);
      for (int k = 0; k < n; k++) {
        std::swap(a[k][i], a[k][j + 1]);
      }
      break;
    }
    if (a[j + 1][j]) {
      auto inv = a[j + 1][j].inv();
      for (int i = j + 2; i < n; i++) {
        auto c = a[i][j] * inv;
        for (int k = 0; k < n; k++) a[i][k] -= a[j + 1][k] * c;
        for (int k = 0; k < n; k++) a[k][j + 1] += a[k][i] * c;
      }
    }
  }
  std::vector<std::vector<Z>> h(n + 1);
  h[0] = {1};
  for (int i = 0; i < n; i++) {
    Z prod = 1;
    h[i + 1].resize(h[i].size() + 1);
    for (int j = 0; j < h[i].size(); j++) {
      h[i + 1][j + 1] = h[i][j];
      h[i + 1][j] += h[i][j] * a[i][i];
    }
    for (int j = 0; j < i; j++) {
      prod *= -a[i - j][i - j - 1];
      auto c = a[i - j - 1][i] * prod;
      for (int k = 0; k < h[i - j - 1].size(); k++) {
        h[i + 1][k] += h[i - j - 1][k] * c;
      }
    }
  }
  return h[n];
}

int main() {
  // freopen("06.in", "r", stdin);
  // freopen("t.out", "w", stdout);
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m, t, k;
  std::cin >> n >> m >> t >> k;
  std::vector a(n, std::vector<Z>(n)), b = a;
  auto out = [&](auto a) {
    // for (int i = 0; i < n; i++) {
    //   for (int j = 0; j < n; j++) {
    //     std::cerr << a[i][j] << " \n"[j + 1 == n];
    //   }
    // }
  };
  while (m--) {
    int x, y, a1, b1, a2, b2;
    std::cin >> x >> y >> a1 >> b1 >> a2 >> b2;
    x--, y--;
    Z p = Z(a1) / Z(b1);
    Z q = Z(a2) / Z(b2);
    p -= q;
    a[x][x] += p;
    a[y][y] += p;
    a[x][y] -= p;
    a[y][x] -= p;
    b[x][x] += q;
    b[y][y] += q;
    b[x][y] -= q;
    b[y][x] -= q;
  }
  n--;
  if (!n) return puts("0"), 0;
  a.pop_back();
  for (auto &v : a) v.pop_back();
  b.pop_back();
  for (auto &v : b) v.pop_back();
  out(a);
  Z prod(1);
  for (int i = 0, j; i < n; i++) {
    for (j = i; j < n; j++) {
      if (!a[j][i]) continue;
      if (i != j) {
        std::swap(a[i], a[j]);
        std::swap(b[i], b[j]);
        prod *= -1;
       }
      break;
    }
    if (j == n) return puts("0"), 0;
    prod *= a[i][i];
    Z inv = a[i][i].inv();
    for (int j = i; j < n; j++) a[i][j] *= inv;
    for (int j = 0; j < n; j++) b[i][j] *= inv;
    for (j++; j < n; j++) {
      if (!a[j][i]) continue;
      Z x = a[j][i];
      for (int k = i; k < n; k++) a[j][k] -= x * a[i][k];
      for (int k = 0; k < n; k++) b[j][k] -= x * b[i][k];
    }
  }
  for (int i = n - 1; i; i--) {
    for (int j = 0; j < i; j++) {
      if (!a[j][i]) continue;
      for (int k = 0; k < n; k++) b[j][k] -= a[j][i] * b[i][k];
      a[j][i] = 0;
    }
  }
  
  out(a);
  out(b);
  // for (auto &x : b) for (auto &y : x) y = -y;
  // std::cerr << "prod " << prod.val() << "\n";
  auto c = charPoly(b);
  for (auto &x : c) x *= prod;
  Z ans = c[0] * Z(t), v = k = fpow(k);
  // assert(c.size() == n);
  for (int i = 1; i < c.size(); i++, v *= k) {
    ans += k == 1 ? c[i] * Z(t) : c[i] * (v.power(t) - Z(1)) / (v - Z(1));
  }
  std::cout << ans << "\n";
  return 0;
}

Delete Edges

Gambling Monster

分治fwt。

// Author:  HolyK
// Created: Mon Aug  2 14:51:08 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;
constexpr int P(1e9 + 7);
inline int &inc(int &x, int y) {
  x += y;
  if (x >= P) x -= P;
  return x;
}
inline int sum(int x, int y) {
  return x + y >= P ? x + y - P : x + y;
}
inline int mod(LL x) {
  return x % P;
}
int fpow(int x, int k = P - 2) {
  int r = 1;
  for (; k; k >>= 1, x = 1LL * x * x % P) {
    if (k & 1) r = 1LL * r * x % P;
  }
  return r;
}
using Poly = std::vector<int>;
class FwtBase {
 public:
  FwtBase() = default;
  Poly conv(Poly a, Poly b) {
    assert(a.size() == b.size());
    fwt(a), fwt(b);
    for (int i = 0; i < (int)a.size(); i++)
      a[i] = 1LL * a[i] * b[i] % P;
    ifwt(a);
    return a;
  }
  virtual ~FwtBase() = default;
 private:
  virtual void fwt(Poly &a) {}
  virtual void ifwt(Poly &a) {}
};

class FwtXor: public FwtBase {
  void fwt(Poly &a) {
    int n = a.size();
    assert((n & n - 1) == 0);
    for (int k = 1; k < n; k <<= 1)
      for (int i = 0; i < n; i += k << 1)
        for (int j = 0, x, y; j < k; j++) {
          x = a[i + j], y = a[i + j + k];
          inc(a[i + j], y);
          a[i + j + k] = sum(x, P - y);
        }
  }
  void ifwt(Poly &a) {
    int n = a.size();
    auto shift = [](int &x) { x = x & 1 ? x + P >> 1 : x >> 1; };
    for (int k = n >> 1; k; k >>= 1)
      for (int i = 0; i < n; i += k << 1)
        for (int j = 0, x, y; j < k; j++) {
          x = a[i + j], y = a[i + j + k];
          shift(inc(a[i + j], y));
          shift(a[i + j + k] = sum(x, P - y));
        }
  }
} fwt_xor;

int n, p[1 << 16], s[1 << 16], f[1 << 16];
void solve() {
  std::cin >> n;
  for (int i = 0; i < n; i++) {
    std::cin >> p[i];
    s[i] = p[i];
    f[i] = 0;
  }
  for (int i = 1; i < n; i++) {
    inc(s[i], s[i - 1]);
  }
  int inv = fpow(s[n - 1]);
  for (int i = 0; i < n; i++) {
    p[i] = mod(1LL * p[i] * inv);
    s[i] = mod(1LL * s[i] * inv);
  }
  // for (int i = n - 1; i >= 0; i--) {
  //   int sum = 0;
  //   for (int j = i + 1; j < n; j++) {
  //     inc(sum, p[i ^ j]);
  //     f[i] = mod(f[i] + 1LL * f[j] * p[i ^ j]);
  //   }
  //   f[i] = (f[i] + 1LL) * fpow(sum) % P;
  // }
  int d = std::__lg(n);
  for (int m = n - 1; m >= 0; m--) {
    int sum = 0;
    for (int i = d - 1, p = n >> 1; i >= 0; i--, p >>= 1) {
      if (!(m & p)) {
        inc(sum, s[p * 2 - 1]);
        inc(sum, P - s[p - 1]);
      }
    }
    f[m] = mod(1LL * fpow(sum) * (f[m] + 1));
    if (!m) break;
    int l = m & m - 1, r = m + (m & -m);
    Poly a(f + l, f + r), b(p, p + (r - l));
    for (int i = 0; i < r - m; i++) a[i] = 0;
    a = fwt_xor.conv(a, b);
    for (int i = 0; i < r - m; i++) {
      inc(f[l + i], a[i]);
    }
  }
  std::cout << f[0] << "\n";
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int t;
  std::cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

Growing Tree

重量平衡树维护括号序,可以用旋转 treap 或者替罪羊。

关于重量平衡树,是否可以用非旋转 treap 实现以及是否可以可持久化有待探讨。

// Author:  HolyK
// Created: Fri Aug  6 10:11:54 2021
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
template <class T, class Comp = std::less<T>>
using Tree = __gnu_pbds::tree<T, __gnu_pbds::null_type, Comp, __gnu_pbds::rb_tree_tag, __gnu_pbds::tree_order_statistics_node_update>;

template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(5e5 + 5);
struct Node {
  Node *ch[2], *fa;
  int rnd;
  LL v, l;
  void sch(int d, Node *p) {
    ch[d] = p;
    if (p) p->fa = this;
  }
  int dir() const {
    return this == fa->ch[1];
  }
  void rotate() {
    Node *f = fa, *ff = f->fa;
    int k = dir();
    if (ff) ff->ch[f->dir()] = this;
    fa = ff;
    f->sch(k, ch[!k]), sch(!k, f);
  }
  void label() {
    if (fa) {
      l = fa->l / 2;
      v = fa->v + l * (dir() ? 1 : -1);
    } else {
      v = l = 1LL << 60;
    }
    
    if (ch[0]) ch[0]->label();
    if (ch[1]) ch[1]->label();
  }
} t[N * 2], *cur;

Node *newNode() {
  static std::mt19937 rng(19260817);
  cur->ch[0] = cur->ch[1] = cur->fa = nullptr;
  cur->rnd = rng();
  return cur++;
}
Node *ins(Node *o) {
  Node *p = newNode();
  if (!o->ch[1]) {
    o->sch(1, p);
  } else {
    o = o->ch[1];
    while (o->ch[0]) o = o->ch[0];
    o->sch(0, p);
  }
  while (p->fa && p->rnd < p->fa->rnd) {
    p->rotate();
  }
  p->label();
  return p;
}
struct Comp {
  bool operator()(Node *const &a, Node *const &b) const {
    return a->v < b->v;
  }
};
Tree<Node*, Comp> s[N];
Node *l[N], *r[N];
int c[N];
void solve() {
  cur = t;
  l[1] = newNode();
  l[1]->label();
  r[1] = ins(l[1]);
  int m;
  std::cin >> c[1] >> m;
  s[c[1]].insert(l[1]);
  int ans = 0, n = 1;
  while (m--) {
    int opt, x, y;
    std::cin >> opt >> x >> y;
    opt ^= ans, x ^= ans, y ^= ans;
    if (opt == 1) {
      l[++n] = ins(l[x]);
      r[n] = ins(l[n]);
      c[n] = y;
      s[y].insert(l[n]);
    } else if (opt == 2) {
      ans = s[y].order_of_key(r[x]) - s[y].order_of_key(l[x]);
      std::cout << ans << "\n";
    } else {
      assert(false);
    }
  }
  for (int i = 1; i <= n; i++) {
    s[c[i]].clear();
  }
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int t;
  std::cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

Hamburger Steak

二分后贪心。

// Author:  HolyK
// Created: Mon Aug  2 12:51:09 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(1e5 + 5);
int a[N];
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m;
  std::cin >> n >> m;
  LL ans = 0, sum = 0;
  
  for (int i = 1; i <= n; i++) {
    std::cin >> a[i];
    sum += a[i];
    smax(ans, a[i]);
  }
  smax(ans, (sum + m - 1) / m);
  LL v = 0;
  int id = 1;
  for (int i = 1; i <= n; i++) {
    
    if (v + a[i] <= ans) {
      std::cout << 1 << " " << id << " " << v << " " << v + a[i] << "\n";
      v += a[i];
    } else {
      std::cout << 2 << " " << id + 1 << " " << 0 << " " << v + a[i] - ans << " " << id << " " << v << " " << ans << "\n";
      id++;
      v = v + a[i] - ans;
    }
    if (v == ans) id++, v = 0;
  }
  return 0;
}

Hasse Diagram

\[ f(n) = (e + 1) f(\frac{n}{p^e}) + e\operatorname{d}(\frac{n}{p^e}) \]

\(\operatorname{d}(n)\) 表示因数个数。

min_25筛同时筛 \(f, d\)

// Author:  HolyK
// Created: Sat Aug  7 08:42:22 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;
using u32 = uint32_t;
using u64 = uint64_t;
constexpr u32 P(1145140019u);
inline void inc(u32 &x, u32 y) {
  x += y;
  if (x >= P) x -= P;
}
inline u32 mod(u64 x) {
  return x % P;
}
u32 fpow(u32 x, u32 k = P - 2) {
  u32 r = 1;
  for (; k; k >>= 1, x = 1LL * x * x % P) {
    if (k & 1) r = 1LL * r * x % P;
  }
  return r;
}
constexpr int N(2e5 + 5);
bool np[N];
u32 primes[N], cnt, g0[N];
void sieve(u32 n) {
  for (u32 i = 2, j; i <= n; i++) {
    if (!np[i]) {
      primes[++cnt] = i;
    }
    for (j = 1; j <= cnt && i * primes[j] <= n; j++) {
      np[i * primes[j]] = true;
      if (i % primes[j] == 0) break;
    }
  }
}
u64 val[N];
int id1[N], id2[N];
void solve() {
  u64 n;
  std::cin >> n;
  int m = 0;
  for (u64 i = 1, j; i <= n; i = j + 1) {
    val[++m] = n / i;
    j = n / val[m];
    g0[m] = (val[m] - 1) % P;
    if (val[m] <= 1e5) {
      id1[val[m]] = m;
    } else {
      id2[j] = m;
    }
  }
  for (u32 i = 1; i <= cnt; i++) {
    u64 limit = u64(primes[i]) * primes[i];
    if (limit > n) break;
    for (u32 j = 1; val[j] >= limit; j++) {
      u64 x = val[j] / primes[i];
      int k = x <= 1e5 ? id1[x] : id2[n / x];
      g0[j] = mod(g0[j] + (P - g0[k]) + i - 1);
    }
  }

  auto cal = [&](auto rec, u64 x, u32 i) -> std::pair<u32, u32> {
    if (x == 1) return {0, 1};
    if (primes[i] > x) return {0, 0};
    int k = x <= 1e5 ? id1[x] : id2[n / x];
    u32 sum_f = mod(g0[k] + P - i + 1), sum_d = sum_f * 2 % P;
    for (u32 j = i; j <= cnt && u64(primes[j]) * primes[j] <= x; j++) {
      for (u64 k = 1, s = primes[j]; s * primes[j] <= x; k++, s *= primes[j]) {
        auto [f, d] = rec(rec, x / s, j + 1);
        sum_f = mod(sum_f + k + 1 + (k + 1) * f + k * d);
        sum_d = mod(sum_d + k + 2 + (k + 1) * d);
      }
    }
    return {sum_f, sum_d};
  };
  std::cout << cal(cal, n, 1).first << "\n";
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  sieve(1e5);
  int t;
  std::cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

Hopping Rabbit

线段树维护矩形并。

// Author:  HolyK
// Created: Mon Aug  2 13:55:48 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(1e5 + 5);
#define ls o << 1
#define rs o << 1 | 1
int min[N << 2], tag[N << 2];
void pushup(int o) {
  min[o] = std::min(min[ls], min[rs]);
}
void add(int o, int z) {
  min[o] += z;
  tag[o] += z;
}
void pushdown(int o) {
  if (tag[o]) {
    add(ls, tag[o]);
    add(rs, tag[o]);
    tag[o] = 0;
  }
}
void update(int o, int l, int r, int x, int y, int z) {
  if (x <= l && r <= y) {
    add(o, z);
    return;
  }
  int m = l + r >> 1;
  pushdown(o);
  if (x < m) update(ls, l, m, x, y, z);
  if (y > m) update(rs, m, r, x, y, z);
  pushup(o);
}
int ask(int o, int l, int r) {
  if (r - l == 1) return l;
  int m = l + r >> 1;
  pushdown(o);
  return min[ls] ? ask(rs, m, r) : ask(ls, l, m);
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m;
  std::cin >> n >> m;
  std::vector<std::vector<std::array<int, 3>>> g(m);
  for (int i = 0; i < n; i++) {
    int a, b, c, d;
    std::cin >> a >> b >> c >> d;
    a = (a % m + m) % m;
    b = (b % m + m) % m;
    c = (c % m + m) % m;
    d = (d % m + m) % m;
    g[a].push_back({b, d, 1});
    g[c].push_back({b, d, -1});
    if (a >= c) {
      g[0].push_back({b, d, 1});
    }
  }
  for (int i = 0; i < m; i++) {
    for (auto [x, y, z] : g[i]) {
      if (x < y) {
        update(1, 0, m, x, y, z);
      } else {
        std::swap(x, y);
        update(1, 0, m, 0, m, z);
        if (x < y) update(1, 0, m, x, y, -z);
      }
    }
    if (min[1] == 0) {
      std::cout << "YES\n";
      std::cout << i << " " << ask(1, 0, m) << "\n";
      return 0;
    }
  }
  std::cout << "NO\n";
  return 0;
}

Intervals on the Ring

\[ \overline{\bigcup A_i} = \bigcap \overline{A_i} \]

// Author:  HolyK
// Created: Mon Aug  2 12:05:21 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int t;
  std::cin >> t;
  while (t--) {
    int n, m;
    std::cin >> n >> m;
    std::vector<PII> a(m);
    for (int i = 0; i < m; i++) {
      std::cin >> a[i].first >> a[i].second;
    }
    std::sort(a.begin(), a.end());
    std::cout << m << "\n";
    for (int i = 0; i < m; i++) {
      std::cout << a[(i + 1) % m].first << " " << a[i].second << "\n";
    }
  }
  return 0;
}

Defend Your Country

// Author:  HolyK
// Created: Mon Aug  2 16:17:15 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;

constexpr int N(1e6 + 5);
int n, m, a[N];
std::vector<int> g[N];
LL ans, sum, val[N];
int in[N], low[N], cnt, siz[N];
void dfs(int x) {
  in[x] = low[x] = ++cnt;
  LL min = a[x], s = a[x];
  int c = 1;
  siz[x] = 1;
  val[x] = a[x];
  bool cut = false;
  for (int y : g[x]) {
    if (!in[y]) {
      dfs(y);
      siz[x] ^= siz[y];
      smin(low[x], low[y]);
      val[x] += val[y];
      if (low[y] >= in[x]) {
        cut = true;
        if (siz[y] & 1) {
          c = 0;
        }
      }
    } else {
      smin(low[x], in[y]);
    }
  }
  if (c & 1) {
    smax(ans, sum - a[x] * 2);
  }
}
void solve() {
  std::cin >> n >> m;
  sum = 0;
  for (int i = 1; i <= n; i++) {
    std::cin >> a[i];
    sum += a[i];
    g[i].clear();
    in[i] = 0;
    siz[i] = 0;
  }
  while (m--) {
    int x, y;
    std::cin >> x >> y;
    g[x].push_back(y);
    g[y].push_back(x);
  }
  if (n % 2 == 0) {
    std::cout << sum << "\n";
    return;
  }
  ans = -sum;
  cnt = 0;
  dfs(1);
  std::cout << ans << "\n";
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int t;
  std::cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

Starch Cat

点分树处理中心到每个点的dp值。

随机数据暴力求lca很快。

// Author:  HolyK
// Created: Fri Aug  6 16:44:46 2021
#include <bits/stdc++.h>
template <class T, class U>
inline bool smin(T &x, const U &y) {
  return y < x ? x = y, 1 : 0;
}
template <class T, class U>
inline bool smax(T &x, const U &y) {
  return x < y ? x = y, 1 : 0;
}

using LL = long long;
using PII = std::pair<int, int>;
constexpr int P = 998244353;

struct Rand{
  unsigned int n,seed;
  Rand(unsigned int n,unsigned int seed)
    :n(n),seed(seed){}
  int get(long long lastans){
    seed ^= seed << 13;
    seed ^= seed >> 17;
    seed ^= seed << 5;
    return (seed^lastans)%n+1;
  }
};

constexpr int N(5e5 + 5);
int n, a[N], siz[N];

std::vector<int> g[N];
bool vis[N];
void getSize(int x, int p) {
  siz[x] = 1;
  for (int y : g[x]) {
    if (y == p || vis[y]) continue;
    getSize(y, x);
    siz[x] += siz[y];
  }
}
int getRoot(int x, int p, int s) {
  for (int y : g[x]) {
    if (y == p || vis[y] || siz[y] * 2 < s) continue;
    return getRoot(y, x, s);
  }
  return x;
}
LL f[20][N][2], dp[N][2][2];
int anc[20][N];
void dfs(int x, int p, int d, int r) {
  anc[d][x] = r;
  f[d][x][0] = std::max(dp[x][0][0], dp[x][0][1]);
  f[d][x][1] = std::max(dp[x][1][0], dp[x][1][1]);
  for (int y : g[x]) {
    if (y == p || vis[y]) continue;
    dp[y][0][0] = std::max(dp[x][0][0], dp[x][0][1]);
    dp[y][0][1] = dp[x][0][0] + a[y];
    dp[y][1][0] = std::max(dp[x][1][0], dp[x][1][1]);
    dp[y][1][1] = dp[x][1][0] + a[y];
    dfs(y, x, d, r);
  }
}
void solve(int x, int d) {
  getSize(x, 0);
  x = getRoot(x, 0, siz[x]);
  dp[x][0][0] = 0;
  dp[x][0][1] = dp[x][1][0] = -1e18;
  dp[x][1][1] = a[x];
  dfs(x, 0, d, x);
  vis[x] = true;
  for (int y : g[x]) {
    if (vis[y]) continue;
    solve(y, d + 1);
  }
}
LL query(int x, int y) {
  if (x == y) return a[x];
  for (int i = 0; i < 20; i++) {
    if (anc[i + 1][x] != anc[i + 1][y]) {
      int p = anc[i][x];
      return std::max(f[i][x][1] + f[i][y][1] - a[p], f[i][x][0] + f[i][y][0]);
    }
  }
  // return -1;
}
int main() {
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
  int n, m, seed;
  std::cin >> n >> m >> seed;
  for (int i = 1; i <= n; i++) {
    std::cin >> a[i];
  }
  for (int i = 2, x; i <= n; i++) {
    std::cin >> x;
    g[x].push_back(i);
    g[i].push_back(x);
  }
  solve(1, 0);
  LL lastans = 0, ans = 0;
  Rand rand(n, seed);
  for (int i = 0; i < m; i++) {
    int u = rand.get(lastans);
    int v = rand.get(lastans);
    int x = rand.get(lastans);
    lastans = query(u, v);
    ans = (ans + lastans % P * x) % P;
  }
  std::cout << ans << "\n";
  return 0;
}

最后修改于 2021-08-19