Skip to content

4.寻找两个正序数组的中位数

binary search, https://leetcode.cn/problems/median-of-two-sorted-arrays/

给定两个大小分别为 mn 的正序(从小到大)数组 nums1nums2。请你找出并返回这两个正序数组的 中位数

算法的时间复杂度应该为 O(log (m+n))

示例 1:

输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

提示:

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -10^6 <= nums1[i], nums2[i] <= 10^6

这个问题要求在合并两个有序数组后找到中位数,并且时间复杂度要求是 O(log(m+n))。直接合并数组并排序的方式的时间复杂度为 O((m+n)log(m+n)),这显然不符合要求。为了达到要求的时间复杂度,我们可以通过二分查找来解决。

思路:

  1. 中位数的定义
    • 如果合并后的数组长度是奇数,中位数就是中间那个元素。
    • 如果合并后的数组长度是偶数,中位数是中间两个元素的平均值。
  2. 分治法
    • 我们可以将问题转化为在两个有序数组中寻找一个合适的分割点,使得:
      • 左侧部分包含的元素数量等于右侧部分(或比右侧部分多一个元素)。
      • 左侧部分的所有元素都不大于右侧部分的所有元素。
  3. 二分查找
    • 通过对较小数组进行二分查找来优化时间复杂度。我们设定数组 nums1nums2 中较小的那个作为主数组。
    • nums1 中选择一个分割点,然后根据该分割点在 nums2 中选择对应的分割点,确保左右两边的元素满足中位数的定义。

代码实现:

python
def findMedianSortedArrays(nums1, nums2):
    # 确保 nums1 是较小的数组
    if len(nums1) > len(nums2):
        nums1, nums2 = nums2, nums1
    
    m, n = len(nums1), len(nums2)
    
    # 二分查找 nums1
    left, right = 0, m
    
    while left <= right:
        # 在 nums1 中选择一个分割点
        partition1 = (left + right) // 2
        partition2 = (m + n + 1) // 2 - partition1
        
        # 获取分割点左右两侧的值
        maxLeft1 = float('-inf') if partition1 == 0 else nums1[partition1 - 1]
        minRight1 = float('inf') if partition1 == m else nums1[partition1]
        
        maxLeft2 = float('-inf') if partition2 == 0 else nums2[partition2 - 1]
        minRight2 = float('inf') if partition2 == n else nums2[partition2]
        
        # 判断是否找到了正确的分割点
        if maxLeft1 <= minRight2 and maxLeft2 <= minRight1:
            # 找到分割点
            if (m + n) % 2 == 0:
                return (max(maxLeft1, maxLeft2) + min(minRight1, minRight2)) / 2
            else:
                return max(maxLeft1, maxLeft2)
        elif maxLeft1 > minRight2:
            # 如果 maxLeft1 大于 minRight2,说明 partition1 需要向左移
            right = partition1 - 1
        else:
            # 如果 maxLeft2 大于 minRight1,说明 partition1 需要向右移
            left = partition1 + 1

解释:

  1. 确保二分查找在较小的数组上进行:
    • 我们首先确保 nums1 是较小的数组,这样可以减少二分查找的次数。
  2. 二分查找的核心:
    • nums1 中选择一个分割点 partition1,然后在 nums2 中根据 partition1 计算对应的分割点 partition2
    • 我们保证:
      • nums1 左侧的元素不大于 nums2 右侧的元素。
      • nums2 左侧的元素不大于 nums1 右侧的元素。
  3. 中位数的计算:
    • 如果合并后的数组长度是奇数,返回 max(maxLeft1, maxLeft2),即左侧最大的元素。
    • 如果合并后的数组长度是偶数,返回 (max(maxLeft1, maxLeft2) + min(minRight1, minRight2)) / 2,即左侧最大元素和右侧最小元素的平均值。

时间复杂度:

  • 二分查找的时间复杂度是 O(logmin(m,n)),因此整体时间复杂度是 O(logmin(m,n))

空间复杂度:

  • 只使用了常数级别的额外空间,空间复杂度是 O(1)。
python
from typing import List
class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        n1, n2 = len(nums1), len(nums2)
        if n1 > n2:
            nums1, nums2, n1, n2 = nums2, nums1, n2, n1
        if n2 == 0:
            raise ValueError

        left, right, half_len = 0, n1, (n1 + n2 + 1) // 2
        while left <= right:
            i = (left + right) // 2
            j = half_len - i
            if i < n1 and nums2[j - 1] > nums1[i]:
                left = i + 1
            elif i > 0 and nums1[i - 1] > nums2[j]:
                right = i
            else:
                max_of_left = 0
                if i == 0:
                    max_of_left = nums2[j - 1]
                elif j == 0:
                    max_of_left = nums1[i - 1]
                else:
                    max_of_left = max(nums1[i - 1], nums2[j - 1])

                if (n1 + n2) % 2 == 1:
                    return max_of_left

                min_of_right = 0
                if i == n1:
                    min_of_right = nums2[j]
                elif j == n2:
                    min_of_right = nums1[i]
                else:
                    min_of_right = min(nums1[i], nums2[j])

                return (max_of_left + min_of_right) / 2.0

        raise ValueError

if __name__ == "__main__":
    sol = Solution()
    print(sol.findMedianSortedArrays([1, 3], [2]))  # 2.0
    print(sol.findMedianSortedArrays([1, 2], [3, 4]))  # 2.5