Skip to content

M973.最接近原点的 K 个点

kd tree, heap, https://leetcode.cn/problems/k-closest-points-to-origin/

给定一个数组 points ,其中 points[i] = [xi, yi] 表示 X-Y 平面上的一个点,并且是一个整数 k ,返回离原点 (0,0) 最近的 k 个点。

这里,平面上两点之间的距离是 欧几里德距离√(x1 - x2)2 + (y1 - y2)2 )。

你可以按 任何顺序 返回答案。除了点坐标的顺序之外,答案 确保唯一 的。

示例 1:

img
输入:points = [[1,3],[-2,2]], k = 1
输出:[[-2,2]]
解释: 
(1, 3) 和原点之间的距离为 sqrt(10),
(-2, 2) 和原点之间的距离为 sqrt(8),
由于 sqrt(8) < sqrt(10),(-2, 2) 离原点更近。
我们只需要距离原点最近的 K = 1 个点,所以答案就是 [[-2,2]]。

示例 2:

输入:points = [[3,3],[5,-1],[-2,4]], k = 2
输出:[[3,3],[-2,4]]
(答案 [[-2,4],[3,3]] 也会被接受。)

提示:

  • 1 <= k <= points.length <= 10^4
  • -10^4 < xi, yi < 10^4
  • 题目大意:给定一组坐标,求离原点最近的 K 个点。
  • KD树联系:这是标准的 k-NN 问题。虽然通常用大顶堆(Heap)解决,但用 kd 树进行空间索引在数据量巨大且有多次查询时更有优势。

针对力扣 973 题,虽然最常用的解法是 大顶堆 (Heap)快速选择 (QuickSelect),但既然我们正在学习 kd,我将为你展示如何利用 kd 树来解决这个问题,并对比其他主流解法。


解法一:kd 树实现(空间索引法)

kd 树中找 K 个最近邻,与找 1 个最近邻的区别在于:我们需要维护一个大小为 K 的大顶堆,用来保存当前找到的最近的 K 个点。

python
import heapq

class KdNode:
    def __init__(self, point, split, left=None, right=None):
        self.point = point
        self.split = split
        self.left = left
        self.right = right

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        # 1. 构建 KD 树
        def build_tree(data, depth):
            if not data:
                return None
            axis = depth % 2 # 2维空间,只有 0 和 1
            data.sort(key=lambda x: x[axis])
            mid = len(data) // 2
            return KdNode(
                point=data[mid],
                split=axis,
                left=build_tree(data[:mid], depth + 1),
                right=build_tree(data[mid+1:], depth + 1)
            )
        
        root = build_tree(points, 0)
        
        # 2. 搜索 K 个最近邻
        # heap 中存储的是 (-距离, 点),因为 Python 的 heapq 是小顶堆,
        # 用负数可以模拟大顶堆,方便弹出距离最远的点
        self.max_heap = []
        
        def get_dist(p):
            return p[0]**2 + p[1]**2 # 省略开方以优化性能
            
        def search(node):
            if not node:
                return
            
            dist = get_dist(node.point)
            axis = node.split
            
            # 维护大小为 k 的大顶堆
            if len(self.max_heap) < k:
                heapq.heappush(self.max_heap, (-dist, node.point))
            elif dist < -self.max_heap[0][0]:
                heapq.heapreplace(self.max_heap, (-dist, node.point))
            
            # 决定搜索方向
            target_val = 0 # 因为是离原点(0,0)的距离,目标点值永远是0
            if target_val < node.point[axis]:
                near, far = node.left, node.right
            else:
                near, far = node.right, node.left
            
            search(near)
            
            # 剪枝:如果另一侧还可能存在更近的点,则进入另一侧
            # 这里的“墙”是 node.point[axis],目标点在 0
            # 距离墙的平方距离是 (node.point[axis] - 0)^2
            if len(self.max_heap) < k or (node.point[axis]**2 < -self.max_heap[0][0]):
                search(far)
                
        search(root)
        return [p for d, p in self.max_heap]

解法二:大顶堆(最稳健解法)

这是面试中最推荐的解法。遍历所有点,维护一个大小为 K 的大顶堆。时间复杂度 O(NlogK)

python
import heapq

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        heap = [] # 大顶堆
        for x, y in points:
            dist = -(x**2 + y**2) # 取负数变大顶堆
            if len(heap) < k:
                heapq.heappush(heap, (dist, x, y))
            elif dist > heap[0][0]:
                heapq.heapreplace(heap, (dist, x, y))
        
        return [[x, y] for d, x, y in heap]

解法三:快速选择(QuickSelect,理论最优)

类似快排的分治思想,不需要完全排序,只需把最小的 K 个数放到数组左边。时间复杂度平均 O(N)

python
import random

class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        def get_dist(p):
            return p[0]**2 + p[1]**2
        
        def quick_select(left, right, k):
            if left >= right: return
            
            pivot_idx = random.randint(left, right)
            points[left], points[pivot_idx] = points[pivot_idx], points[left]
            
            pivot_dist = get_dist(points[left])
            i = left
            for j in range(left + 1, right + 1):
                if get_dist(points[j]) < pivot_dist:
                    i += 1
                    points[i], points[j] = points[j], points[i]
            
            points[left], points[i] = points[i], points[left]
            
            # i 是 pivot 最终所在的位置
            num_left = i - left + 1
            if k < num_left:
                quick_select(left, i - 1, k)
            elif k > num_left:
                quick_select(i + 1, right, k - num_left)

        quick_select(0, len(points) - 1, k)
        return points[:k]

总结与对比

方法时间复杂度优点缺点
排序O(NlogN)代码最简单。性能一般,处理了不必要的排序。
大顶堆O(NlogK)适合流式数据(海量数据无法一次存入内存)。需要额外 O(K) 空间。
快速选择O(N)理论最快。稍微复杂,若基准点选不好可能退化。
kd建树 O(NlogN) 查询 O(KlogN)多次查询时极快。适合空间检索。仅单次查询时,建树开销太大。

针对力扣这道题:

  • 如果你只需要查询一次,请使用快速选择大顶堆
  • 如果你有一个包含 100 万个点的地图,用户会频繁点击不同位置查询最近的点,请务必使用 kd