dfs(int u) { int b = bro[u], s = son[u]; if (s != 0) dfs(s); if (b != 0) dfs(b); for (k = 1; k <= m; k ++) f[u][k] = f[b][k]; for (α = 0; α <= m; α ++) { for (β = 0; β <= m; β ++) { k = α + β + 1; f[u][k] = max { f[s][α] + f[b][β] + A[u]; f[u][k]; }; } } }
1 2
dfs(son[0]); cout << f[son[0]][m];
很不幸,==它还是$O(n^3)$的。==(证明略)
优化:(将$k$改为$sz[s], sz[b]$)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
dfs(int u) { int b = bro[u], s = son[u]; f[u,0] = 0; f[u,1] = a[u]; // 初始化 if (s != 0) dfs(s); if (b != 0) dfs(b); sz[u] = sz[s] + sz[b] + 1; // 算子树大小 for (k = 1; k <= m; k ++) f[u][k] = f[b][k]; // 划水行为(递推公式第一行) for (α = 0; α <= sz[s]; α ++) { for (β = 0; β <= sz[b]; β ++) { k = α + β + 1; f[u][k] = max { f[s][α] + f[b][β] + A[u]; f[u][k]; }; } } }
对于时间复杂度的证明:
非常好证明以下不等式:
在这里我们记$|s| = sz(s),~|b|=sz(b)$,所以时间复杂度:
故时间复杂度为$O(n^2)$。
将方法(一)的$O(n^3)$算法变为$O(n^2)$算法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
voiddfs(int u){ sz[u] = 1; for (int i = son[u]; i != 0; i = bro[i]) { dfs(i); sz[u] += sz[i]; for (int k = sz[u]; k >= 1; k --) { for (α = k - 1; α >= 0; α --) { f[u][k] = max{ f[u][k]; f[u][k-α] + f[i][α]; }; } } } }
ll n, k, ecnt = 0, head[maxn], f[maxn][maxn], sz[maxn];
voidaddEdge(int u, int v, int w){ g[ecnt].to = v; g[ecnt].w = w; g[ecnt].next = head[u]; head[u] = ecnt ++; }
voiddfs(int u, int p){ sz[u] = 1; f[u][0] = f[u][1] = 0; // 子树是一个一个添加(树型背包的第一种添加子树的方法) for (int e = head[u]; e != -1; e = g[e].next) { int v = g[e].to; if (v != p) { dfs(v, u); sz[u] += sz[v]; for (int x = sz[u]; x >= 0; x --) { for (int y = sz[v]; y >= 0; y --) { // 特判:这个节点可行(其子树的节点够) if (f[u][x] != -1) { /* val表示: (新的子树中的黑点个数×除此之外的所有黑点个数+新的子树中的白点个数×除此之外的白点个数) ==> 必将经过下面那条边的次数 × (这些所有次数必定将要进过的一条边:就是u,v之间的边,边权为w[e]) = 这个新的子树添加时,与u连接的这条边对答案的贡献 */ ll val = g[e].w * (y * (k - y) + (sz[v] - y) * ((n - k) - (sz[v] - y))); /* 所以下面的式子可以转换为更加易于理解的式子: f[u][x] = max{f[u][x], f[u][x-k] + f[v][k] + val}, 其中:0 ≤ k ≤ x; v为u的一个儿子 我们不难发现,节点是在dfs的过程中,一 个 一 个 添加的,所以 1. f[u][x-k]表示的是之前(没有添加现在子树的时候),以u为根的子树有x-k个点染黑对答案的贡献 2. f[v][k]表示的是以v为根的子树(由于dfs的顺序,此时这个子树是完全的),有k个节点被染黑对答案,到v的贡献 3. val表示的是为了补全(2)当中的那些染黑的节点,在u -> v的贡献 */ f[u][x + y] = max(f[u][x + y], f[u][x] + f[v][y] + val); } } } } } }
intmain(){ memset(head, -1, sizeof(head)); memset(f, -1, sizeof(f)); cin >> n >> k; for (int i = 1; i < n; i ++) { int from, to, dis; cin >> from >> to >> dis; addEdge(from, to, dis); addEdge(to, from, dis); } dfs(1, 0); cout << f[1][k] << endl; return0; }