Skip to content

P1257 平面上的最接近点对

https://www.luogu.com.cn/problem/P1257

给定平面上 n 个点,找出其中的一对点的距离,使得在这 n 个点的所有点对中,该距离为所有点对中最小的。

输入格式

第一行一个整数 n,表示点的个数。

接下来 n 行,每行两个整数 x,y ,表示一个点的行坐标和列坐标。

输出格式

仅一行,一个实数,表示最短距离,四舍五入保留 4 位小数。

样例 #1

样例输入 #1

3
1 1
1 2
2 2

样例输出 #1

1.0000

提示

数据规模与约定

对于 100% 的数据,保证 1n1040x,y109

python
import math
from functools import cmp_to_key
import sys
sys.setrecursionlimit(1000000)

MAXN = 200010
INF = float("inf")

class Node:
    def __init__(self, x=0.0, y=0.0):
        self.x = x
        self.y = y

class KDTree:
    def __init__(self):
        self.n = 0
        self.s = [Node() for _ in range(MAXN)]  # 点集
        self.lc = [0] * MAXN  # 左子树
        self.rc = [0] * MAXN  # 右子树
        self.d = [0] * MAXN   # 划分维度
        self.L = [0.0] * MAXN
        self.R = [0.0] * MAXN
        self.D = [0.0] * MAXN
        self.U = [0.0] * MAXN
        self.ans = INF

    def dist(self, a, b):
        """计算两点之间的欧几里得距离平方"""
        return (self.s[a].x - self.s[b].x) ** 2 + (self.s[a].y - self.s[b].y) ** 2

    def maintain(self, x):
        """维护边界矩形"""
        self.L[x] = self.R[x] = self.s[x].x
        self.D[x] = self.U[x] = self.s[x].y
        if self.lc[x]:
            lc = self.lc[x]
            self.L[x] = min(self.L[x], self.L[lc])
            self.R[x] = max(self.R[x], self.R[lc])
            self.D[x] = min(self.D[x], self.D[lc])
            self.U[x] = max(self.U[x], self.U[lc])
        if self.rc[x]:
            rc = self.rc[x]
            self.L[x] = min(self.L[x], self.L[rc])
            self.R[x] = max(self.R[x], self.R[rc])
            self.D[x] = min(self.D[x], self.D[rc])
            self.U[x] = max(self.U[x], self.U[rc])

    def cmp1(self, a, b):
        return -1 if a.x < b.x else (1 if a.x > b.x else 0)

    def cmp2(self, a, b):
        return -1 if a.y < b.y else (1 if a.y > b.y else 0)

    def build(self, l, r):
        """构建 KD 树"""
        if l > r:
            return 0
        if l == r:
            self.maintain(l)
            return l

        mid = (l + r) >> 1
        avx = avy = vax = vay = 0.0  # 平均值与方差
        for i in range(l, r + 1):
            avx += self.s[i].x
            avy += self.s[i].y
        avx /= (r - l + 1)
        avy /= (r - l + 1)
        for i in range(l, r + 1):
            vax += (self.s[i].x - avx) ** 2
            vay += (self.s[i].y - avy) ** 2

        if vax >= vay:
            self.d[mid] = 1
            self.s[l:r + 1] = sorted(self.s[l:r + 1], key=cmp_to_key(self.cmp1))
        else:
            self.d[mid] = 2
            self.s[l:r + 1] = sorted(self.s[l:r + 1], key=cmp_to_key(self.cmp2))

        self.lc[mid] = self.build(l, mid - 1)
        self.rc[mid] = self.build(mid + 1, r)
        self.maintain(mid)
        return mid

    def f(self, a, b):
        """计算点 a 到矩形 b 的最短距离"""
        ret = 0
        if self.L[b] > self.s[a].x:
            ret += (self.L[b] - self.s[a].x) ** 2
        if self.R[b] < self.s[a].x:
            ret += (self.s[a].x - self.R[b]) ** 2
        if self.D[b] > self.s[a].y:
            ret += (self.D[b] - self.s[a].y) ** 2
        if self.U[b] < self.s[a].y:
            ret += (self.s[a].y - self.U[b]) ** 2
        return ret

    def query(self, l, r, x):
        """查询点 x 的最近邻点"""
        if l > r:
            return
        mid = (l + r) >> 1
        if mid != x:
            self.ans = min(self.ans, self.dist(x, mid))
        if l == r:
            return

        dist_l = self.f(x, self.lc[mid]) if self.lc[mid] else INF
        dist_r = self.f(x, self.rc[mid]) if self.rc[mid] else INF

        if dist_l < dist_r:
            if dist_l < self.ans:
                self.query(l, mid - 1, x)
            if dist_r < self.ans:
                self.query(mid + 1, r, x)
        else:
            if dist_r < self.ans:
                self.query(mid + 1, r, x)
            if dist_l < self.ans:
                self.query(l, mid - 1, x)

    def solve(self):
        """主函数逻辑"""
        root = self.build(1, self.n)
        for i in range(1, self.n + 1):
            self.query(1, self.n, i)
        return math.sqrt(self.ans)

# 示例用法
if __name__ == "__main__":
    kd_tree = KDTree()
    kd_tree.n = int(input())
    for i in range(1, kd_tree.n + 1):
        x, y = map(float, input().split())
        kd_tree.s[i] = Node(x, y)

    result = kd_tree.solve()
    print(f"{result:.4f}")