Skip to content

2179.统计数组中好三元组数目

Binary Indexed Tree, Segment Tree, Divide and Conquer, Merge Sort, https://leetcode.cn/problems/count-good-triplets-in-an-array/description/

给你两个下标从 0 开始且长度为 n 的整数数组 nums1nums2 ,两者都是 [0, 1, ..., n - 1]排列

好三元组 指的是 3互不相同 的值,且它们在数组 nums1nums2 中出现顺序保持一致。换句话说,如果我们将 pos1v 记为值 vnums1 中出现的位置,pos2v 为值 vnums2 中的位置,那么一个好三元组定义为 0 <= x, y, z <= n - 1 ,且 pos1x < pos1y < pos1zpos2x < pos2y < pos2z 都成立的 (x, y, z)

请你返回好三元组的 总数目

示例 1:

输入:nums1 = [2,0,1,3], nums2 = [0,1,2,3]
输出:1
解释:
总共有 4 个三元组 (x,y,z) 满足 pos1x < pos1y < pos1z ,分别是 (2,0,1) ,(2,0,3) ,(2,1,3) 和 (0,1,3) 。
这些三元组中,只有 (0,1,3) 满足 pos2x < pos2y < pos2z 。所以只有 1 个好三元组。

示例 2:

输入:nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
输出:4
解释:总共有 4 个好三元组 (4,0,3) ,(4,0,2) ,(4,1,3) 和 (4,1,2) 。

提示:

  • n == nums1.length == nums2.length
  • 3 <= n <= 10^5
  • 0 <= nums1[i], nums2[i] <= n - 1
  • nums1nums2[0, 1, ..., n - 1] 的排列。

下面给出一种基于「树状数组(Fenwick Tree)」的 O(nlogn) 解法,思路如下:

  1. 先构造一个长度为 n 的数组 pos2,其中 pos2[v] 表示值 (v) 在 nums2 中的位置。
  2. nums1 中的每个值 (v) 替换为 pos2[v],得到一个新数组 A。此时,问题就变成了:在数组 A 中,统计下标 i<j<kA[i]<A[j]<A[k] 的三元组数。
  3. 对于每个中间点 j,我们想知道:
    • 左侧比它小的元素个数:L[j]=|{i<j:A[i]<A[j]}|
    • 右侧比它大的元素个数:R[j]=|{k>j:A[k]>A[j]}|
      那么以 j 为中间点的递增三元组数就是 L[j]×R[j]
  4. 我们只要对所有 j 求和即可。

用两次 Fenwick 树就能在线地统计「前缀中小于某值的个数」和「后缀中大于某值的个数」:

  • 第一次从左到右扫,Fenwick 维护「已经遍历过的 A[i] 的频次」,查询时得到 L[j]
  • 第二次从右到左扫,重新清空 Fenwick,维护「已经遍历过的 A[k] 的频次」,查询时得到 R[j]

下面是完整代码:

python
from typing import List

class Fenwick:
    """1-based Fenwick Tree for point update & prefix sum."""
    def __init__(self, n: int):
        self.n = n
        self.fw = [0] * (n + 1)
    def update(self, i: int, v: int):
        """add v at position i (1-based)"""
        while i <= self.n:
            self.fw[i] += v
            i += i & -i
    def query(self, i: int) -> int:
        """sum of [1..i] (1-based)"""
        s = 0
        while i > 0:
            s += self.fw[i]
            i -= i & -i
        return s

class Solution:
    def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
        n = len(nums1)
        # 1. 记录 nums2 中每个值的位置
        pos2 = [0] * n
        for i, v in enumerate(nums2):
            pos2[v] = i
        
        # 2. 用 nums2 的位置替换 nums1 中的值
        A = [pos2[v] for v in nums1]
        # 为 Fenwick 做 1-based 转换
        A = [x + 1 for x in A]  # 现在 A 中元素范围是 [1..n]
        
        # 3. 计算 L[j]
        fw = Fenwick(n)
        L = [0] * n
        for j in range(n):
            aj = A[j]
            # 前缀中 < aj 的个数 = query(aj-1)
            L[j] = fw.query(aj - 1)
            # 把 aj 加入 Fenwick
            fw.update(aj, 1)
        
        # 4. 计算 R[j]
        fw = Fenwick(n)
        R = [0] * n
        for j in range(n-1, -1, -1):
            aj = A[j]
            # 后缀中 > aj 的个数 = 已遍历总数 - query(aj)
            seen = (n-1 - j)  # 已经加到树里的元素个数
            cnt_le = fw.query(aj)  # <= aj
            R[j] = seen - cnt_le
            fw.update(aj, 1)
        
        # 5. 累加答案
        ans = 0
        for j in range(n):
            ans += L[j] * R[j]
        return ans

# --- 简单测试 ---
if __name__ == "__main__":
    sol = Solution()
    print(sol.goodTriplets([2,0,1,3], [0,1,2,3]))  # 输出 1
    print(sol.goodTriplets([4,0,1,3,2], [4,1,0,2,3]))  # 输出 4

复杂度分析:

  • 构造 pos2AO(n)
  • 两次 Fenwick 扫描,每次更新/查询 O(logn),总计 O(nlogn)
  • 总体 O(nlogn),可处理 n 高达 105 的规模。

这样就能高效地统计所有“好三元组”的数量。