Skip to content

T3013.将数组分成最小总代价的子数组 II

sliding window, heap, https://leetcode.cn/problems/divide-an-array-into-subarrays-with-minimum-cost-ii/

给你一个下标从 0 开始长度为 n 的整数数组 nums 和两个 整数 kdist

一个数组的 代价 是数组中的 第一个 元素。比方说,[1,2,3] 的代价为 1[3,4,1] 的代价为 3

你需要将 nums 分割成 k连续且互不相交 的子数组,满足 第二 个子数组与第 k 个子数组中第一个元素的下标距离 不超过 dist 。换句话说,如果你将 nums 分割成子数组 nums[0..(i1 - 1)], nums[i1..(i2 - 1)], ..., nums[ik-1..(n - 1)] ,那么它需要满足 ik-1 - i1 <= dist

请你返回这些子数组的 最小 总代价。

示例 1:

输入:nums = [1,3,2,6,4,2], k = 3, dist = 3
输出:5
解释:将数组分割成 3 个子数组的最优方案是:[1,3] ,[2,6,4] 和 [2] 。这是一个合法分割,因为 ik-1 - i1 等于 5 - 2 = 3 ,等于 dist 。总代价为 nums[0] + nums[2] + nums[5] ,也就是 1 + 2 + 2 = 5 。
5 是分割成 3 个子数组的最小总代价。

示例 2:

输入:nums = [10,1,2,2,2,1], k = 4, dist = 3
输出:15
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[1] ,[2] 和 [2,2,1] 。这是一个合法分割,因为 ik-1 - i1 等于 3 - 1 = 2 ,小于 dist 。总代价为 nums[0] + nums[1] + nums[2] + nums[3] ,也就是 10 + 1 + 2 + 2 = 15 。
分割 [10] ,[1] ,[2,2,2] 和 [1] 不是一个合法分割,因为 ik-1 和 i1 的差为 5 - 1 = 4 ,大于 dist 。
15 是分割成 4 个子数组的最小总代价。

示例 3:

输入:nums = [10,8,18,9], k = 3, dist = 1
输出:36
解释:将数组分割成 4 个子数组的最优方案是:[10] ,[8] 和 [18,9] 。这是一个合法分割,因为 ik-1 - i1 等于 2 - 1 = 1 ,等于 dist 。总代价为 nums[0] + nums[1] + nums[2] ,也就是 10 + 8 + 18 = 36 。
分割 [10] ,[8,18] 和 [9] 不是一个合法分割,因为 ik-1 和 i1 的差为 3 - 1 = 2 ,大于 dist 。
36 是分割成 3 个子数组的最小总代价。

提示:

  • 3 <= n <= 10^5
  • 1 <= nums[i] <= 10^9
  • 3 <= k <= n
  • k - 2 <= dist <= n - 2

为了最小化总代价,我们需要将数组分成 k 个连续子数组。代价是每个子数组第一个元素的总和。 由于第一个子数组总是从下标 0 开始,因此 nums[0] 总是包含在总代价中。 我们需要在范围 [1,n1] 中选择 k1 个下标作为剩余 k1 个子数组的起始位置。设这些下标为 i1,i2,,ik1。 题目限制条件为 ik1i1dist。这意味着这 k1 个下标必须位于一个长度为 dist+1 的窗口内。

核心思路:

  1. 我们要在数组 nums[1n1] 中滑动一个长度为 dist+1 的窗口。
  2. 对于每一个窗口,我们从中选出 k1 个最小的元素,它们的和加上 nums[0] 就是该窗口下的最小代价。
  3. 我们遍历所有可能的窗口,并维护所有窗口代价的全局最小值。
  4. 为了高效地维护窗口内前 m=k1 个最小元素的和,我们可以使用树状数组(BIT)配合离散化,利用树状数组上的二分倍增(Binary Lifting)O(logn) 时间内找到前 m 小元素的和。

代码实现:

python
import math
from typing import List

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        n = len(nums)
        m = k - 1
        # 我们在 nums[1:] 中选择 k-1 个起始下标
        arr = nums[1:]
        
        # 1. 离散化:处理 nums[i] 过大的问题,映射到 [1, U] 范围
        sorted_unique = sorted(list(set(arr)))
        U = len(sorted_unique)
        rank = {v: i + 1 for i, v in enumerate(sorted_unique)}
        
        # 2. 树状数组维护频率和数值总和
        # t_cnt[i] 维护离散化后索引 i 对应的频率前缀和
        # t_sum[i] 维护离散化后索引 i 对应的 (数值 * 频率) 前缀和
        t_cnt = [0] * (U + 1)
        t_sum = [0] * (U + 1)
        
        # 预计算用于树状数组二分倍增的步长
        bit_len = U.bit_length()
        bit_powers = [1 << i for i in range(bit_len - 1, -1, -1)]
        
        min_window_sum = float('inf')
        
        # 3. 滑动窗口
        for i in range(n - 1):
            # 添加当前元素 arr[i] 到树状数组
            val = arr[i]
            idx = rank[val]
            cur = idx
            while cur <= U:
                t_cnt[cur] += 1
                t_sum[cur] += val
                cur += cur & (-cur)
            
            # 如果窗口大小超过 dist + 1,移除窗口左侧失效元素
            if i > dist:
                val_r = arr[i - dist - 1]
                idx_r = rank[val_r]
                cur = idx_r
                while cur <= U:
                    t_cnt[cur] -= 1
                    t_sum[cur] -= val_r
                    cur += cur & (-cur)
            
            # 当窗口达到指定大小(或滑动过程中)进行查询
            if i >= dist:
                # 4. 树状数组二分倍增:寻找前 m 个最小元素的和
                pos = 0
                curr_cnt = 0
                curr_sum = 0
                for p in bit_powers:
                    next_pos = pos + p
                    if next_pos <= U:
                        cnt = t_cnt[next_pos]
                        if curr_cnt + cnt < m:
                            pos = next_pos
                            curr_cnt += cnt
                            curr_sum += t_sum[pos]
                
                # 补足剩余所需的元素个数(来自下一个离散化数值)
                # 注意:pos 是已经完全包含的离散化索引数量,下一个索引是 pos+1
                # 对应 sorted_unique 的下标是 pos
                s_total = curr_sum + (m - curr_cnt) * sorted_unique[pos]
                if s_total < min_window_sum:
                    min_window_sum = s_total
        
        return nums[0] + int(min_window_sum)

复杂度分析:

  • 时间复杂度O(nlogn)。排序和建立映射表需要 O(nlogn);滑动窗口进行 n 次操作,每次树状数组更新和二分倍增查询均为 O(logn)
  • 空间复杂度O(n)。用于存储离散化映射、树状数组以及窗口内的统计信息。

这个问题的核心可以通俗地理解为:

在数组中找一个长度为 dist + 1 的“滑动窗口”,在这个窗口里选出最小的 k-1 个数字。我们要让这些数字的和尽可能小。

代码看不懂是因为之前的版本使用了“树状数组+离散化”,那是一种极其高效但比较抽象的写法。为了让你理解,我们换一种更直观的思路:使用两个集合(双有序集合)来维护窗口内的信息。


1. 核心思路图解

想象你有一个长度为 dist + 1 的窗口在数组中滑动:

  • 我们把窗口里的数字分成两部分:
    1. 左半部分 (L):保存窗口内最小的 k-1 个数字。
    2. 右半部分 (R):保存窗口内剩下的其他数字。
  • 我们需要实时维护 L 里的数字总和

当窗口向右移动时:

  1. 进来一个新数字
    • 先把新数放进 L。
    • 如果 L 的数字太多了(超过 k-1 个),就把 L 里最大的那个踢出去,扔进 R。
  2. 离开一个旧数字
    • 如果这个旧数字在 L 里,把它删掉,然后从 R 里挑一个最小的补进 L。
    • 如果旧数字在 R 里,直接从 R 里删掉即可。

2. 直观的代码实现 (使用 SortedList)

在 Python 中,我们可以使用 sortedcontainers 库里的 SortedList(LeetCode 环境内置了该库),它可以像数组一样自动排序,并且支持快速添加和删除。

python
from sortedcontainers import SortedList

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        n = len(nums)
        m = k - 1  # 我们需要在窗口内选出 m 个最小的数
        
        # L 存储窗口内最小的 m 个数, R 存储剩下的数
        L = SortedList()
        R = SortedList()
        current_l_sum = 0 # 实时记录 L 中所有数字的和
        
        # 辅助函数:将数字 x 加入窗口
        def add(x):
            nonlocal current_l_sum
            L.add(x)
            current_l_sum += x
            # 如果 L 里的数超过了 m 个,把最大的移到 R
            if len(L) > m:
                max_val = L.pop() # 弹出 L 中最大的
                current_l_sum -= max_val
                R.add(max_val)

        # 辅助函数:将数字 x 从窗口中移除
        def remove(x):
            nonlocal current_l_sum
            if x in R:
                R.remove(x)
            else:
                L.remove(x)
                current_l_sum -= x
                # L 少了一个数,从 R 里拿一个最小的来补位
                if R:
                    min_val = R.pop(0) # 弹出 R 中最小的
                    L.add(min_val)
                    current_l_sum += min_val

        # --- 正式逻辑 ---
        
        # 1. 初始化:第一个合法的窗口是 nums[1 : dist+2]
        # (因为题目规定第二个子数组起始下标 i1 和最后一个 i_k-1 距离不超过 dist)
        for i in range(1, dist + 2):
            add(nums[i])
        
        # 初始窗口的代价 (nums[0] 是固定要加的)
        ans = nums[0] + current_l_sum
        
        # 2. 开始滑动窗口
        # 窗口左边界是 i-dist, 右边界是 i
        for i in range(dist + 2, n):
            remove(nums[i - dist - 1]) # 移除滑出窗口的左侧元素
            add(nums[i])               # 加入新进入窗口的右侧元素
            
            # 更新全局最小代价
            ans = min(ans, nums[0] + current_l_sum)
            
        return ans

3. 重点难点拆解

  1. 为什么是 k-1 个数? 题目要把数组分成 k 个子数组。第一个子数组的代价固定是 nums[0]。剩下还要找 k-1 个子数组的开头。这 k-1 个开头的下标只要落在长度为 dist+1 的范围内即可。

  2. 窗口的范围是多少?

    • 第一个子数组开头下标永远是 0
    • 第二个子数组开头下标 i1 最小可以是 1
    • 最后一个子数组开头下标 ik1 最大满足 ik1i1dist
    • 所以,如果我们固定一个窗口 [i1,i1+dist],在这个窗口里选出的所有下标对应的数字中,最小的 k1 个就是我们要的。
  3. 时间复杂度:

    • SortedList 的添加和删除操作都是 O(log(窗口大小))
    • 我们一共遍历数组一次(n 次操作)。
    • 总时间复杂度是 O(nlogn),对于 105 的数据量来说非常快。

总结

这道题本质上是一个滑动窗口内求前 K 小元素和的问题。通过维护两个有序集合,我们可以很方便地在窗口移动时,实时更新这 K 小元素的总和。