Skip to content

T3510.移除最小数对使数组有序II

doubly-linked list + heap, https://leetcode.cn/problems/minimum-pair-removal-to-sort-array-ii/

给你一个数组 nums,你可以执行以下操作任意次数:

  • 选择 相邻 元素对中 和最小 的一对。如果存在多个这样的对,选择最左边的一个。
  • 用它们的和替换这对元素。

返回将数组变为 非递减 所需的 最小操作次数

如果一个数组中每个元素都大于或等于它前一个元素(如果存在的话),则称该数组为非递减

示例 1:

输入: nums = [5,2,3,1]

输出: 2

解释:

  • 元素对 (3,1) 的和最小,为 4。替换后 nums = [5,2,4]
  • 元素对 (2,4) 的和为 6。替换后 nums = [5,6]

数组 nums 在两次操作后变为非递减。

示例 2:

输入: nums = [1,2,2]

输出: 0

解释:

数组 nums 已经是非递减的。

提示:

  • 1 <= nums.length <= 10^5
  • -10^9 <= nums[i] <= 10^9

下面给出一个基于优先队列 + 双向链表的 O(nlog⁡n)实现思路:

  1. 初始化

    • 将数组中的每个元素封装成一个双向链表节点 Node(val),并且用 prev/next 串起来。
    • 用一个小顶堆存所有相邻节点对的「和」,即 (sum, timestamp, left_node)timestamp 用来区分同样 sum 的不同对。
    • 统计初始的「逆序对数」 bad,即有多少处 node.next.val < node.val
  2. 主循环

    • bad == 0(已经非递减)或链表只剩一个节点时结束。

    • 从堆里弹出当前最小的 (sum, ts, left),如果这对节点已经被合并过(检查 left.next 是否还在链表里)就跳过。

    • 否则把这对节点合并成一个新节点 m = Node(sum),并插入原来这对节点的位置:

      left.prev <-> left <-> right <-> right.next
              ↓               ↓
      left.prev <->   m   <-> right.next
    • 更新「逆序对数」bad

      • 删除原来 (left.prev, left)(right, right.next) 两处可能的逆序,
      • 新增 (left.prev, m)(m, right.next) 两处可能的逆序。
    • m 与它的新左右相邻节点组成的两对新「和」重新 push 进堆。

    • 计数 cnt += 1

python
import heapq
from typing import List

class Node:
    def __init__(self, val: int, index: int):
        self.val = val
        self.prev = None
        self.next = None
        self.alive = True
        self.index = index

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        n = len(nums)
        if n <= 1:
            return 0

        # 初始化节点和双向链表
        nodes = [Node(nums[i], i) for i in range(n)]
        for i in range(n):
            if i > 0:
                nodes[i].prev = nodes[i - 1]
            else:
                nodes[i].prev = None
            if i < n - 1:
                nodes[i].next = nodes[i + 1]
            else:
                nodes[i].next = None

        # 计算初始逆序对数
        bad = 0
        for i in range(n - 1):
            if nodes[i].val > nodes[i + 1].val:
                bad += 1

        # 初始化堆
        heap = []
        for i in range(n - 1):
            current_node = nodes[i]
            next_node = current_node.next
            heapq.heappush(heap, (current_node.val + next_node.val, i))

        cnt = 0

        while bad > 0:
            if not heap:
                break  # 堆为空但仍有逆序对,说明逻辑错误

            s, i = heapq.heappop(heap)
            current_node = nodes[i]
            next_node = current_node.next

            # 检查 next_node 是否存在
            if next_node is None:
                continue

            # 跳过无效条目
            if not current_node.alive or not next_node.alive or (current_node.val + next_node.val) != s:
                continue

            prev_node = current_node.prev
            next_next_node = next_node.next

            # 移除旧逆序对
            # 1. prev_node 和 current_node 的逆序
            if prev_node and prev_node.alive and prev_node.val > current_node.val:
                bad -= 1
            # 2. current_node 和 next_node 的逆序
            if current_node.val > next_node.val:
                bad -= 1
            # 3. next_node 和 next_next_node 的逆序
            if next_next_node and next_next_node.alive and next_node.val > next_next_node.val:
                bad -= 1

            # 合并 next_node 到 current_node
            current_node.val += next_node.val
            next_node.alive = False

            # 更新指针
            current_node.next = next_next_node
            if next_next_node:
                next_next_node.prev = current_node
            else:
                current_node.next = None  # 确保指针正确

            # 添加新逆序对
            # 1. prev_node 和 current_node 的新逆序
            if prev_node and prev_node.alive and prev_node.val > current_node.val:
                bad += 1
            # 2. current_node 和 next_next_node 的新逆序
            if next_next_node and next_next_node.alive and current_node.val > next_next_node.val:
                bad += 1

            # 将新邻对推入堆
            if prev_node and prev_node.alive:
                heapq.heappush(heap, (prev_node.val + current_node.val, prev_node.index))
            if next_next_node and next_next_node.alive:
                heapq.heappush(heap, (current_node.val + next_next_node.val, current_node.index))

            cnt += 1

        return cnt

if __name__ == "__main__":
    s = Solution()
    print(s.minimumPairRemoval([5, 2, 3, 1]))  # 输出 2
    print(s.minimumPairRemoval([1, 2, 2]))     # 输出 0
    print(s.minimumPairRemoval([1, 1, 4, 4, 2, -4, -1]))  # 输出 5
    print(s.minimumPairRemoval([3,6,4,-6,2,-4,5,-7,-3,6,3,-4]))  # 输出 10
    print(s.minimumPairRemoval([2,2,-1,3,-2,2,1,1,1,0,-1]))  # 输出 9
    print(s.minimumPairRemoval([-1,2,2,-2,-3,0,2,1,0,0,1]))  # 输出 9

完整的双向链表(通过两个列表来模拟)与最小堆+懒删除

python
import heapq
class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        n = len(nums)
        pairs = [(nums[i], nums[i + 1]) for i in range(n - 1)]
        h = []
        dec = 0
        for i, (x, y) in enumerate(pairs):
            if x > y:
                dec += 1
            h.append((x + y, i))

        heapq.heapify(h)

        # 模拟双向链表
        left = list(range(-1, n))
        right = list(range(1, n + 1))

        ans = 0
        while dec:
            ans += 1

            # 懒删除
            while right[h[0][1]] >= n or nums[h[0][1]] + nums[right[h[0][1]]] != h[0][0]:
                heapq.heappop(h)

            s, i = heapq.heappop(h)
            nxt = right[i]
            pre = left[i]
            nxt2 = right[nxt]

            if nums[nxt] < nums[i]:
                dec -= 1

            if pre >= 0:
                if nums[pre] > s:
                    dec += 1
                if nums[pre] > nums[i]:
                    dec -= 1
                heapq.heappush(h, (nums[pre] + s, pre))

            if nxt2 < n:
                if nums[nxt] > nums[nxt2]:
                    dec -= 1
                if nums[nxt2] < s:
                    dec += 1
                heapq.heappush(h, (s + nums[nxt2], i))

            nums[i] = s
            # 删除 nxt
            right[i] = nxt2
            left[nxt2] = i
            right[nxt] = n

        return ans

「数组模拟双向链表 + 小顶堆」写法。它对每次合并都严格检查“节点还活着”并且“相邻”,同时只在合并点附近更新逆序对计数,避免全局扫描。

python
from typing import List
import heapq

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        n = len(nums)
        if n <= 1:
            return 0

        # 值数组
        v = nums[:]
        # 左右指针:模拟双向链表
        L = [i-1 for i in range(n)]
        R = [i+1 for i in range(n)]
        # 标记节点是否还在链表中
        alive = [True] * n

        # 1) 计算初始逆序对数 bad
        bad = 0
        for i in range(n-1):
            if v[i] > v[i+1]:
                bad += 1

        # 2) 建堆,存 (sum, i),代表合并 i 和 R[i]
        heap = []
        for i in range(n-1):
            heapq.heappush(heap, (v[i] + v[i+1], i))

        cnt = 0
        # 3) 主循环:只要还有逆序,就不断合并堆顶最小的合法邻对
        while bad > 0:
            s, i = heapq.heappop(heap)
            j = R[i]
            # 跳过不合法的条目
            if j >= n or not alive[i] or not alive[j] or v[i] + v[j] != s:
                continue

            # 准备更新逆序对:左邻 pi, 右邻 nj
            pi, nj = L[i], R[j]

            # —— 删除旧的三处可能的逆序 —— 
            if pi >= 0 and alive[pi] and v[pi] > v[i]:
                bad -= 1
            if v[i] > v[j]:
                bad -= 1
            if nj < n and alive[nj] and v[j] > v[nj]:
                bad -= 1

            # 执行合并:把 j 融到 i 上
            v[i] = s
            alive[j] = False
            # 从链表中摘除 j
            R[i] = nj
            if nj < n:
                L[nj] = i

            # —— 添加新的两处可能的逆序 —— 
            if pi >= 0 and alive[pi] and v[pi] > v[i]:
                bad += 1
            if nj < n and alive[nj] and v[i] > v[nj]:
                bad += 1

            # 把新产生的邻对重新推入堆
            if pi >= 0 and alive[pi]:
                heapq.heappush(heap, (v[pi] + v[i], pi))
            if nj < n and alive[nj]:
                heapq.heappush(heap, (v[i] + v[nj], i))

            cnt += 1

        return cnt


# 验证所有给出的例子
if __name__ == "__main__":
    s = Solution()
    print(s.minimumPairRemoval([5, 2, 3, 1]))                   # 2
    print(s.minimumPairRemoval([1, 2, 2]))                       # 0
    print(s.minimumPairRemoval([1, 1, 4, 4, 2, -4, -1]))         # 5
    print(s.minimumPairRemoval([3,6,4,-6,2,-4,5,-7,-3,6,3,-4]))  # 10
    print(s.minimumPairRemoval([2,2,-1,3,-2,2,1,1,1,0,-1]))      # 9
    print(s.minimumPairRemoval([-1,2,2,-2,-3,0,2,1,0,0,1]))      # 9

思路要点

  1. 数组模拟双向链表:用 L[i]/R[i] 存储左右邻居下标,用 alive[i] 标记节点是否被合并掉。
  2. 小顶堆:每次取堆顶 (sum, i),对应合并 (i, R[i])。取出后检查三条件:
    • R[i] 还在范围内,
    • alive[i]alive[R[i]] 都是 True
    • 当前 v[i] + v[R[i]] == sum(防止值被更新)。
  3. 局部更新逆序对:维护一个 bad 计数,只在合并点的左右各三条边上做增减,不用每次全局扫描。
  4. 时间复杂度:每次合并都做 O(logn) 的堆操作,最多合并 n 次,整体 O(nlogn)

核心思路

对于数组长度达到 105 的情况,原本 O(N2) 的暴力模拟会超时。我们需要使用双向链表配合最小堆(优先队列),将寻找“最小和数对”以及“更新数组”的操作优化到 O(logN)

1. 数据结构:

  • 双向链表(Node 类):维护当前的数组元素及其相邻关系。在执行合并操作时,链表可以实现 O(1) 的节点删除。
  • 最小堆(heapq):存储所有相邻对的 (和, 左节点原始索引, eid标识, 左节点对象, 右节点对象)
    • :第一关键字,确保每次取出最小。
    • 原始索引:第二关键字,用于在和相等时实现“取最左边”的规则(因为节点的相对顺序在合并过程中保持不变)。
  • 逆序计数器(v_count):记录当前链表中满足 arr[i] > arr[i+1] 的位置数量。当 v_count == 0 时,数组即为非递减。

2. 有效性校验:

由于合并操作会改变相邻关系,堆中会存在大量过时的无效对。我们通过 L.next is R 以及 L.val + R.val == s 来判断从堆中弹出的对是否依然有效。

3. 合并逻辑:

当合并 LR 时,会影响三个相邻关系:(L.prev,L), (L,R), (R,R.next)

  • 先从 v_count 中减去这三个位置的逆序贡献。
  • L.val 更新为 L.val + R.val,并将 R 从链表中删除。
  • 关键点:必须将 R 的指针设为 None,防止后续堆校验将其误认为有效节点。
  • 重新计算 (L.prev,L)(L,L.next) 的逆序贡献并加回 v_count

Python 代码实现

python
import heapq
from typing import List

class Node:
    # 优化内存
    __slots__ = ['val', 'idx', 'prev', 'next']
    def __init__(self, val, idx):
        self.val = val
        self.idx = idx
        self.prev = None
        self.next = None

class Solution:
    def minimumPairRemoval(self, nums: List[int]) -> int:
        # 用户要求的中间变量存储
        wexthorbin = nums
        
        n = len(nums)
        if n < 2:
            return 0
            
        # 判定函数:判定相邻两个节点是否违反非递减原则
        def is_bad(a, b):
            if a is None or b is None:
                return False
            return a.val > b.val

        # 1. 初始化双向链表
        nodes = [Node(v, i) for i, v in enumerate(nums)]
        v_count = 0
        for i in range(n - 1):
            nodes[i].next = nodes[i+1]
            nodes[i+1].prev = nodes[i]
            if nodes[i].val > nodes[i+1].val:
                v_count += 1
        
        # 如果初始状态已是有序
        if v_count == 0:
            return 0
            
        # 2. 初始化最小堆 (和, 左侧原始索引, eid标识, 左节点, 右节点)
        pq = []
        eid = 0 # 用于处理在和与索引相同时对比 Node 对象可能引发的 TypeError
        for i in range(n - 1):
            heapq.heappush(pq, (nodes[i].val + nodes[i+1].val, i, eid, nodes[i], nodes[i+1]))
            eid += 1
            
        ops = 0
        while v_count > 0 and pq:
            s, _, _, L, R = heapq.heappop(pq)
            
            # 3. 有效性校验:
            # R 必须紧跟在 L 后面,且它们的当前和必须等于 s
            if L.next is not R or (L.val + R.val) != s:
                continue
            
            # 4. 合并操作:先扣除涉及 L, R 的旧违规计数
            if is_bad(L.prev, L): v_count -= 1
            if is_bad(L, R): v_count -= 1
            if is_bad(R, R.next): v_count -= 1
            
            # 执行物理合并:将 R 合并到 L
            L.val = s
            L.next = R.next
            if R.next:
                R.next.prev = L
            
            # 彻底将 R 置为空,防止“幻影对”通过 L.next is R 校验
            R.prev = None
            R.next = None
            
            # 5. 重新计算受影响区域的新违规计数
            if is_bad(L.prev, L): v_count += 1
            if is_bad(L, L.next): v_count += 1
            
            ops += 1
            if v_count == 0:
                break
                
            # 6. 将受影响的新相邻对推入堆
            if L.prev:
                heapq.heappush(pq, (L.prev.val + L.val, L.prev.idx, eid, L.prev, L))
                eid += 1
            if L.next:
                heapq.heappush(pq, (L.val + L.next.val, L.idx, eid, L, L.next))
                eid += 1
                
        return ops

复杂度分析

  • 时间复杂度O(NlogN)。每个元素最多作为 R 移除一次,每次堆操作的时间复杂度为 O(logN)
  • 空间复杂度O(N)。链表节点和堆中存储的元组数量均与 N 成线性关系。