3478.选出和最大的K个元素
heap, https://leetcode.cn/problems/choose-k-elements-with-maximum-sum/
给你两个整数数组,nums1 和 nums2,长度均为 n,以及一个正整数 k 。
对从 0 到 n - 1 每个下标 i ,执行下述操作:
- 找出所有满足
nums1[j]小于nums1[i]的下标j。 - 从这些下标对应的
nums2[j]中选出 至多k个,并 最大化 这些值的总和作为结果。
返回一个长度为 n 的数组 answer ,其中 answer[i] 表示对应下标 i 的结果。
示例 1:
输入:nums1 = [4,2,1,5,3], nums2 = [10,20,30,40,50], k = 2
输出:[80,30,0,80,50]
解释:
- 对于
i = 0:满足nums1[j] < nums1[0]的下标为[1, 2, 4],选出其中值最大的两个,结果为50 + 30 = 80。 - 对于
i = 1:满足nums1[j] < nums1[1]的下标为[2],只能选择这个值,结果为30。 - 对于
i = 2:不存在满足nums1[j] < nums1[2]的下标,结果为0。 - 对于
i = 3:满足nums1[j] < nums1[3]的下标为[0, 1, 2, 4],选出其中值最大的两个,结果为50 + 30 = 80。 - 对于
i = 4:满足nums1[j] < nums1[4]的下标为[1, 2],选出其中值最大的两个,结果为30 + 20 = 50。
示例 2:
输入:nums1 = [2,2,2,2], nums2 = [3,1,2,3], k = 1
输出:[0,0,0,0]
解释:由于 nums1 中的所有元素相等,不存在满足条件 nums1[j] < nums1[i],所有位置的结果都是 0 。
提示:
n == nums1.length == nums2.length1 <= n <= 10^51 <= nums1[i], nums2[i] <= 10^61 <= k <= n
class Solution:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
a = sorted((x, y, i) for i, (x, y) in enumerate(zip(nums1, nums2)))
n = len(a)
ans = [0] * n
h = []
s = 0
for i, (x, y, idx) in enumerate(a):
ans[idx] = ans[a[i - 1][2]] if i and x == a[i - 1][0] else s
s += y
if len(h) < k:
heappush(h, y)
else:
s -= heappushpop(h, y)
return ans【叶靖 信管系】
import heapq
class Solution:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
n = len(nums1)
result = [0] * n
pairs = sorted((num, i) for i, num in enumerate(nums1))
min_heap = []
total_sum = 0
j = 0
for value, i in pairs:
while j < n and pairs[j][0] < value:
_, idx = pairs[j]
heapq.heappush(min_heap, nums2[idx])
total_sum += nums2[idx]
if len(min_heap) > k:
total_sum -= heapq.heappop(min_heap)
j += 1
result[i] = total_sum
return result思路:观察数据范围得知复杂度为nums1排序之后就只需要遍历一遍就可以解决。
发现heappushpop似乎会比先heappush再heappop快。
# 张景天 物理学院
class Solution:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
import heapq
indexs = sorted(enumerate(nums1), key=lambda x: x[1])
heap = [0] * k
max_sum = [0] * len(nums1)
j = 0
s = 0
for i in range(len(indexs)):
while indexs[j][1] < indexs[i][1]:
s += nums2[indexs[j][0]]
s -= heapq.heappushpop(heap, nums2[indexs[j][0]])
j += 1
max_sum[indexs[i][0]] = s
return max_sumfrom typing import List
import heapq
class Solution:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
n = len(nums1)
res = [0] * n
# 构造元组 (nums1[i], i, nums2[i]) 并按 nums1 升序排序
arr = [(nums1[i], i, nums2[i]) for i in range(n)]
arr.sort(key=lambda x: x[0])
# 使用最小堆维护之前(即 nums1 较小的)的 nums2 值,堆内最多保存 k 个数
heap = []
heap_sum = 0
i = 0
while i < n:
cur_val = arr[i][0]
j = i
# 对于同一组(nums1 值相同),先将当前堆中记录的和赋值给答案
while j < n and arr[j][0] == cur_val:
_, idx, _ = arr[j]
res[idx] = heap_sum
j += 1
# 再把这一组的元素加入堆中(加入后会用于后续 nums1 更大的位置)
while i < j:
_, idx, value = arr[i]
heapq.heappush(heap, value)
heap_sum += value
if len(heap) > k:
# 如果堆内超过 k 个,则移除最小值
heap_sum -= heapq.heappop(heap)
i += 1
return res
if __name__ == "__main__":
sol = Solution()
print(sol.findMaxSum([4, 2, 1, 5, 3], [10, 20, 30, 40, 50], 2)) # [80, 30, 0, 80, 50]说明
- 排序与分组 先将所有下标及对应的
(nums1, index, nums2)组成元组并按nums1升序排序。对于相同的nums1值,先不把它们加入堆(这样可以防止同一组内互相影响,因为要求严格<),而是先将当前堆中记录的累计和赋值给答案。 - 维护堆与累计和 使用一个最小堆来维护前面所有满足条件的
nums2值,并且保证堆中最多保留k个最大的数。每加入一个新值后,如果堆大小超过k,就弹出堆顶最小值,并调整累计和。 - 结果映射 最后,答案数组
res中每个位置记录的就是对应下标的最大和。
让我详细拆解讲解这个算法的思路!虽然代码里有排序、堆等操作,但逻辑其实是围绕两大核心目标:
🚀 1. 如何找出所有满足条件的下标 j?
题目的要求是:
对于每个下标
i,找到所有满足nums1[j] < nums1[i]的j。为了快速找到符合条件的
j,我们采取排序 + 线性扫描的方法:
- 首先,我们把所有的
(nums1[i], i, nums2[i])组合成一个三元组数组arr。- 然后按照
nums1升序排序。排序后的效果是:
- 较小的
nums1会出现在前面,也就是说,如果我们正在处理nums1[i],它左边的那些值必然都比它小!因此,排序完之后,我们只需要:
- 从左往右遍历一次数组,就可以保证:
当我们在第i个位置时,左边的数自动满足nums1[j] < nums1[i]。- 这样就避免了暴力双重循环去检查所有组合,大大提升效率!
💡 2. 如何挑选 k 个最大的 nums2[j] 并求和?
我们接下来解决第二个任务:
从满足条件的 j 中挑选最多 k 个 nums2[j],并让这些值的总和最大化。
- 由于我们是从左往右扫描,每次处理到一个新位置
i时,把所有满足nums1[j] < nums1[i]的值的nums2[j]加入到堆中。- 为什么用最小堆?
因为我们想维持最大的 k 个值:
- 每次往堆里加一个新的
nums2[j]- 如果堆的大小超过 k,就把堆顶最小的元素弹出去,确保堆里始终是最大的 k 个值
- 堆顶元素是最小值,因此堆的和就是当前前 k 大的 nums2[j] 的和!
🌟 处理相同的 nums1[i]
还有个关键点:
如果多个 nums1[i] 值相同怎么办?
假如我们同时有多个
nums1[i]等于 2,它们互相之间不应影响彼此,因为条件是严格小于。
- 例如:对于
[2, 2, 2, 2],无论哪一个 2,它左边都没有比自己小的数,所以结果都是 0。怎么实现这一点呢?
- 我们在遍历排序后的数组时,使用一个
while循环,一次性处理所有相同的nums1[i]。- 先计算答案,再更新堆:
- 避免同一组的数影响彼此。
- 只有在处理完这一组之后,才把它们加入堆中,为下一个更大的值做准备。
🏎️ 时间复杂度分析
- 排序:O(n log n)
- 遍历:O(n)
- 堆操作:O(log k)(每次插入或弹出堆顶)
最终复杂度是:
[ O(n \log n) ] 这比暴力解法的 (O(n^2)) 快了好几个数量级!💥🎯 代码逻辑总结
排序
- 按照
nums1[i]升序排序三元组(nums1[i], i, nums2[i])遍历处理
- 用堆维护“左边所有比当前值小的数对应的 nums2”
- 遇到新的一组值,先把当前堆的和保存到结果里
- 再把这组值加入堆,为后续更大的数准备数据
结果更新
- 每次遇到更大的
nums1[i],答案直接来自堆的和- 如果堆大小超出 k,就移除最小值,保持最大的 k 个元素