[PKUWC2018]猎人杀
容斥好题。

\(n\) 个人,每个人有个权值 \(a_i\)。进行 \(n\) 轮操作,每轮开枪杀死一个人,\(i\) 被杀死的概率 \(p_i=\dfrac{a_i}{\sum_{x\ is\ alive}a_x}\) ,求最后死的那个人是 \(1\) 的概率。答案对 \(998244353\) 取模。

\(a_i > 0, 1 \le \sum a_i \le 10^5\)

\(s = \sum_{i=1}^na_i,w=\sum_{x\ is\ alive}a_x\),那么有 \[ \begin{aligned} p_i &= \frac{a_i}{w}\\ \frac wsp_i &= \frac {a_i}{s}\\ p_i &= \frac {a_i}{s} + \frac{s - w}s p_i \end{aligned} \] 最后的这个式子可以理解为,每次开枪的目标是所有活人和死人,如果打到活人就会杀死他,如果打到死人不算,重打。

这样变成进行无限轮,每轮的目标都是所有人,打到某个人的概率更容易表示。

然后考虑容斥,设 \(A = \{ a_2, a_3, \dots, a_n \}\)\(1\) 在第 \(r\) 轮被打死,且至少有 \(k\) 个人在 \(1\) 之后被打死。 \[ \begin{aligned} ans &= \sum_{k = 0}^{n - 1} (-1)^k \sum_{S \subseteq A, |S| = k} \sum_{r = 1}^{\infty} \left(\frac{s - a_1 - \sum_{i \in S} a_i}{s}\right)^{r - 1} \cdot \frac{a_1}{s}\\ &= \sum_{k = 0}^{n - 1} (-1)^k \sum_{S \subseteq A, |S| = k} \frac{a_1}{s} \cdot \sum_{r = 0}^{\infty} \left(\frac{s - a_1 - \sum_{i \in S} a_i}{s}\right)^{r}\\ &= \sum_{k = 0}^{n - 1} (-1)^k \sum_{S \subseteq A, |S| = k} \frac{a_1}{s} \cdot \frac{1}{1 - \dfrac{s - a_1 - \sum_{i \in S} a_i}{s}}\\ &= \sum_{k = 0}^{n - 1} (-1)^k \sum_{S \subseteq A, |S| = k} \frac{a_1}{a_1 + \sum_{i \in S} a_i}\\ &= \sum_{k = 0}^{s - a_1} \frac{a_1}{a_1 + k} \times \sum_{S \subseteq A, \sum_{i \in S} a_i = k} (-1)^{|S|}\\ &= \sum_{k = 0}^{s - a_1} \frac{a_1}{a_1 + k} \times [x^k] \prod_{i=2}^n(1 - x^{a_i}) \end{aligned} \] 分治求 \(\prod_{i=2}^n(1 - x^{a_i})\) 即可,复杂度 \(O(n \log n \log s)\)

#include <bits/stdc++.h>
#ifdef LOCAL
#define dbg(args...) std::cerr << "\033[32;1m" << #args << " -> ", err(args)
#else
#define dbg(...)
#endif
inline void err() { std::cerr << "\033[0m\n"; }
template<class T, class... U>
inline void err(const T &x, const U &... a) { std::cerr << x << ' '; err(a...); }
template <class T>
inline void readInt(T &w) {
  char c, p = 0;
  while (!isdigit(c = getchar())) p = c == '-';
  for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
  if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }

constexpr int P(998244353), G(3);
inline void inc(int &x, int y) { (x += y) >= P ? x -= P : 0; }
inline int sum(int x, int y) { return x + y >= P ? x + y - P : x + y; }
inline int sub(int x, int y) { return x - y < 0 ? x - y + P : x - y; }
inline int fpow(int x, int k = P - 2) {
  int r = 1;
  for (; k; k >>= 1, x = 1LL * x * x % P)
    if (k & 1) r = 1LL * r * x % P;
  return r;
}

namespace Polynomial {
using Polynom = std::vector<int>;
std::vector<int> w;
void getOmega(int k) {
  w.resize(k);
  w[0] = 1;
  int base = fpow(G, (P - 1) / (k << 1));
  for (int i = 1; i < k; i++) w[i] = 1LL * w[i - 1] * base % P;
}
void dft(int *a, int n) {
  assert((n & n - 1) == 0);
  for (int k = n >> 1; k; k >>= 1) {
    getOmega(k);
    for (int i = 0; i < n; i += k << 1) {
      for (int j = 0; j < k; j++) {
        int y = a[i + j + k];
        a[i + j + k] = (1LL * a[i + j] - y + P) * w[j] % P;
        inc(a[i + j], y);
      }
    }
  }
}
void dft(Polynom &a) { dft(a.data(), a.size()); }
void idft(int *a, int n) {
  assert((n & n - 1) == 0);
  for (int k = 1; k < n; k <<= 1) {
    getOmega(k);
    for (int i = 0; i < n; i += k << 1) {
      for (int j = 0; j < k; j++) {
        int x = a[i + j], y = 1LL * a[i + j + k] * w[j] % P;
        a[i + j] = sum(x, y), a[i + j + k] = sub(x, y);
      }
    }
  }
  for (int i = 0, inv = P - (P - 1) / n; i < n; i++) a[i] = 1LL * a[i] * inv % P;
  std::reverse(a + 1, a + n);
}
void idft(Polynom &a) { idft(a.data(), a.size()); }
Polynom operator*(Polynom a, Polynom b) {
  int len = a.size() + b.size() - 1;
  if (a.size() <= 8 || b.size() <= 8) {
    Polynom c(len);
    for (unsigned i = 0; i < a.size(); i++)
      for (unsigned j = 0; j < b.size(); j++)
        c[i + j] = (c[i + j] + 1LL * a[i] * b[j]) % P;
    return c;
  }
  int n = 1 << std::__lg(len - 1) + 1;
  a.resize(n), b.resize(n);
  dft(a), dft(b);
  for (int i = 0; i < n; i++) a[i] = 1LL * a[i] * b[i] % P;
  idft(a);
  a.resize(len);
  return a;
}
} // namespace Polynomial

using Polynomial::Polynom;
using Polynomial::operator*;

constexpr int N(1e5 + 5);
int n, a[N], s;
Polynom calc(int l, int r) {
  if (l == r) {
    Polynom ans(a[l] + 1);
    ans[0] = 1, ans.back() = P - 1;
    return ans;
  }
  int m = l + r >> 1;
  Polynom ans = calc(l, m) * calc(m + 1, r);
  if (ans.size() > s + 1) ans.resize(s + 1);
  return ans;
}
int main() {
  readInt(n);
  for (int i = 1; i <= n; i++) readInt(a[i]), s += a[i];
  s -= a[1];
  auto p = calc(2, n);
  int ans = 0;
  for (int i = 0; i <= s; i++) ans = (ans + 1LL * fpow(a[1] + i) * p[i]) % P;
  ans = 1LL * ans * a[1] % P;
  printf("%d\n", ans);
  return 0;
}

最后修改于 2021-08-13