P6192 【模板】最小斯坦纳树
bitmask/state_compression dp, dijkstra , https://www.luogu.com.cn/problem/P6192
给定一个包含
再给定包含
; 为连通图; 中所有边的权值和最小。
你只需要求出
输入
第一行:三个整数
接下来
接下来一行:
输出
第一行:一个整数,表示
输入输出样例 #1
输入 #1
7 7 4
1 2 3
2 3 2
4 3 9
2 6 2
4 5 3
6 5 2
7 6 4
2 4 7 5输出 #1
11说明/提示:【样例解释】
样例中给出的图如下图所示,红色点为

【数据范围】:对于
斯坦纳树问题是组合优化问题,与最小生成树相似,是最短网络的一种。最小生成树是在给定的点集和边中寻求最短网络使所有点连通。而最小斯坦纳树允许在给定点外增加额外的点,使生成的最短网络开销最小。
最小斯坦纳树问题的核心挑战在于:需要连接指定的
由于
1. 核心思路
定义 dp[i][mask]:
- i:当前连通结构的“根”节点(或者说当前考虑到的节点)。
- mask:一个二进制数,表示当前已经连接了哪些关键点。例如
mask = 5(二进制101) 表示连接了第 0 个和第 2 个关键点。
状态转移分为两个阶段:
阶段 A:基于节点内部的子集拆分(合并) 对于同一个点 mask:
这里
阶段 B:基于边权在节点间扩展 这本质上是在更新“为了连接同样的点集,从点
这部分看起来非常像最短路。由于边权为正,对每一个固定的 mask 运行一次 Dijkstra。
2. Python 实现
使用 heapq 优化 Dijkstra,并尽量减少不必要的计算。
import heapq
import sys
def solve():
# 快速读入
input = sys.stdin.read().split()
if not input: return
idx = 0
n = int(input[idx]); idx += 1
m = int(input[idx]); idx += 1
k = int(input[idx]); idx += 1
# 邻接表
adj = [[] for _ in range(n + 1)]
for _ in range(m):
u = int(input[idx]); idx += 1
v = int(input[idx]); idx += 1
w = int(input[idx]); idx += 1
adj[u].append((v, w))
adj[v].append((u, w))
# 关键点
terminals = []
for _ in range(k):
terminals.append(int(input[idx]))
idx += 1
# INF 定义
INF = float('inf')
# dp[mask][i] 表示连接状态为 mask,且以 i 为当前点的最小代价
# mask 范围是 0 到 2^k - 1
dp = [[INF] * (n + 1) for _ in range(1 << k)]
# 初始化:每个关键点连接自己状态的代价为 0
for i in range(k):
dp[1 << i][terminals[i]] = 0
# 状态压缩 DP
for mask in range(1, 1 << k):
# 阶段 A: 子集枚举转移 (同一个点内合并两个子集)
# sub = (sub - 1) & mask 是枚举子集的技巧
sub = (mask - 1) & mask
while sub > 0:
for i in range(1, n + 1):
dp[mask][i] = min(dp[mask][i], dp[sub][i] + dp[mask ^ sub][i])
sub = (sub - 1) & mask
# 阶段 B: Dijkstra 算法(节点间通过边权转移)
pq = []
for i in range(1, n + 1):
if dp[mask][i] != INF:
heapq.heappush(pq, (dp[mask][i], i))
while pq:
d, u = heapq.heappop(pq)
if d > dp[mask][u]:
continue
for v, w in adj[u]:
if dp[mask][v] > dp[mask][u] + w:
dp[mask][v] = dp[mask][u] + w
heapq.heappush(pq, (dp[mask][v], v))
# 最终答案:连接了所有关键点 (mask = 2^k - 1) 的任意一个点作为根的最小值
ans = min(dp[(1 << k) - 1][1:])
print(ans)
if __name__ == "__main__":
solve()3. 代码关键点解析
- 子集枚举优化:
sub = (mask - 1) & mask这行代码可以在的总时间复杂度内遍历所有 mask及其子集,而不是。 - Dijkstra 的必要性: 在
mask内部,通过dp[sub] + dp[mask ^ sub]得到的可能只是零散的更新。需要通过边把这些代价“传导”到其他节点。由于这涉及到求最短路径,所以使用 Dijkstra。 - Python 性能备注:
- 斯坦纳树的复杂度较高(
)。 - 如果遇到超时(TLE),可以尝试将 Dijkstra 换成 SPFA(在 Python 中有时表现更好),或者使用
sys.stdin.readline加快输入。
- 斯坦纳树的复杂度较高(
4. 总结
这个算法体现了“斯坦纳树”和“最小生成树”的区别:
- MST:直接在
个点中连边,Greedy (Kruskal/Prim) 即可。 - SMT:选择一部分关键点。通过状压记录“连接了哪些点”,通过 DP 寻找“最佳中转路径”。
最小斯坦纳树 (SMT) 的状压 DP 确实和 旅行商问题 (TSP) 的状压 DP 在本质上是“同门师兄弟”。
它们都属于状态压缩 DP (Bitmask DP),核心思想都是用一个二进制数字(mask)来代表“我已经搞定了哪些点”。
但它们之间有几个关键的区别,理解了这些区别,你就真正掌握了这两种算法。
1. 相同点:状态的定义
在两个算法中,
dp[mask][i]的定义几乎一模一样:
- mask: 一个二进制数,表示当前任务的进度(TSP 是去过了哪些城,SMT 是连接了哪些关键点)。
- i: 当前停留的位置(或者说当前的“根”)。
2. 不同点:转移的逻辑(核心区别)
这是最关键的地方。TSP 是在走“一条线”,而 SMT 是在修“一棵树”。
(1) TSP 的转移:线性扩展
在 TSP 中,你就像在画一条不间断的线。你每走一步,就多去一个城市。
- 逻辑:
当前在 j+走向下个城市 k=更新包含 k 的状态。- 代码体现:
new_mask = i | (1 << k)。- 形状:
(2) SMT 的转移:分支合并 + 扩展
在 SMT 中,因为是“树”结构,所以可能有两个分叉在某个点汇合。
- 逻辑 A(合并):我在点
处,我手里有一半目标点的清单,你手里有另一半。我们两个一碰头,就把所有的清单都凑齐了。
- 代码体现:
dp[mask][i] = min(dp[mask][i], dp[sub][i] + dp[mask ^ sub][i])- 逻辑 B(扩展):我在点
处,我带着我的清单走到隔壁点 。
- 代码体现:这部分和 TSP 很像,但因为可能存在环或者复杂的路径,通常用 Dijkstra 或 SPFA 来跑,而不是简单的循环。
- 形状:像树枝一样汇聚:
3. 对比总结表
特性 旅行商问题 (TSP) 最小斯坦纳树 (SMT) 目标 走过所有点的最短环路 连接指定关键点的最短树 Mask 含义 已经访问过的所有城市 已经连通的关键目标点 合并操作 无(只能一个点一个点地往下走) 有(两个子连通块可以在一个点合并) 移动操作 简单 dist[j][k]累加使用 Dijkstra/SPFA进行松弛复杂度 (k 是关键点数) 4. 针对你提供的 TSP 优化技巧,SMT 是否适用?
你提到的那三个层次的优化,在 SMT 中同样非常重要:
- I/O 优化:完全适用。处理复杂的图论题,
sys.stdin.read().split()是 Python 选手的救命稻草。- 逻辑剪枝:
- TSP 中可以用
range(1, size, 2)是因为起点固定。- SMT 中,通常会预处理所有关键点。如果你当前
mask下的代价已经超过了已知最小代价,或者某个mask在当前点根本不合法,也可以直接continue。- Python 特性优化 (去掉
min):
- 极其适用! 在 SMT 的子集遍历循环
sub = (mask - 1) & mask内部,如果使用min(),速度会慢得惊人。改成if new_cost < dp[mask][i]: dp[mask][i] = new_cost能让代码快出好几倍。总结建议
如果你已经完全理解了 TSP 的这段代码:
- 你已经掌握了
mask的移动 和 位运算。- 你想学 SMT,只需要额外学习“如何在同一个点合并两个状态”(即
dp[mask][i] = dp[sub][i] + dp[mask^sub][i])以及“如何用 Dijkstra 更新 mask”。TSP 像是在跑马拉松(一个人跑到底),而斯坦纳树像是在搞接力赛或者修水管网(多条线可以汇聚到一个水泵站)。