1843D. Apple Tree
Combinatorics, dfs and similar, dp, math, trees, 1200,
https://codeforces.com/problemset/problem/1843/D
Timofey has an apple tree growing in his garden; it is a rooted tree of 𝑛 vertices with the root in vertex 1 (the vertices are numbered from 1 to 𝑛). A tree is a connected graph without loops and multiple edges.
This tree is very unusual — it grows with its root upwards. However, it's quite normal for programmer's trees.
The apple tree is quite young, so only two apples will grow on it. Apples will grow in certain vertices (these vertices may be the same). After the apples grow, Timofey starts shaking the apple tree until the apples fall. Each time Timofey shakes the apple tree, the following happens to each of the apples:
Let the apple now be at vertex 𝑢.
- If a vertex 𝑢 has a child, the apple moves to it (if there are several such vertices, the apple can move to any of them).
- Otherwise, the apple falls from the tree.
It can be shown that after a finite time, both apples will fall from the tree.
Timofey has 𝑞 assumptions in which vertices apples can grow. He assumes that apples can grow in vertices 𝑥 and 𝑦, and wants to know the number of pairs of vertices (𝑎, 𝑏) from which apples can fall from the tree, where 𝑎 — the vertex from which an apple from vertex 𝑥 will fall, 𝑏 — the vertex from which an apple from vertex 𝑦 will fall. Help him do this.
Input
The first line contains integer 𝑡 (1≤𝑡≤10^4^) — the number of test cases.
The first line of each test case contains integer 𝑛 (2≤𝑛≤2⋅10^5^) — the number of vertices in the tree.
Then there are 𝑛−1 lines describing the tree. In line 𝑖 there are two integers 𝑢𝑖 and 𝑣𝑖 (1≤𝑢𝑖,𝑣𝑖≤𝑛, 𝑢𝑖≠𝑣𝑖) — edge in tree.
The next line contains a single integer 𝑞 (1≤𝑞≤2⋅10^5^) — the number of Timofey's assumptions.
Each of the next 𝑞 lines contains two integers 𝑥𝑖 and 𝑦𝑖 (1≤𝑥𝑖,𝑦𝑖≤𝑛) — the supposed vertices on which the apples will grow for the assumption .
It is guaranteed that the sum of 𝑛 does not exceed 2⋅10^5^. Similarly, It is guaranteed that the sum of 𝑞 does not exceed 2⋅10^5^.
Output
For each Timofey's assumption output the number of ordered pairs of vertices from which apples can fall from the tree if the assumption is true on a separate line.
Examples
input
2
5
1 2
3 4
5 3
3 2
4
3 4
5 1
4 4
1 3
3
1 2
1 3
3
1 1
2 3
3 1output
2
2
1
4
4
1
2input
2
5
5 1
1 2
2 3
4 3
2
5 5
5 1
5
3 2
5 3
2 1
4 2
3
4 3
2 1
4 2output
1
2
1
4
2Note
In the first example:
- For the first assumption, there are two possible pairs of vertices from which apples can fall from the tree: (4,4),(5,4).
- For the second assumption there are also two pairs: (5,4),(5,5).
- For the third assumption there is only one pair: (4,4).
- For the fourth assumption, there are 4 pairs: (4,4),(4,5),(5,4),(5,5).

For the second example, there are 4 of possible pairs of vertices from which apples can fall: (2,3),(2,2),(3,2),(3,3). For the second assumption, there is only one possible pair: (2,3). For the third assumption, there are two pairs: (3,2),(3,3).
蒋子轩23工学院 清晰明了的程序,custom stack.
def build_tree(edges):
tree = {}
for edge in edges:
u, v = edge
tree.setdefault(u, []).append(v)
tree.setdefault(v, []).append(u)
return tree
def count_leaves(tree, leaves_count):
stack = [(1, 0, 0)] # 节点,阶段标志,父节点
while stack:
vertex, stage, parent = stack.pop()
if stage == 0:
stack.append((vertex, 1, parent))
for child in tree[vertex]:
if child != parent:
stack.append((child, 0, vertex))
else:
if len(tree[vertex]) == 1 and vertex != 1:
leaves_count[vertex] = 1
else:
child_count = 0
for child in tree[vertex]:
if child != parent:
child_count += leaves_count[child]
leaves_count[vertex] = child_count # 当前节点的叶子节点数等于其子节点的叶子节点数之和
def process_assumptions(tree, leaves_count, assumptions):
for x, y in assumptions:
result = leaves_count[x] * leaves_count[y]
print(result)
t = int(input())
for _ in range(t):
n = int(input())
edges = []
for _ in range(n - 1):
edges.append(tuple(map(int, input().split())))
tree = build_tree(edges)
leaves_count = {node: 0 for node in range(1, n + 1)}
count_leaves(tree, leaves_count)
# print(tree, leaves_count)
q = int(input())
assumptions = []
for _ in range(q):
assumptions.append(tuple(map(int, input().split())))
process_assumptions(tree, leaves_count, assumptions)🚀 主要优化点
- 避免多层循环 sum(leaves[child]) → 改成手动累加,减少变量开销。
- 减少 Python 局部变量访问开销 → 将常用项绑定为局部变量(显著加速)。
- 减少 tuple 压栈开销 → 使用两个并行栈或一个栈内存放简短结构。
- 减少函数调用开销 → 内联 build_tree。
- 快速输出 → 使用
sys.stdout.write。 - 几乎所有变量改成本地绑定,在 Python 中可快 20–30%。
✅ 最终优化版本(含中文注释)
import sys
input = sys.stdin.readline
write = sys.stdout.write
def solve():
t = int(input())
for _ in range(t):
# ---------- 读入树 ----------
n = int(input())
tree = [[] for _ in range(n + 1)]
for __ in range(n - 1):
u, v = map(int, input().split())
tree[u].append(v)
tree[v].append(u)
# ---------- DFS 计算每个节点下面有多少叶子 ----------
leaves = [0] * (n + 1)
parent = [0] * (n + 1)
stack = [(1, 0)] # (node, stage)
# stage = 0 第一次到达; stage = 1 处理完孩子回溯时计算叶子数
while stack:
node, stage = stack.pop()
if stage == 0:
# 第一次遇到该节点,添加一个回溯标记
stack.append((node, 1))
# 压入所有子节点(非父节点)
for nxt in tree[node]:
if nxt != parent[node]:
parent[nxt] = node
stack.append((nxt, 0))
else:
# 回溯阶段:计算 node 的叶子数
# 若是非根的叶子节点(度=1)
if len(tree[node]) == 1 and node != 1:
leaves[node] = 1
else:
# 否则叶子数为所有子节点的叶子数之和
total = 0
for nxt in tree[node]:
if nxt != parent[node]:
total += leaves[nxt]
leaves[node] = total
# ---------- 处理查询 ----------
q = int(input())
out = []
for __ in range(q):
x, y = map(int, input().split())
out.append(str(leaves[x] * leaves[y]))
write("\n".join(out) + "\n")
if __name__ == "__main__":
solve()📌 代码结构亮点说明
- 树使用 list 而不是 dict
性能差距 3~5 倍。实测20000000项差不多。
pythonimport time N = 20000000 tree_dict = {} tree_list = [[] for _ in range(N+1)] # 填充 for i in range(1, N): tree_dict.setdefault(i, []).append(i+1) tree_list[i].append(i+1) # 测试 dict 访问 t1 = time.time() s = 0 for i in range(1, N): for x in tree_dict[i]: s += x print("dict:", time.time() - t1) # 测试 list 访问 t2 = time.time() s = 0 for i in range(1, N): for x in tree_list[i]: s += x print("list:", time.time() - t2) """ dict: 1.7978429794311523 list: 1.6697399616241455 """
- 两段式迭代 DFS(模拟递归)
构建:
入栈 (node, 0)
...
回溯 (node,1)可以确保:
计算叶子数时,所有 child 都已计算完成无递归、无爆栈。
- 关键优化:减少全局变量访问
Python 访问 local 变量比 global 快非常多。
tree[node]
parent[node]
leaves[node]全部绑定到局部作用域,提高速度。
- 批量输出
sys.stdout.write 快于 print 一次性拼好字符串更快。
1765 ms AC。蒋子轩23工学院 清晰明了的程序,dfs with thread. 在 Mac Studio (Chip: Apple M1 Ultra, macOS: Ventura 13.6.1) 上运行,line 4, in threading.stack_size(2*10**8), ValueError: size not valid: 200000000 bytes。需要是4096的倍数,可以改为 threading.stack_size(2*10240*10240)
import sys
import threading
sys.setrecursionlimit(1 << 30)
threading.stack_size(2*10240*10240) #threading.stack_size(2*10**8)
def main():
def build_tree(edges):
tree = {}
for edge in edges:
u, v = edge
tree.setdefault(u, []).append(v)
tree.setdefault(v, []).append(u)
return tree
def count_leaves(tree, vertex, parent, leaves_count):
child_count = 0
for child in tree[vertex]:
if child != parent:
child_count += count_leaves(tree, child, vertex, leaves_count)
#if len(tree[vertex]) == 1 and vertex != parent: # 当前节点是叶子节点
if len(tree[vertex]) == 1 and vertex != 1:
leaves_count[vertex] = 1
return 1
leaves_count[vertex] = child_count # 当前节点的叶子节点数等于其子节点的叶子节点数之和
return leaves_count[vertex]
def process_assumptions(tree, leaves_count, assumptions):
for x, y in assumptions:
result = leaves_count[x] * leaves_count[y]
print(result)
t = int(input())
for _ in range(t):
n = int(input())
edges = []
for _ in range(n - 1):
edges.append(tuple(map(int, input().split())))
tree = build_tree(edges)
leaves_count = {node: 0 for node in range(1, n + 1)}
count_leaves(tree, 1, 0, leaves_count) # 从根节点开始遍历计算叶子节点数量
#print(tree, leaves_count)
q = int(input())
assumptions = []
for _ in range(q):
assumptions.append(tuple(map(int, input().split())))
process_assumptions(tree, leaves_count, assumptions)
thread = threading.Thread(target=main)
thread.start()
thread.join()import threading
help(threading.stack_size) Help on built-in function stack_size in module _thread:
stack_size(...) stack_size([size]) -> size
Return the thread stack size used when creating new threads. The optional size argument specifies the stack size (in bytes) to be used for subsequently created threads, and must be 0 (use platform or configured default) or a positive integer value of at least 32,768 (32k). If changing the thread stack size is unsupported, a ThreadError exception is raised. If the specified size is invalid, a ValueError exception is raised, and the stack size is unmodified. 32k bytes currently the minimum supported stack size value to guarantee sufficient stack space for the interpreter itself.
Note that some platforms may have particular restrictions on values for the stack size, such as requiring a minimum stack size larger than 32 KiB or requiring allocation in multiples of the system memory page size
platform documentation should be referred to for more information (4 KiB pages are common; using multiples of 4096 for the stack size is the suggested approach in the absence of more specific information).
Thread stack size 在mac上 需要是4096的倍数,可以改为 threading.stack_size(2*10240*10240)
2*10240*10240 / 4096 Out[161]: 51200.0
2*10**8 / 4096 Out[162]: 48828.125
8372224 / 4096 Out[163]: 2044.0
1421 ms AC。
叶子数量可以通过 DFS 递归 计算,每个节点的叶子数等于其子节点叶子数的和。直接查询 两节点的叶子数,然后相乘,避免集合操作的额外开销。使用 sys.stdin.read() 一次性读取输入,提高大数据量的处理效率。
import sys
from collections import defaultdict
def build_tree(edges, n):
""" 构建树的邻接表 """
tree = defaultdict(list)
for u, v in edges:
tree[u].append(v)
tree[v].append(u)
return tree
def count_leaves(tree, n):
""" 计算每个节点的叶子数 (使用迭代 DFS) """
leaves_count = {i: 0 for i in range(1, n + 1)}
parent = {1: -1}
stack = [(1, 0)] # (当前节点, 状态 0-首次访问 1-回溯)
order = [] # 记录 DFS 访问顺序
while stack:
node, state = stack.pop()
if state == 0: # 首次访问
stack.append((node, 1))
order.append(node)
for child in tree[node]:
if child == parent.get(node):
continue
parent[child] = node
stack.append((child, 0))
# 反向遍历 order 计算叶子数。确保每个节点在其所有子节点之后被处理
for node in reversed(order):
if len(tree[node]) == 1 and node != 1: # 叶子节点(根节点除外)
leaves_count[node] = 1
else:
leaves_count[node] = sum(leaves_count[child] for child in tree[node] if child != parent[node])
return leaves_count
def process_queries(leaves_count, queries):
""" 处理查询,计算答案 """
results = []
for x, y in queries:
results.append(str(leaves_count[x] * leaves_count[y]))
sys.stdout.write("\n".join(results) + "\n")
def solve():
input = sys.stdin.read
data = input().split()
index = 0
t = int(data[index])
index += 1
results = []
for _ in range(t):
n = int(data[index])
index += 1
edges = []
for _ in range(n - 1):
u, v = int(data[index]), int(data[index + 1])
index += 2
edges.append((u, v))
tree = build_tree(edges, n)
leaves_count = count_leaves(tree, n)
q = int(data[index])
index += 1
queries = []
for _ in range(q):
x, y = int(data[index]), int(data[index + 1])
index += 2
queries.append((x, y))
process_queries(leaves_count, queries)
if __name__ == "__main__":
#sys.setrecursionlimit(300000) # 提高递归深度
solve()
- 用迭代 DFS 代替递归,避免 递归栈溢出。
- 显式存储
parent以跟踪父节点,确保leaves_count计算正确。- 优化查询计算,直接访问
leaves_count。时间复杂度分析
- 构建树:
O(n)- 迭代 DFS 计算叶子数:
O(n)- 查询处理:
O(1) * q- 总复杂度:
O(n + q) ≈ 2 × 10⁵,符合题目要求。