跳转至

依赖背包

依赖背包通常可以使用树形dp解决。

例题 #1 有线电视网

题目描述

某收费有线电视网计划转播一场重要的足球比赛。他们的转播网和用户终端构成一棵树状结构,这棵树的根结点位于足球比赛的现场,树叶为各个用户终端,其他中转站为该树的内部节点。

从转播站到转播站以及从转播站到所有用户终端的信号传输费用都是已知的,一场转播的总费用等于传输信号的费用总和。

现在每个用户都准备了一笔费用想观看这场精彩的足球比赛,有线电视网有权决定给哪些用户提供信号而不给哪些用户提供信号。

写一个程序找出一个方案使得有线电视网在不亏本的情况下使观看转播的用户尽可能多。

输入格式

输入文件的第一行包含两个用空格隔开的整数 \(N\)\(M\),其中 \(2 \le N \le 3000\)\(1 \le M \le N-1\)\(N\) 为整个有线电视网的结点总数,\(M\) 为用户终端的数量。

第一个转播站即树的根结点编号为 \(1\),其他的转播站编号为 \(2\)\(N-M\),用户终端编号为 \(N-M+1\)\(N\)

接下来的 \(N-M\) 行每行表示—个转播站的数据,第 \(i+1\) 行表示第 \(i\) 个转播站的数据,其格式如下:

\(K \ \ A_1 \ \ C_1 \ \ A_2 \ \ C_2 \ \ \ldots \ \ A_k \ \ C_k\)

\(K\) 表示该转播站下接 \(K\) 个结点(转播站或用户),每个结点对应一对整数 \(A\)\(C\)\(A\) 表示结点编号,\(C\) 表示从当前转播站传输信号到结点 \(A\) 的费用。最后一行依次表示所有用户为观看比赛而准备支付的钱数。单次传输成本和用户愿意交的费用均不超过 10。

https://paste.ubuntu.com/p/xN9NRmYvpY/

输出格式

输出文件仅一行,包含一个整数,表示上述问题所要求的最大用户数。

思路

用树形dp来解释分组背包就很简单了。设\(f_{i,j}\)为在子树i中选择j个用户满足,可以得到的最大利润。为了保证被满足的用户和i之间的路线全部打通,在转移中很显然可以保证。

\(f_{i,j}=\max(f_{i,j-k}+f_{son(i),k}-cost(i,son(i)))\)

最后从大到小找到第一个≥0的\(f_{1,i}\)即可。


/*
CB Ntsc111
*/

#include <bits/stdc++.h>
using namespace std;

#define ull unsigned int
#define pii pair<int, int>
#define pf to
#define ps second
#define int long long

#define err cerr << "Error"
#define rd read()

#define ot write
#define nl putchar('\n')
int read() {
  int xx = 0, ff = 1;
  char ch = getchar();
  while (ch < '0' || ch > '9') {
    if (ch == '-')
      ff = -1;
    ch = getchar();
  }
  while (ch >= '0' && ch <= '9')
    xx = xx * 10 + (ch - '0'), ch = getchar();
  return xx * ff;
}
void write(int out) {
  if (out < 0)
    putchar('-'), out = -out;
  if (out > 9)
    write(out / 10);
  putchar(out % 10 + '0');
}

const int M = 800;
const int mxxlog = 10;
int MOD = 1e9 + 57;
const int N = 1 + 4e3 + 3;
const int INF = 1e9;
const double eps = 1e-10;

int n, m, to[N], nxt[N], w[N], v[N], c[N];
int f[N][N];
int cnt = 0;
void add(int a, int b, int c) {
  v[++cnt] = b;
  w[cnt] = c;
  nxt[cnt] = to[a];
  to[a] = cnt;
}

int dfs(int u) {
  if (u > n - m) {
    f[u][1] = c[u];
    return 1;
  }
  int sum = 0;
  for (int i = to[u]; i != -1; i = nxt[i]) {
    int x = dfs(v[i]);
    sum += x;
    for (int j = sum; j; j--)
      for (int k = 1; k <= x; k++)
        f[u][j] = max(f[u][j], f[u][j - k] + f[v[i]][k] - w[i]);
  }
  return sum;
}

signed main() {
  n = rd, m = rd;
  for (int i = 1; i <= n; i++)
    to[i] = -1;
  for (int u = 1, son; u <= n - m; u++) {
    son = rd;
    for (int j = 1, v, w; j <= son; j++) {
      v = rd, w = rd;
      add(u, v, w);
    }
  }
  for (int u = n - m + 1; u <= n; u++)
    c[u] = rd;
  memset(f, -60, sizeof(f));
  for (int u = 1; u <= n; u++)
    f[u][0] = 0;
  dfs(1);

  for (int i = m; i; i--)
    if (f[1][i] >= 0) {
      printf("%lld", i);
      break;
    }

  return 0;
}