Skip to content

02442: Sequence

heap, merge, http://cs101.openjudge.cn/practice/02442/

中文版是 http://cs101.openjudge.cn/practice/06648/

Given m sequences, each contains n non-negative integer. Now we may select one number from each sequence to form a sequence with m integers. It's clear that we may get n ^ m this kind of sequences. Then we can calculate the sum of numbers in each sequence, and get n ^ m values. What we need is the smallest n sums. Could you help us?

输入

The first line is an integer T, which shows the number of test cases, and then T test cases follow. The first line of each case contains two integers m, n (0 < m <= 100, 0 < n <= 2000). The following m lines indicate the m sequence respectively. No integer in the sequence is greater than 10000.

输出

For each test case, print a line with the smallest n sums in increasing order, which is separated by a space.

样例输入

1
2 3
1 2 3
2 2 3

样例输出

3 3 4

来源

POJ Monthly,Guang Lin

这个问题是一个经典的“K个最小和”问题。给定 m 个序列,每个序列有 n 个数,我们需要从每个序列中选出一个数组成新的序列并求和,在所有 nm 种可能的组合中,找出最小的 n 个和。

解题思路

  1. 两两合并: 由于总组合数 nm 非常巨大(如 2000100),直接生成是不现实的。我们可以采用两两合并的策略。先找前两个序列合并后的最小 n 个和,将结果作为一个新的“序列”,再与第三个序列合并,以此类推。
  2. 双序列合并算法: 假设有两个已排序的序列 AB,大小均为 n。我们想求它们组合出的 n2 个和中的前 n 小。
    • 首先对 AB 进行排序。
    • 使用最小堆(Min-Heap)来维护候选和。
    • 初始时,将 (A[i]+B[0],i,0) 放入堆中(其中 i[0,n1])。
    • 每次从堆顶取出最小的和 s=A[i]+B[j],这就是当前的第 k 小和。
    • 取出后,如果 j+1<n,则将下一个候选和 A[i]+B[j+1] 放入堆中。
    • 重复上述过程 n 次,即可得到 AB 合并后的前 n 个最小和。
  3. 复杂度分析
    • 总共有 m1 次合并操作。
    • 每次合并涉及 n 次堆操作,每次堆操作复杂度为 O(logn)
    • 总时间复杂度为 O(mnlogn)
    • 给定 m=100,n=2000,总操作次数约 2×106,在 Python 中 3000ms 的限制下完全可行。
  4. 内存限制
    • 64MB 限制要求我们不能一次性读入所有数据。使用生成器或逐个读取的方式可以有效控制内存。

Python 代码实现

python
import sys
import heapq

def get_ints():
    """从标准输入流中逐词读取整数,节省内存。"""
    for line in sys.stdin:
        for word in line.split():
            yield word

def solve():
    ints_gen = get_ints()

    try:
        token = next(ints_gen)
    except StopIteration:
        return

    # 测试用例数量
    t_cases = int(token)
    
    for _ in range(t_cases):
        try:
            m = int(next(ints_gen))
            n = int(next(ints_gen))
        except StopIteration:
            break
            
        # 读取第一个序列并排序
        res = []
        for i in range(n):
            res.append(int(next(ints_gen)))
        res.sort()
        
        # 依次合并剩余的 m-1 个序列
        for _ in range(m - 1):
            row = []
            for i in range(n):
                row.append(int(next(ints_gen)))
            row.sort()
            
            # 使用最小堆合并当前结果 res 和新序列 row
            # 堆中存储: (和, row序列的索引, res序列的值)
            h = [(res[i] + row[0], 0, res[i]) for i in range(n)]
            heapq.heapify(h)
            
            new_res = [0] * n
            for k in range(n):
                curr_sum, row_idx, res_val = h[0]
                new_res[k] = curr_sum
                
                if row_idx + 1 < n:
                    # 如果 row 序列还没到头,将该 res 值对应的下一个 row 值组合入堆
                    heapq.heapreplace(h, (res_val + row[row_idx + 1], row_idx + 1, res_val))
                # else:
                #     # 如果 row 到头了,弹出堆顶
                #     heapq.heappop(h)
            
            # 更新 res 为合并后的前 n 个最小和
            res = new_res
            
        # 按照题目格式输出最小的 n 个和
        sys.stdout.write(" ".join(map(str, res)) + "\n")

if __name__ == "__main__":
    solve()

关键点说明

  • heapq.heapify: 将列表原地转换为堆,时间复杂度 O(n)
  • heapq.heapreplace: 弹出堆顶并压入新元素,比先 poppush 更高效。
  • 内存管理: 通过 get_ints 生成器读取数据,确保即使输入文件很大也不会因内存溢出(MLE)而导致程序崩溃。
  • 排序: 虽然堆可以处理乱序,但算法前提是两序列之一有序(本代码中两个都排序了以确保逻辑最简),这样能保证每次从堆中弹出的必然是当前组合中最小的。

利用堆合并的方法依次求解两序列的最小 n 个和,从而逐步合并 m 个序列,避免枚举所有 n^m 种组合。

python
import sys
import heapq

def merge(arr1, arr2, n):
    """
    将两个有序数组 arr1 和 arr2 合并,求出所有组合中最小的 n 个和
    使用堆来进行合并搜索
    """
    heap = []
    visited = set()
    # 初始候选项:(arr1[0]+arr2[0], 0, 0)
    heapq.heappush(heap, (arr1[0] + arr2[0], 0, 0))
    visited.add((0, 0))
    result = []
    while len(result) < n:
        s, i, j = heapq.heappop(heap)
        result.append(s)
        # 如果 arr1 中的下一个数存在,尝试加入候选项
        if i + 1 < n and (i + 1, j) not in visited:
            heapq.heappush(heap, (arr1[i + 1] + arr2[j], i + 1, j))
            visited.add((i + 1, j))
        # 如果 arr2 中的下一个数存在,尝试加入候选项
        if j + 1 < n and (i, j + 1) not in visited:
            heapq.heappush(heap, (arr1[i] + arr2[j + 1], i, j + 1))
            visited.add((i, j + 1))
    return result

def main():
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    T = int(next(it))
    results = []
    for _ in range(T):
        m = int(next(it))
        n = int(next(it))
        # 读取第一个序列,并排序
        current = sorted(int(next(it)) for _ in range(n))
        # 依次与后续的 m-1 个序列合并
        for _ in range(m - 1):
            seq = sorted(int(next(it)) for _ in range(n))
            current = merge(current, seq, n)
        results.append(" ".join(map(str, current)))
    sys.stdout.write("\n".join(results))

if __name__ == "__main__":
    main()

代码说明

  • merge 函数
    该函数接受两个有序数组 arr1arr2,利用最小堆依次寻找组合中最小的 n 个和。我们用 visited 集合避免重复放入堆中。

  • 主函数
    先读取测试用例数 T,再依次处理每个测试用例。每个测试用例中,首先将第一个序列排序作为初始的结果,再依次将后续序列与当前结果进行合并。最终输出最小的 n 个和。

    该算法利用堆优化,每次合并时间复杂度约为 O(n log n),适合 m 与 n 的题目范围。

参考链接:https://blog.csdn.net/liuwei_nefu/article/details/5645528 题意是 给出 m组数,每组 n个数 然后从m组中 每组选出一个进行求和 ,然后取其中前n小的数输出。 选择的总数自然是 n的m次方,暴力法自然是超时的。

一个简单的思路是,从第一组到第m组依次处理。 首先第一组的n个数自然是最小的n个数, 然后这n个数和第二组的n个组进行组合,形成n×n个数,保留其前n个数,再处理第三组,依次类推直到第m组。

为什么保留前n个数就可以了呢? 我们以第一组和第二组例所得的n×n个数为例,假设保留n+1个数,且这第n+1个数加上第三组的某个数x的和 在下一步中需要保留(即在下一步操作中属于前n小的数之一),然而前n个数中的任意一个数+x < 第n+1个数+x ,此时得出矛盾, 由此可知,每次处理后的n×n个数中只需保留前n个数即可

python
import heapq

t = int(input())
for _ in range(t):
    m, n = map(int, input().split())
    seq1 = sorted(map(int, input().split()))
    for _ in range(m - 1):
        seq2 = sorted(map(int, input().split()))

        # 使用最小堆存储可能的最小和以及对应的索引
        min_heap = [(seq1[i] + seq2[0], i, 0) for i in range(n)]
        heapq.heapify(min_heap)
        result = []
        for _ in range(n):
            current_sum, i, j = heapq.heappop(min_heap)
            result.append(current_sum)
            if j + 1 < len(seq2):
                heapq.heappush(min_heap, (seq1[i] + seq2[j + 1], i, j + 1))
        seq1 = result
    print(*seq1)

补充候选:一旦 (seq1[i] + seq2[j]) 被选中,唯一的竞争对手可能就是 (seq1[i] + seq2[j+1])。