[Codechef CCC]Hit the Coconuts
斜率优化DP好题。

题意: \(N\) 个椰子, 第 \(i\) 个需要敲 \(a_i\) 下才能打开,问最坏情况下最少要敲多少下才能打开 \(k\) 个椰子。

不妨设 \(a_i \le a_{i + 1}\)

要敲开一个椰子,现在有两种方法:

  1. 随便拿起一个椰子不停地敲,最多要 \(a_n\) 次。
  2. 有目的性地想要敲开某个椰子 \(a_i\),可以发现最优的策略就是每个椰子都敲 \(a_i\) 次,最多要敲 \(a_i\times(n - i + 1)\) 次,所有情况下取最小值。

第二种情况实际上包含了第一种。

那么要敲开 \(k\) 个椰子,也有两种方法:

  1. 每次都随便拿一个敲开,最多要 \(a_n + a_{n - 1}+ \dots + a_{n - k + 1}\) 次。
  2. \(k\) 个椰子 \(a_{b_1}, a_{b_2}, \dots, a_{b_k}(b_i < b_{i + 1})\),有目的性地敲开它们。发现最优策略下最多要敲 \(\sum_{i = 1}^k (a_{b_i} - a_{b_i - 1}) \times (n - b_i + 1)\) 次,其中 \(a_{b_0} = 0\),所有情况下取最小值。

第二种情况也包含了第一种。

因此考虑第二种情况的求解,设 \(dp[i][k]\) 表示前 \(i\) 个椰子敲开 \(k\) 个(包括第 \(i\) 个椰子)的最少次数。

\[ dp[i][k] = \min\{dp[j][k - 1] + (a_i - a_j) \times (n - i + 1)\} \]

这显然是一个斜率优化的形式,化一下式子(第二维省略):

\[ dp[i] = \min\{a_j \times (i - n - 1) + dp[j]\} + a_i \times (n - i + 1) \]

设直线 \(L_j(x) = a_j \cdot x + dp[j]\),求 \(\min L_j(i - n - 1)\),斜率和横坐标都是递增的,可以用栈维护一个上凸壳,复杂度 \(O(nk + n\log n)\)

#include <bits/stdc++.h>
#ifdef LOCAL
#define dbg(args...) std::cerr << "\033[32;1m" << #args << " -> ", err(args)
#else
#define dbg(...)
#endif
inline void err() { std::cerr << "\033[0m\n"; }
template<class T, class... U>
inline void err(const T &x, const U &... a) { std::cerr << x << ' '; err(a...); }
template <class T>
inline void readInt(T &w) {
  char c, p = 0;
  while (!isdigit(c = getchar())) p = c == '-';
  for (w = c & 15; isdigit(c = getchar());) w = w * 10 + (c & 15);
  if (p) w = -w;
}
template <class T, class... U>
inline void readInt(T &w, U &... a) { readInt(w), readInt(a...); }
template <class T, class U>
inline bool smin(T &x, const U &y) { return y < x ? x = y, 1 : 0; }
template <class T, class U>
inline bool smax(T &x, const U &y) { return x < y ? x = y, 1 : 0; }

typedef long long LL;
typedef std::pair<int, int> PII;

constexpr int N(1005);
int n, k, a[N];
LL dp[N];
struct Line {
  int k; LL b;
  Line(int k = 0, LL b = 0): k(k), b(b) {}
  inline LL func(int x) { return 1LL * k * x + b; }
  inline bool check(const Line &p, const Line &q) const {
    return (b - p.b) * (k - q.k) > (b - q.b) * (k - p.k);
  }
} q[N];
int main() {
  int t; readInt(t);
  while (t--) {
    readInt(n, k);
    for (int i = 1; i <= n; i++) readInt(a[i]);
    std::sort(a + 1, a + 1 + n);
    for (int i = 1; i <= n; i++) dp[i] = 1e9;
    for (int c = 0, r; c < k; c++) {
      q[r = 1] = Line(a[c], dp[c]);
      for (int i = c + 1; i <= n; i++) {
        int x = i - n - 1;
        while (r > 1 && q[r].func(x) >= q[r - 1].func(x)) r--;
        Line now = Line(a[i], dp[i]);
        dp[i] = q[r].func(x) - 1LL * a[i] * x;
        while (r > 1 && !now.check(q[r], q[r - 1])) r--;
        q[++r] = now;
      }
    }
    LL ans = dp[k];
    for (int i = k + 1; i <= n; i++) smin(ans, dp[i]);
    printf("%lld\n", ans);
  }
  return 0;
}

最后修改于 2021-08-13