T30201: 旅行售货商问题
bitmask dp, http://cs101.openjudge.cn/practice/30201/
一个国家有 n 个城市,每两个城市之间都开设有航班,从城市 i 到城市 j 的航班价格为 cost[i, j] ,而且往、返航班的价格相同。
售货商要从一个城市出发,途径每个城市 1 次(且每个城市只能经过 1 次),最终返回出发地,而且他的交通工具只有航班,请求出他旅行的最小开销。
输入
输入的第 1 行是一个正整数 n (3 <= n <= 18) 然后有 n 行,每行有 n 个正整数,构成一个 n * n 的矩阵,矩阵的第 i 行第 j 列为城市 i 到城市 j 的航班价格。1 <= cost[i,j] <= 10^4
输出
输出数据为一个正整数 m,表示旅行售货商的最小开销
样例输入
4
0 4 1 3
4 0 2 1
1 2 0 5
3 1 5 0样例输出
7提示:dp, dfs
来源:2025fall-cs101 yan
除了用状压dp,还可以有三个层次的优化:I/O 与数据结构优化(基础)、逻辑剪枝(中级)以及Python 特性优化(高级)。
详细解读优化的三个层次
1. I/O 与数据结构优化(基础)
- sys.stdin.read: 原代码在循环里用 input(),每次都要处理缓冲区。用 read().split() 将所有数据一次性读入内存并切分,配合迭代器 iter() 和 next(),是 Python 刷题处理大量数据的标准做法。
- 列表 vs 字典: 原代码用 d = {} 存图。字典查询需要哈希计算,而列表(数组)是基于内存偏移量的直接访问。在高频访问场景下,列表要快得多。
2.逻辑剪枝(中级)
- range(1, size, 2): 这是一个非常漂亮的逻辑剪枝。因为我们规定从城市 0 出发,所以任何合法的状态 mask,其二进制最低位(第0位)必须是 1。这意味着 mask 必然是奇数。直接让循环步长为 2,循环次数瞬间减少一半。
- range(1, n): 内层循环寻找下一个城市 v 时,不需要考虑 0。因为 TSP 规定中间过程不走重复路,0 是起点,只有在遍历完所有节点后,计算 ans 时才考虑回到 0。
3.Python 特性优化(高级 - 提速关键)
- 去处 min() 函数: 这是 Python 算法题优化的“杀手锏”。
dp[new][k] = min(dp[new][k], val)看起来很简洁。但 Python 的函数调用栈开销很大。在千万级别的循环中,这会严重拖慢速度。改成if val < dp[new][k]: dp[new][k] = val虽然代码多了两行,但运行速度会有质的飞跃。- 局部变量缓存:
curr_dist = dp[mask][u]。在 Python 中,访问dp[mask][u]需要两次列表索引操作。把它存为局部变量 curr_dist,后续计算只读这个变量,减少了索引查找开销。
运行时间:3867ms
import sys
def solve():
# 1. 快速 I/O
input = sys.stdin.read
data = input().split()
iterator = iter(data)
try:
n = int(next(iterator))
except StopIteration:
return
# 2. 使用二维列表代替字典,并预处理为整数
# 直接读取 n*n 个数据构建矩阵
dist = []
for _ in range(n):
row = [int(next(iterator)) for _ in range(n)]
dist.append(row)
# 3. 初始化 DP
# 状态上限 1<<n
size = 1 << n
inf = float("inf")
# 这里的 dp[i][j] 表示:状态掩码为 i,当前停留在 j 城市
dp = [[inf] * n for _ in range(size)]
# 起点固定为 0,初始状态掩码为 1 (二进制 ...001),花费 0
dp[1][0] = 0
# 4. 逻辑优化:只遍历奇数 mask
# 因为起点 0 永远在路径中,所以 mask 的第 0 位永远是 1,即 mask 永远是奇数
for i in range(1, size, 2):
# 优化:如果当前状态 i 下,所有结尾可能都不可达,直接跳过(但在 TSP 稠密图中较少见)
for j in range(n):
# 如果状态 i 下不可能停在 j,跳过
# 或者 if not (i >> j) & 1: continue (隐含条件,通常由 inf 判断即可)
if dp[i][j] == inf:
continue
curr_dist = dp[i][j]
# 5. 内层循环优化
# k 从 1 开始,因为中间过程不可能回到起点 0
for k in range(1, n):
# 如果 k 已经在状态 i 中,跳过
if (i >> k) & 1:
continue
new_mask = i | (1 << k)
new_cost = curr_dist + dist[j][k]
# 6. 移除 min() 函数调用,手动比较更快
if new_cost < dp[new_mask][k]:
dp[new_mask][k] = new_cost
# 7. 计算最终回路
ans = inf
# 最终状态必然是 (1<<n) - 1,即全 1
final_mask = size - 1
# 此时停留在 i,需要从 i 回到 0
for i in range(1, n):
cost = dp[final_mask][i] + dist[i][0]
if cost < ans:
ans = cost
print(ans)
if __name__ == '__main__':
solve()36行第2个for循环,
for j in range(n):这一层循环的意思是:枚举当前所在的“终点城市”。为了彻底理解,需要结合 DP 状态定义来看。
1. 核心含义
dp[i][j]的定义是:
i:经过了哪些城市(状态/集合)。j:目前停留在哪个城市。第 2 个循环
for j in range(n)就是在遍历这个 “目前停留的城市”。2. 为什么要遍历它?(举例说明)
假设有 3 个城市:0, 1, 2。 外层循环
i到了状态111(二进制),表示{0, 1, 2}这三个城市都去过了。虽然大家都去过这三个城市,但“怎么走的”以及“最后停在哪里”是不同的,这对下一步去哪里至关重要。
- 情况 A:路径是
。
- 此时
j = 2(当前在 2 号城市)。- 如果你下一步要去 3 号城市,路费是
distance[2][3]。- 情况 B:路径是
。
- 此时
j = 1(当前在 1 号城市)。- 如果你下一步要去 3 号城市,路费是
distance[1][3]。结论: 只知道“去过哪些城市(
i)”是不够的,必须知道“现在脚踩在哪个城市(j)”,才能算出“去下一个城市(k)”的距离。3. 代码逻辑链条
这个循环的作用起到了承上启下的连接作用:
承上(检查合法性):
pythonif dp[i][j] == float("inf"): continue不是所有城市都能作为当前状态的终点。比如,如果集合
i里根本没有城市j,或者从起点根本走不到j,这个状态就是无效的,直接跳过。启下(状态转移):
pythondp[newi][k] = min(..., dp[i][j] + d[j][k])这里用到了
j。我们要从 “当前点j” 走到 “下一个点k”。如果没有这层j的循环,我们就不知道这笔路费d[j][k]里的起点是谁。总结
第 2 个循环就是在问: “在已经去过集合
i的所有方案中,如果我们最后停在了城市0、或者城市1、……或者城市n-1,分别会怎么样?”
Q:为什么在openjudge.cn上提交代码,套在函数中的程序,运行更快?
把程序套在函数里面,就不超时了。因为:局部变量(函数内部定义的变量)存储在 局部命名空间(local namespace) 中,通过 索引直接访问(LOAD_FAST 指令),速度非常快。 全局变量(模块级别定义的变量)存储在 全局字典(globals dict) 中,每次访问都需要 哈希查找(LOAD_GLOBAL 指令),开销更大。 @李睿安
运行时间:4135ms
pythondef solve(): n = int(input().strip()) cost = [] for _ in range(n): row = list(map(int, input().split())) cost.append(row) # 如果只有1个城市?但题目保证 n>=3 INF = float('inf') # dp[mask][i]: mask 是已访问的城市集合,i 是当前所在城市(0 <= i < n) # mask 是一个整数,bit j 为1 表示城市 j 已访问 total_masks = 1 << n dp = [[INF] * n for _ in range(total_masks)] # 起点设为城市0 dp[1][0] = 0 # 只访问了城市0,当前在0,花费0 # 遍历所有状态 for mask in range(1, total_masks, 2): for u in range(n): if dp[mask][u] == INF: continue # 尝试从 u 到未访问的城市 v for v in range(n): if mask & (1 << v): continue # v 已访问,跳过 new_mask = mask | (1 << v) new_cost = dp[mask][u] + cost[u][v] if new_cost < dp[new_mask][v]: dp[new_mask][v] = new_cost # 所有城市都访问完的状态是 (1 << n) - 1 full_mask = total_masks - 1 ans = INF for i in range(1, n): # 从其他城市回到起点0 if dp[full_mask][i] != INF: ans = min(ans, dp[full_mask][i] + cost[i][0]) print(ans) if __name__ == '__main__': solve()运行时间:6252ms
pythonn = int(input().strip()) cost = [] for _ in range(n): row = list(map(int, input().split())) cost.append(row) # 如果只有1个城市?但题目保证 n>=3 INF = float('inf') # dp[mask][i]: mask 是已访问的城市集合,i 是当前所在城市(0 <= i < n) # mask 是一个整数,bit j 为1 表示城市 j 已访问 total_masks = 1 << n dp = [[INF] * n for _ in range(total_masks)] # 起点设为城市0 dp[1][0] = 0 # 只访问了城市0,当前在0,花费0 # 遍历所有状态 for mask in range(1, total_masks, 2): for u in range(n): if dp[mask][u] == INF: continue # 尝试从 u 到未访问的城市 v for v in range(n): if mask & (1 << v): continue # v 已访问,跳过 new_mask = mask | (1 << v) new_cost = dp[mask][u] + cost[u][v] if new_cost < dp[new_mask][v]: dp[new_mask][v] = new_cost # 所有城市都访问完的状态是 (1 << n) - 1 full_mask = total_masks - 1 ans = INF for i in range(1, n): # 从其他城市回到起点0 if dp[full_mask][i] != INF: ans = min(ans, dp[full_mask][i] + cost[i][0]) print(ans)
【李欣珂 25物院】思路:上学期因为对位运算不熟悉所以没有仔细研究这题,对这题一知半解,这学期仔细了解了位运算,只要想到位来标记城市是否访问的思路,那状态方程就比较经典了。(本来用了cache结果发现oj不支持,上学期写搜索的时候居然没发现)
import sys
from functools import lru_cache
def solve():
input_data = sys.stdin.read().split()
if not input_data:
return
n = int(input_data[0])
cost = []
idx = 1
for i in range(n):
cost.append([int(x) for x in input_data[idx: idx + n]])
idx += n
@lru_cache(None)
def dfs(state, u):
if state == (1 << n) - 1:
return cost[u][0]
res = float('inf')
for v in range(n):
if (state & (1 << v)) == 0:
res = min(res, dfs(state | (1 << v), v) + cost[u][v])
return res
ans = dfs(1, 0)
print(ans)
if __name__ == '__main__':
solve()