Skip to content

T30201: 旅行售货商问题

bitmask dp, http://cs101.openjudge.cn/practice/30201/

一个国家有 n 个城市,每两个城市之间都开设有航班,从城市 i 到城市 j 的航班价格为 cost[i, j] ,而且往、返航班的价格相同。

售货商要从一个城市出发,途径每个城市 1 次(且每个城市只能经过 1 次),最终返回出发地,而且他的交通工具只有航班,请求出他旅行的最小开销。

输入

输入的第 1 行是一个正整数 n (3 <= n <= 18) 然后有 n 行,每行有 n 个正整数,构成一个 n * n 的矩阵,矩阵的第 i 行第 j 列为城市 i 到城市 j 的航班价格。1 <= cost[i,j] <= 10^4

输出

输出数据为一个正整数 m,表示旅行售货商的最小开销

样例输入

4
0 4 1 3
4 0 2 1
1 2 0 5
3 1 5 0

样例输出

7

提示:dp, dfs

来源:2025fall-cs101 yan

除了用状压dp,还可以有三个层次的优化:I/O 与数据结构优化(基础)、逻辑剪枝(中级)以及Python 特性优化(高级)。

详细解读优化的三个层次

1. I/O 与数据结构优化(基础)

  • sys.stdin.read: 原代码在循环里用 input(),每次都要处理缓冲区。用 read().split() 将所有数据一次性读入内存并切分,配合迭代器 iter() 和 next(),是 Python 刷题处理大量数据的标准做法。
  • 列表 vs 字典: 原代码用 d = {} 存图。字典查询需要哈希计算,而列表(数组)是基于内存偏移量的直接访问。在高频访问场景下,列表要快得多。

2.逻辑剪枝(中级)

  • range(1, size, 2): 这是一个非常漂亮的逻辑剪枝。因为我们规定从城市 0 出发,所以任何合法的状态 mask,其二进制最低位(第0位)必须是 1。这意味着 mask 必然是奇数。直接让循环步长为 2,循环次数瞬间减少一半
  • range(1, n): 内层循环寻找下一个城市 v 时,不需要考虑 0。因为 TSP 规定中间过程不走重复路,0 是起点,只有在遍历完所有节点后,计算 ans 时才考虑回到 0。

3.Python 特性优化(高级 - 提速关键)

  • 去处 min() 函数: 这是 Python 算法题优化的“杀手锏”。dp[new][k] = min(dp[new][k], val) 看起来很简洁。但 Python 的函数调用栈开销很大。在千万级别的循环中,这会严重拖慢速度。改成 if val < dp[new][k]: dp[new][k] = val 虽然代码多了两行,但运行速度会有质的飞跃。
  • 局部变量缓存: curr_dist = dp[mask][u]。在 Python 中,访问 dp[mask][u] 需要两次列表索引操作。把它存为局部变量 curr_dist,后续计算只读这个变量,减少了索引查找开销。

运行时间:3867ms

python
import sys

def solve():
    # 1. 快速 I/O
    input = sys.stdin.read
    data = input().split()
    iterator = iter(data)
    
    try:
        n = int(next(iterator))
    except StopIteration:
        return

    # 2. 使用二维列表代替字典,并预处理为整数
    # 直接读取 n*n 个数据构建矩阵
    dist = []
    for _ in range(n):
        row = [int(next(iterator)) for _ in range(n)]
        dist.append(row)

    # 3. 初始化 DP
    # 状态上限 1<<n
    size = 1 << n
    inf = float("inf")
    # 这里的 dp[i][j] 表示:状态掩码为 i,当前停留在 j 城市
    dp = [[inf] * n for _ in range(size)]
    
    # 起点固定为 0,初始状态掩码为 1 (二进制 ...001),花费 0
    dp[1][0] = 0

    # 4. 逻辑优化:只遍历奇数 mask
    # 因为起点 0 永远在路径中,所以 mask 的第 0 位永远是 1,即 mask 永远是奇数
    for i in range(1, size, 2):
        # 优化:如果当前状态 i 下,所有结尾可能都不可达,直接跳过(但在 TSP 稠密图中较少见)
        
        for j in range(n):
            # 如果状态 i 下不可能停在 j,跳过
            # 或者 if not (i >> j) & 1: continue (隐含条件,通常由 inf 判断即可)
            if dp[i][j] == inf:
                continue
            
            curr_dist = dp[i][j]
            
            # 5. 内层循环优化
            # k 从 1 开始,因为中间过程不可能回到起点 0
            for k in range(1, n):
                # 如果 k 已经在状态 i 中,跳过
                if (i >> k) & 1:
                    continue
                
                new_mask = i | (1 << k)
                new_cost = curr_dist + dist[j][k]
                
                # 6. 移除 min() 函数调用,手动比较更快
                if new_cost < dp[new_mask][k]:
                    dp[new_mask][k] = new_cost

    # 7. 计算最终回路
    ans = inf
    # 最终状态必然是 (1<<n) - 1,即全 1
    final_mask = size - 1
    
    # 此时停留在 i,需要从 i 回到 0
    for i in range(1, n):
        cost = dp[final_mask][i] + dist[i][0]
        if cost < ans:
            ans = cost

    print(ans)

if __name__ == '__main__':
    solve()

36行第2个for循环,for j in range(n):

这一层循环的意思是:枚举当前所在的“终点城市”。为了彻底理解,需要结合 DP 状态定义来看。

1. 核心含义

dp[i][j] 的定义是:

  • i:经过了哪些城市(状态/集合)。
  • j目前停留在哪个城市

第 2 个循环 for j in range(n) 就是在遍历这个 “目前停留的城市”

2. 为什么要遍历它?(举例说明)

假设有 3 个城市:0, 1, 2。 外层循环 i 到了状态 111(二进制),表示 {0, 1, 2} 这三个城市都去过了。

虽然大家都去过这三个城市,但“怎么走的”以及“最后停在哪里”是不同的,这对下一步去哪里至关重要。

  • 情况 A:路径是 012
    • 此时 j = 2(当前在 2 号城市)。
    • 如果你下一步要去 3 号城市,路费是 distance[2][3]
  • 情况 B:路径是 021
    • 此时 j = 1(当前在 1 号城市)。
    • 如果你下一步要去 3 号城市,路费是 distance[1][3]

结论: 只知道“去过哪些城市(i)”是不够的,必须知道“现在脚踩在哪个城市(j)”,才能算出“去下一个城市(k)”的距离。

3. 代码逻辑链条

这个循环的作用起到了承上启下的连接作用:

  1. 承上(检查合法性)

    python
    if dp[i][j] == float("inf"):
        continue

    不是所有城市都能作为当前状态的终点。比如,如果集合 i 里根本没有城市 j,或者从起点根本走不到 j,这个状态就是无效的,直接跳过。

  2. 启下(状态转移)

    python
    dp[newi][k] = min(..., dp[i][j] + d[j][k])

    这里用到了 j。我们要从 “当前点 j 走到 “下一个点 k。如果没有这层 j 的循环,我们就不知道这笔路费 d[j][k] 里的起点是谁。

总结

第 2 个循环就是在问: “在已经去过集合 i 的所有方案中,如果我们最后停在了城市 0、或者城市 1、……或者城市 n-1,分别会怎么样?”

Q:为什么在openjudge.cn上提交代码,套在函数中的程序,运行更快?

把程序套在函数里面,就不超时了。因为:局部变量(函数内部定义的变量)存储在 局部命名空间(local namespace) 中,通过 索引直接访问(LOAD_FAST 指令),速度非常快。 全局变量(模块级别定义的变量)存储在 全局字典(globals dict) 中,每次访问都需要 哈希查找(LOAD_GLOBAL 指令),开销更大。 @李睿安

运行时间:4135ms

python
def solve():
    n = int(input().strip())
    cost = []
    for _ in range(n):
        row = list(map(int, input().split()))
        cost.append(row)

    # 如果只有1个城市?但题目保证 n>=3
    INF = float('inf')
    # dp[mask][i]: mask 是已访问的城市集合,i 是当前所在城市(0 <= i < n)
    # mask 是一个整数,bit j 为1 表示城市 j 已访问
    total_masks = 1 << n
    dp = [[INF] * n for _ in range(total_masks)]

    # 起点设为城市0
    dp[1][0] = 0  # 只访问了城市0,当前在0,花费0

    # 遍历所有状态
    for mask in range(1, total_masks, 2):
        for u in range(n):
            if dp[mask][u] == INF:
                continue
            # 尝试从 u 到未访问的城市 v
            for v in range(n):
                if mask & (1 << v):
                    continue  # v 已访问,跳过
                new_mask = mask | (1 << v)
                new_cost = dp[mask][u] + cost[u][v]
                if new_cost < dp[new_mask][v]:
                    dp[new_mask][v] = new_cost

    # 所有城市都访问完的状态是 (1 << n) - 1
    full_mask = total_masks - 1
    ans = INF
    for i in range(1, n):  # 从其他城市回到起点0
        if dp[full_mask][i] != INF:
            ans = min(ans, dp[full_mask][i] + cost[i][0])

    print(ans)

if __name__ == '__main__':
    solve()

运行时间:6252ms

python
n = int(input().strip())
cost = []
for _ in range(n):
    row = list(map(int, input().split()))
    cost.append(row)

# 如果只有1个城市?但题目保证 n>=3
INF = float('inf')
# dp[mask][i]: mask 是已访问的城市集合,i 是当前所在城市(0 <= i < n)
# mask 是一个整数,bit j 为1 表示城市 j 已访问
total_masks = 1 << n
dp = [[INF] * n for _ in range(total_masks)]

# 起点设为城市0
dp[1][0] = 0  # 只访问了城市0,当前在0,花费0

# 遍历所有状态
for mask in range(1, total_masks, 2):
    for u in range(n):
        if dp[mask][u] == INF:
            continue
        # 尝试从 u 到未访问的城市 v
        for v in range(n):
            if mask & (1 << v):
                continue  # v 已访问,跳过
            new_mask = mask | (1 << v)
            new_cost = dp[mask][u] + cost[u][v]
            if new_cost < dp[new_mask][v]:
                dp[new_mask][v] = new_cost

# 所有城市都访问完的状态是 (1 << n) - 1
full_mask = total_masks - 1
ans = INF
for i in range(1, n):  # 从其他城市回到起点0
    if dp[full_mask][i] != INF:
        ans = min(ans, dp[full_mask][i] + cost[i][0])

print(ans)

【李欣珂 25物院】思路:上学期因为对位运算不熟悉所以没有仔细研究这题,对这题一知半解,这学期仔细了解了位运算,只要想到位来标记城市是否访问的思路,那状态方程就比较经典了。(本来用了cache结果发现oj不支持,上学期写搜索的时候居然没发现)

python
import sys
from functools import lru_cache 


def solve():
    input_data = sys.stdin.read().split()
    if not input_data:
        return

    n = int(input_data[0])

    cost = []
    idx = 1
    for i in range(n):
        cost.append([int(x) for x in input_data[idx: idx + n]])
        idx += n

    @lru_cache(None)
    def dfs(state, u):
        if state == (1 << n) - 1:
            return cost[u][0]

        res = float('inf')

        for v in range(n):
            if (state & (1 << v)) == 0:
                res = min(res, dfs(state | (1 << v), v) + cost[u][v])

        return res

    ans = dfs(1, 0)
    print(ans)


if __name__ == '__main__':
    solve()