Skip to content

P6192 【模板】最小斯坦纳树

bitmask/state_compression dp, dijkstra , https://www.luogu.com.cn/problem/P6192

给定一个包含 n 个结点和 m 条带权边的无向连通图 G=(V,E)

再给定包含 k 个结点的点集 S,选出 G 的子图 G=(V,E),使得:

  1. SV

  2. G 为连通图;

  3. E 中所有边的权值和最小。

你只需要求出 E 中所有边的权值和。

输入

第一行:三个整数 n,m,k,表示 G 的结点数、边数和 S 的大小。

接下来 m 行:每行三个整数 u,v,w,表示编号为 u,v 的点之间有一条权值为 w 的无向边。

接下来一行:k 个互不相同的正整数,表示 S 的元素。

输出

第一行:一个整数,表示 E 中边权和的最小值。

输入输出样例 #1

输入 #1

7 7 4
1 2 3
2 3 2
4 3 9
2 6 2
4 5 3
6 5 2
7 6 4
2 4 7 5

输出 #1

11

说明/提示:【样例解释】

样例中给出的图如下图所示,红色点为 S 中的元素,红色边为 E 的元素,此时 E 中所有边的权值和为 2+2+3+4=11,达到最小值。


【数据范围】:对于 100% 的数据,1n100,  1m500,  1k10,  1u,vn,  1w106。保证给出的无向图连通,但 可能 存在重边和自环。

斯坦纳树问题是组合优化问题,与最小生成树相似,是最短网络的一种。最小生成树是在给定的点集和边中寻求最短网络使所有点连通。而最小斯坦纳树允许在给定点外增加额外的点,使生成的最短网络开销最小。

最小斯坦纳树问题的核心挑战在于:需要连接指定的 K 个“关键点”,但为了让总边权最小,可以经过图中任何其他的点。

由于 K 的范围很小(通常 10),使用 状态压缩动态规划 (State Compression DP) 结合 最短路算法 (Dijkstra) 来解决。

1. 核心思路

定义 dp[i][mask]

  • i:当前连通结构的“根”节点(或者说当前考虑到的节点)。
  • mask:一个二进制数,表示当前已经连接了哪些关键点。例如 mask = 5 (二进制 101) 表示连接了第 0 个和第 2 个关键点。

状态转移分为两个阶段:

阶段 A:基于节点内部的子集拆分(合并) 对于同一个点 i,我们可以通过合并两个互补的子集来得到 mask

dp[i][mask]=min(dp[i][mask],dp[i][sub]+dp[i][masksub])

这里 submask 的子集。

阶段 B:基于边权在节点间扩展 这本质上是在更新“为了连接同样的点集,从点 j 走到点 i 是否更优”:

dp[i][mask]=min(dp[i][mask],dp[j][mask]+weight(i,j))

这部分看起来非常像最短路。由于边权为正,对每一个固定的 mask 运行一次 Dijkstra


2. Python 实现

使用 heapq 优化 Dijkstra,并尽量减少不必要的计算。

python
import heapq
import sys

def solve():
    # 快速读入
    input = sys.stdin.read().split()
    if not input: return
    
    idx = 0
    n = int(input[idx]); idx += 1
    m = int(input[idx]); idx += 1
    k = int(input[idx]); idx += 1
    
    # 邻接表
    adj = [[] for _ in range(n + 1)]
    for _ in range(m):
        u = int(input[idx]); idx += 1
        v = int(input[idx]); idx += 1
        w = int(input[idx]); idx += 1
        adj[u].append((v, w))
        adj[v].append((u, w))
        
    # 关键点
    terminals = []
    for _ in range(k):
        terminals.append(int(input[idx]))
        idx += 1
        
    # INF 定义
    INF = float('inf')
    
    # dp[mask][i] 表示连接状态为 mask,且以 i 为当前点的最小代价
    # mask 范围是 0 到 2^k - 1
    dp = [[INF] * (n + 1) for _ in range(1 << k)]
    
    # 初始化:每个关键点连接自己状态的代价为 0
    for i in range(k):
        dp[1 << i][terminals[i]] = 0
        
    # 状态压缩 DP
    for mask in range(1, 1 << k):
        # 阶段 A: 子集枚举转移 (同一个点内合并两个子集)
        # sub = (sub - 1) & mask 是枚举子集的技巧
        sub = (mask - 1) & mask
        while sub > 0:
            for i in range(1, n + 1):
                dp[mask][i] = min(dp[mask][i], dp[sub][i] + dp[mask ^ sub][i])
            sub = (sub - 1) & mask
            
        # 阶段 B: Dijkstra 算法(节点间通过边权转移)
        pq = []
        for i in range(1, n + 1):
            if dp[mask][i] != INF:
                heapq.heappush(pq, (dp[mask][i], i))
        
        while pq:
            d, u = heapq.heappop(pq)
            if d > dp[mask][u]:
                continue
            for v, w in adj[u]:
                if dp[mask][v] > dp[mask][u] + w:
                    dp[mask][v] = dp[mask][u] + w
                    heapq.heappush(pq, (dp[mask][v], v))
                    
    # 最终答案:连接了所有关键点 (mask = 2^k - 1) 的任意一个点作为根的最小值
    ans = min(dp[(1 << k) - 1][1:])
    print(ans)

if __name__ == "__main__":
    solve()

3. 代码关键点解析

  1. 子集枚举优化sub = (mask - 1) & mask 这行代码可以在 O(3K) 的总时间复杂度内遍历所有 mask 及其子集,而不是 O(4K)
  2. Dijkstra 的必要性: 在 mask 内部,通过 dp[sub] + dp[mask ^ sub] 得到的可能只是零散的更新。需要通过边把这些代价“传导”到其他节点。由于这涉及到求最短路径,所以使用 Dijkstra。
  3. Python 性能备注
    • 斯坦纳树的复杂度较高(3KN)。
    • 如果遇到超时(TLE),可以尝试将 Dijkstra 换成 SPFA(在 Python 中有时表现更好),或者使用 sys.stdin.readline 加快输入。

4. 总结

这个算法体现了“斯坦纳树”和“最小生成树”的区别:

  • MST:直接在 N 个点中连边,Greedy (Kruskal/Prim) 即可。
  • SMT:选择一部分关键点。通过状压记录“连接了哪些点”,通过 DP 寻找“最佳中转路径”。

最小斯坦纳树 (SMT) 的状压 DP 确实和 旅行商问题 (TSP) 的状压 DP 在本质上是“同门师兄弟”。

它们都属于状态压缩 DP (Bitmask DP),核心思想都是用一个二进制数字(mask)来代表“我已经搞定了哪些点”。

但它们之间有几个关键的区别,理解了这些区别,你就真正掌握了这两种算法。


1. 相同点:状态的定义

在两个算法中,dp[mask][i] 的定义几乎一模一样:

  • mask: 一个二进制数,表示当前任务的进度(TSP 是去过了哪些城,SMT 是连接了哪些关键点)。
  • i: 当前停留的位置(或者说当前的“根”)。

2. 不同点:转移的逻辑(核心区别)

这是最关键的地方。TSP 是在走“一条线”,而 SMT 是在修“一棵树”。

(1) TSP 的转移:线性扩展

在 TSP 中,你就像在画一条不间断的线。你每走一步,就多去一个城市。

  • 逻辑当前在 j + 走向下个城市 k = 更新包含 k 的状态
  • 代码体现new_mask = i | (1 << k)
  • 形状

(2) SMT 的转移:分支合并 + 扩展

在 SMT 中,因为是“树”结构,所以可能有两个分叉在某个点汇合

  • 逻辑 A(合并):我在点 i 处,我手里有一半目标点的清单,你手里有另一半。我们两个一碰头,就把所有的清单都凑齐了。
    • 代码体现dp[mask][i] = min(dp[mask][i], dp[sub][i] + dp[mask ^ sub][i])
  • 逻辑 B(扩展):我在点 i 处,我带着我的清单走到隔壁点 j
    • 代码体现:这部分和 TSP 很像,但因为可能存在环或者复杂的路径,通常用 DijkstraSPFA 来跑,而不是简单的循环。
  • 形状:像树枝一样汇聚:

3. 对比总结表

特性旅行商问题 (TSP)最小斯坦纳树 (SMT)
目标走过所有点的最短环路连接指定关键点的最短树
Mask 含义已经访问过的所有城市已经连通的关键目标点
合并操作无(只能一个点一个点地往下走)(两个子连通块可以在一个点合并)
移动操作简单 dist[j][k] 累加使用 Dijkstra/SPFA 进行松弛
复杂度O(n22n)O(n3k+mlogn2k) (k 是关键点数)

4. 针对你提供的 TSP 优化技巧,SMT 是否适用?

你提到的那三个层次的优化,在 SMT 中同样非常重要:

  1. I/O 优化:完全适用。处理复杂的图论题,sys.stdin.read().split() 是 Python 选手的救命稻草。
  2. 逻辑剪枝
    • TSP 中可以用 range(1, size, 2) 是因为起点固定。
    • SMT 中,通常会预处理所有关键点。如果你当前 mask 下的代价已经超过了已知最小代价,或者某个 mask 在当前点根本不合法,也可以直接 continue
  3. Python 特性优化 (去掉 min)
    • 极其适用! 在 SMT 的子集遍历循环 sub = (mask - 1) & mask 内部,如果使用 min(),速度会慢得惊人。改成 if new_cost < dp[mask][i]: dp[mask][i] = new_cost 能让代码快出好几倍。

总结建议

如果你已经完全理解了 TSP 的这段代码:

  1. 你已经掌握了 mask 的移动位运算
  2. 你想学 SMT,只需要额外学习“如何在同一个点合并两个状态”(即 dp[mask][i] = dp[sub][i] + dp[mask^sub][i])以及“如何用 Dijkstra 更新 mask”

TSP 像是在跑马拉松(一个人跑到底),而斯坦纳树像是在搞接力赛或者修水管网(多条线可以汇聚到一个水泵站)。