Skip to content

T30830:地铁换乘(多组查询版)

倍增法,http://cs101.openjudge.cn/practice/30830/

B 市有 n 个地点,编号为 1∼n,可视为根节点为编号 t 的树,交通管理局在根节点上。

记每个节点的深度为其到 t 的边数,根节点 t 的深度为 0。

有 m 次询问,每次给出两个起点 p,q 与速度 v1,v2:

线路 A 施工队从 p 出发往 q 修,每天修 v1 条边;

线路 B 施工队从 q 出发往 p 修,每天修 v2 条边。

数据保证:

p 到 q 的路径长度 L 满足 Lmod(v1+v2)=0,

两队一定在某个整点、某个节点相遇,该节点为换乘站。

对每次询问,输出:相遇需要的天数、换乘站的深度。

数据范围:1 ≤ n, m ≤ 2×10^5, 1 ≤ t,p,q ≤ n, 1 ≤ v1,v2 ≤ 10^9, 1 ≤ u,v ≤ n,u≠v

保证输入构成一棵树。保证 p 到q 的距离L 满足 L mod (v1+v2) = 0。

保证相遇点一定在某个节点上,且相遇天数为整数。

输入

第一行两个正整数 n,t,代表地点个数、根节点编号。

接下来 n-1 行,每行两个正整数 u,v,表示两点连边。

接下来一行一个正整数 m,表示询问组数。

接下来 m 行,每行四个正整数 p,q,v1,v2。

输出

共 m 行,每行两个整数:相遇天数、换乘站深度。

样例输入

7 1
1 2
1 3
2 4
2 5
3 6
3 7
1
4 7 1 3

样例输出

1 1

提示

LCA(最近公共祖先) 和 倍增法(Binary Lifting)

来源

http://cs101.openjudge.cn/practice/30669/

python
import sys
input = sys.stdin.read
data = input().split()

def main():
    ptr = 0
    n, t = int(data[ptr]), int(data[ptr+1])
    ptr += 2

    # 建图
    adj = [[] for _ in range(n + 1)]
    for _ in range(n - 1):
        u = int(data[ptr])
        v = int(data[ptr+1])
        adj[u].append(v)
        adj[v].append(u)
        ptr += 2

    # 倍增预处理
    LOG = 18
    depth = [0] * (n + 1)
    up = [[0] * LOG for _ in range(n + 1)]

    # DFS 初始化
    stack = [(t, 0, 0)]
    while stack:
        u, fa, d = stack.pop()
        depth[u] = d
        up[u][0] = fa
        for v in adj[u]:
            if v != fa:
                stack.append((v, u, d + 1))

    # 构建倍增表
    for j in range(1, LOG):
        for i in range(1, n + 1):
            up[i][j] = up[up[i][j-1]][j-1]

    # LCA
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # 对齐深度
        for j in range(LOG-1, -1, -1):
            if depth[u] - (1 << j) >= depth[v]:
                u = up[u][j]
        if u == v:
            return u
        for j in range(LOG-1, -1, -1):
            if up[u][j] != up[v][j]:
                u = up[u][j]
                v = up[v][j]
        return up[u][0]

    # 第 k 个祖先
    def kth_ancestor(u, k):
        for j in range(LOG-1, -1, -1):
            if k >= (1 << j):
                u = up[u][j]
                k -= (1 << j)
        return u

    # 读取查询数量 m
    m = int(data[ptr])
    ptr += 1

    # 处理 m 组查询
    res = []
    for _ in range(m):
        p = int(data[ptr])
        q = int(data[ptr+1])
        v1 = int(data[ptr+2])
        v2 = int(data[ptr+3])
        ptr += 4

        r = lca(p, q)
        L = (depth[p] - depth[r]) + (depth[q] - depth[r])
        days = L // (v1 + v2)
        s = v1 * days

        # 找相遇点
        if s <= depth[p] - depth[r]:
            meet = kth_ancestor(p, s)
        else:
            s2 = L - s
            meet = kth_ancestor(q, s2)

        res.append(f"{days} {depth[meet]}")
    
    print('\n'.join(res))

if __name__ == "__main__":
    main()