Skip to content

3478.选出和最大的K个元素

heap, https://leetcode.cn/problems/choose-k-elements-with-maximum-sum/

给你两个整数数组,nums1nums2,长度均为 n,以及一个正整数 k

对从 0n - 1 每个下标 i ,执行下述操作:

  • 找出所有满足 nums1[j] 小于 nums1[i] 的下标 j
  • 从这些下标对应的 nums2[j] 中选出 至多 k 个,并 最大化 这些值的总和作为结果。

返回一个长度为 n 的数组 answer ,其中 answer[i] 表示对应下标 i 的结果。

示例 1:

输入:nums1 = [4,2,1,5,3], nums2 = [10,20,30,40,50], k = 2

输出:[80,30,0,80,50]

解释:

  • 对于 i = 0 :满足 nums1[j] < nums1[0] 的下标为 [1, 2, 4] ,选出其中值最大的两个,结果为 50 + 30 = 80
  • 对于 i = 1 :满足 nums1[j] < nums1[1] 的下标为 [2] ,只能选择这个值,结果为 30
  • 对于 i = 2 :不存在满足 nums1[j] < nums1[2] 的下标,结果为 0
  • 对于 i = 3 :满足 nums1[j] < nums1[3] 的下标为 [0, 1, 2, 4] ,选出其中值最大的两个,结果为 50 + 30 = 80
  • 对于 i = 4 :满足 nums1[j] < nums1[4] 的下标为 [1, 2] ,选出其中值最大的两个,结果为 30 + 20 = 50

示例 2:

输入:nums1 = [2,2,2,2], nums2 = [3,1,2,3], k = 1

输出:[0,0,0,0]

解释:由于 nums1 中的所有元素相等,不存在满足条件 nums1[j] < nums1[i],所有位置的结果都是 0 。

提示:

  • n == nums1.length == nums2.length
  • 1 <= n <= 10^5
  • 1 <= nums1[i], nums2[i] <= 10^6
  • 1 <= k <= n
python
class Solution:
    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        a = sorted((x, y, i) for i, (x, y) in enumerate(zip(nums1, nums2)))
        n = len(a)
        ans = [0] * n
        h = []
        s = 0
        for i, (x, y, idx) in enumerate(a):
            ans[idx] = ans[a[i - 1][2]] if i and x == a[i - 1][0] else s
            s += y
            if len(h) < k:
                heappush(h, y)
            else:
                s -= heappushpop(h, y)
        return ans

【叶靖 信管系】

python
import heapq

class Solution:
    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        n = len(nums1)
        result = [0] * n
        pairs = sorted((num, i) for i, num in enumerate(nums1))
        min_heap = []
        total_sum = 0
        j = 0
        
        for value, i in pairs:
            while j < n and pairs[j][0] < value:
                _, idx = pairs[j]
                heapq.heappush(min_heap, nums2[idx])
                total_sum += nums2[idx]
                if len(min_heap) > k:
                    total_sum -= heapq.heappop(min_heap)
                j += 1
            result[i] = total_sum
        
        return result

思路:观察数据范围得知复杂度为O(nlogn),故考虑使用堆维护最大的k个元素,另外还需要维护和值,否则有无法接受的O(nk)额外开销。初始对nums1排序之后就只需要遍历一遍就可以解决。

发现heappushpop似乎会比先heappushheappop快。

python
# 张景天 物理学院
class Solution:
    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        import heapq
        indexs = sorted(enumerate(nums1), key=lambda x: x[1])
        heap = [0] * k
        max_sum = [0] * len(nums1)
        j = 0
        s = 0
        for i in range(len(indexs)):
            while indexs[j][1] < indexs[i][1]:
                s += nums2[indexs[j][0]]
                s -= heapq.heappushpop(heap, nums2[indexs[j][0]])
                j += 1
            max_sum[indexs[i][0]] = s
        return max_sum
python
from typing import List
import heapq


class Solution:
    def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
        n = len(nums1)
        res = [0] * n
        # 构造元组 (nums1[i], i, nums2[i]) 并按 nums1 升序排序
        arr = [(nums1[i], i, nums2[i]) for i in range(n)]
        arr.sort(key=lambda x: x[0])

        # 使用最小堆维护之前(即 nums1 较小的)的 nums2 值,堆内最多保存 k 个数
        heap = []
        heap_sum = 0

        i = 0
        while i < n:
            cur_val = arr[i][0]
            j = i
            # 对于同一组(nums1 值相同),先将当前堆中记录的和赋值给答案
            while j < n and arr[j][0] == cur_val:
                _, idx, _ = arr[j]
                res[idx] = heap_sum
                j += 1
            # 再把这一组的元素加入堆中(加入后会用于后续 nums1 更大的位置)
            while i < j:
                _, idx, value = arr[i]
                heapq.heappush(heap, value)
                heap_sum += value
                if len(heap) > k:
                    # 如果堆内超过 k 个,则移除最小值
                    heap_sum -= heapq.heappop(heap)
                i += 1
        return res


if __name__ == "__main__":
    sol = Solution()
    print(sol.findMaxSum([4, 2, 1, 5, 3], [10, 20, 30, 40, 50], 2))  # [80, 30, 0, 80, 50]

说明

  1. 排序与分组 先将所有下标及对应的 (nums1, index, nums2) 组成元组并按 nums1 升序排序。对于相同的 nums1值,先不把它们加入堆(这样可以防止同一组内互相影响,因为要求严格 <),而是先将当前堆中记录的累计和赋值给答案。
  2. 维护堆与累计和 使用一个最小堆来维护前面所有满足条件的 nums2 值,并且保证堆中最多保留 k 个最大的数。每加入一个新值后,如果堆大小超过 k,就弹出堆顶最小值,并调整累计和。
  3. 结果映射 最后,答案数组 res 中每个位置记录的就是对应下标的最大和。

让我详细拆解讲解这个算法的思路!虽然代码里有排序、堆等操作,但逻辑其实是围绕两大核心目标

🚀 1. 如何找出所有满足条件的下标 j?

题目的要求是:

对于每个下标 i,找到所有满足 nums1[j] < nums1[i]j

为了快速找到符合条件的 j,我们采取排序 + 线性扫描的方法:

  • 首先,我们把所有的 (nums1[i], i, nums2[i]) 组合成一个三元组数组 arr
  • 然后按照 nums1 升序排序

排序后的效果是:

  • 较小的 nums1 会出现在前面,也就是说,如果我们正在处理 nums1[i],它左边的那些值必然都比它小!

因此,排序完之后,我们只需要:

  • 从左往右遍历一次数组,就可以保证:
    当我们在第 i 个位置时,左边的数自动满足 nums1[j] < nums1[i]
  • 这样就避免了暴力双重循环去检查所有组合,大大提升效率!

💡 2. 如何挑选 k 个最大的 nums2[j] 并求和?

我们接下来解决第二个任务:

从满足条件的 j 中挑选最多 k 个 nums2[j],并让这些值的总和最大化。

  • 由于我们是从左往右扫描,每次处理到一个新位置 i 时,把所有满足 nums1[j] < nums1[i] 的值的 nums2[j] 加入到堆中。
  • 为什么用最小堆?
    因为我们想维持最大的 k 个值
    • 每次往堆里加一个新的 nums2[j]
    • 如果堆的大小超过 k,就把堆顶最小的元素弹出去,确保堆里始终是最大的 k 个值
  • 堆顶元素是最小值,因此堆的和就是当前前 k 大的 nums2[j] 的和

🌟 处理相同的 nums1[i]

还有个关键点:

如果多个 nums1[i] 值相同怎么办?

假如我们同时有多个 nums1[i] 等于 2,它们互相之间不应影响彼此,因为条件是严格小于

  • 例如:对于 [2, 2, 2, 2],无论哪一个 2,它左边都没有比自己小的数,所以结果都是 0。

怎么实现这一点呢?

  • 我们在遍历排序后的数组时,使用一个 while 循环,一次性处理所有相同的 nums1[i]
  • 先计算答案,再更新堆
    • 避免同一组的数影响彼此。
    • 只有在处理完这一组之后,才把它们加入堆中,为下一个更大的值做准备。

🏎️ 时间复杂度分析

  • 排序:O(n log n)
  • 遍历:O(n)
  • 堆操作:O(log k)(每次插入或弹出堆顶)

最终复杂度是:
[ O(n \log n) ] 这比暴力解法的 (O(n^2)) 快了好几个数量级!💥


🎯 代码逻辑总结

  1. 排序

    • 按照 nums1[i] 升序排序三元组 (nums1[i], i, nums2[i])
  2. 遍历处理

    • 用堆维护“左边所有比当前值小的数对应的 nums2”
    • 遇到新的一组值,先把当前堆的和保存到结果里
    • 再把这组值加入堆,为后续更大的数准备数据
  3. 结果更新

    • 每次遇到更大的 nums1[i],答案直接来自堆的和
    • 如果堆大小超出 k,就移除最小值,保持最大的 k 个元素