Skip to content

580C. Kefa and Park

dfs and similar/graphs/trees, 1500, https://codeforces.com/contest/580/problem/C

Kefa decided to celebrate his first big salary by going to the restaurant.

He lives by an unusual park. The park is a rooted tree consisting of n vertices with the root at vertex 1. Vertex 1 also contains Kefa's house. Unfortunaely for our hero, the park also contains cats. Kefa has already found out what are the vertices with cats in them.

The leaf vertices of the park contain restaurants. Kefa wants to choose a restaurant where he will go, but unfortunately he is very afraid of cats, so there is no way he will go to the restaurant if the path from the restaurant to his house contains more than m consecutive vertices with cats.

Your task is to help Kefa count the number of restaurants where he can go.

Input

The first line contains two integers, n and m (2 ≤ n ≤ 10^5^, 1 ≤ m ≤ n) — the number of vertices of the tree and the maximum number of consecutive vertices with cats that is still ok for Kefa.

The second line contains n integers a~1~, a~2~, ..., a~n~, where each a~i~ either equals to 0 (then vertex i has no cat), or equals to 1 (then vertex i has a cat).

Next n - 1 lines contains the edges of the tree in the format "x~i~ y~i~" (without the quotes) (1 ≤ x~i~, y~i~ ≤ n, x~i~ ≠ y~i~), where x~i~ and y~i~ are the vertices of the tree, connected by an edge.

It is guaranteed that the given set of edges specifies a tree.

Output

A single integer — the number of distinct leaves of a tree the path to which from Kefa's home contains at most m consecutive vertices with cats.

Examples

input

4 1
1 1 0 0
1 2
1 3
1 4

output

2

input

7 1
1 0 1 1 0 0 0
1 2
1 3
2 4
2 5
3 6
3 7

output

2

Note

Let us remind you that a tree is a connected graph on n vertices and n - 1 edge. A rooted tree is a tree with a special vertex called root. In a rooted tree among any two vertices connected by an edge, one vertex is a parent (the one closer to the root), and the other one is a child. A vertex is called a leaf, if it has no children.

Note to the first sample test:

img

The vertices containing cats are marked red. The restaurants are at vertices 2, 3, 4. Kefa can't go only to the restaurant located at vertex 2.

Note to the second sample test:

img

The restaurants are located at vertices 4, 5, 6, 7. Kefa can't go to restaurants 6, 7.

显示栈进行迭代DFS:用栈模拟DFS,而不是递归调用,这样能避免 Python 栈深度限制的问题。359ms AC。

python
import sys
from collections import defaultdict

def count_restaurants(n, m, cats, edges):
    # 构建树的邻接表
    tree = defaultdict(list)
    for x, y in edges:
        tree[x].append(y)
        tree[y].append(x)

    # 迭代 DFS (使用显式栈)
    stack = [(1, 0)]  # (当前节点, 连续猫的数量)
    visited = [False] * (n + 1)
    visited[1] = True
    leaf_count = 0

    while stack:
        node, consecutive_cats = stack.pop()
        
        if cats[node - 1]:
            consecutive_cats += 1
        else:
            consecutive_cats = 0

        if consecutive_cats > m:
            continue

        is_leaf = True
        for neighbor in tree[node]:
            if not visited[neighbor]:
                visited[neighbor] = True
                stack.append((neighbor, consecutive_cats))
                is_leaf = False
        
        if is_leaf:  # 如果是叶子节点
            leaf_count += 1

    return leaf_count

# 读取输入
def main():
    n, m = map(int, sys.stdin.readline().split())
    cats = list(map(int, sys.stdin.readline().split()))
    edges = [tuple(map(int, sys.stdin.readline().split())) for _ in range(n - 1)]
    
    print(count_restaurants(n, m, cats, edges))

if __name__ == "__main__":
    main()

rutime error on test 35

python
import sys

sys.setrecursionlimit(10 ** 6)


def count_restaurants(n, m, cats, edges):
    from collections import defaultdict

    # 建立树的邻接表
    tree = defaultdict(list)
    for x, y in edges:
        tree[x].append(y)
        tree[y].append(x)

    # 记录访问过的节点
    visited = [False] * (n + 1)

    def dfs(node, consecutive_cats):
        if cats[node - 1]:
            consecutive_cats += 1
        else:
            consecutive_cats = 0

        if consecutive_cats > m:
            return 0

        visited[node] = True
        is_leaf = True
        count = 0

        for neighbor in tree[node]:
            if not visited[neighbor]:
                is_leaf = False
                count += dfs(neighbor, consecutive_cats)

        return count if not is_leaf else 1

    return dfs(1, 0)


# 读取输入
def main():
    n, m = map(int, input().split())
    cats = list(map(int, input().split()))
    edges = [tuple(map(int, input().split())) for _ in range(n - 1)]

    print(count_restaurants(n, m, cats, edges))


if __name__ == "__main__":
    main()

使用 BFS(广度优先搜索),通过 deque 队列进行迭代处理,而不依赖递归。这种方式避免了递归深度限制,同时保证高效遍历大规模输入。374ms AC。

python
import sys
from collections import defaultdict, deque

def count_restaurants(n, m, cats, edges):
    # 构建树的邻接表
    tree = defaultdict(list)
    for x, y in edges:
        tree[x].append(y)
        tree[y].append(x)

    # 迭代 BFS (使用队列)
    queue = deque([(1, 0)])  # (当前节点, 连续猫的数量)
    visited = [False] * (n + 1)
    visited[1] = True
    leaf_count = 0

    while queue:
        node, consecutive_cats = queue.popleft()
        
        if cats[node - 1]:
            consecutive_cats += 1
        else:
            consecutive_cats = 0

        if consecutive_cats > m:
            continue

        is_leaf = True
        for neighbor in tree[node]:
            if not visited[neighbor]:
                visited[neighbor] = True
                queue.append((neighbor, consecutive_cats))
                is_leaf = False
        
        if is_leaf:  # 如果是叶子节点
            leaf_count += 1

    return leaf_count

# 读取输入
def main():
    n, m = map(int, sys.stdin.readline().split())
    cats = list(map(int, sys.stdin.readline().split()))
    edges = [tuple(map(int, sys.stdin.readline().split())) for _ in range(n - 1)]
    
    print(count_restaurants(n, m, cats, edges))

if __name__ == "__main__":
    main()

2020fall-cs101,林逸云

python
# https://codeforces.com/contest/580/problem/C
# 2020fall-cs101, Yiyun LIN
cat = dict()
graph = dict()
visited = set()
queue = [1]
res = 0
n,m = map(int, input().split())
a = list(map(int, input().split()))
for _ in range(n-1):
    x,y = map(int, input().split())
    if x not in graph.keys():
        graph[x] = []
    if y not in graph.keys():
        graph[y] = []
    graph[x].append(y)
    graph[y].append(x)
    
cat[1] = a[0]
while len(queue)>0:
    x = queue.pop(0)
    visited.add(x)
    if cat[x]>m:
        continue
    b=0
    for k in graph[x]:
        if k not in visited:
            if a[k-1]==1:
                cat[k] = cat[x] + 1
            else:
                cat[k] = 0
            queue.append(k)
            b=1
    if b==0:
        res+=1
print(res)

常规做法应该是图的遍历。可以参考bfs:https://www.codespeedy.com/breadth-first-search-algorithm-in-python/

python
# https://codeforces.com/contest/580/problem/C
n,m = [int(i) for i in input().split()]
cat = [0]+[int(i) for i in input().split()]
d = {}
t = 1
for i in range(n-1):
    x,y = [int(_) for _ in input().split()]
    # d.setdefault(x,[]).append(y)
    try:
        d[x].append(y)
    except:
        d[x] = [y]
    # d.setdefault(y,[]).append(x)    
    try:
        d[y].append(x)
    except:
        d[y] = [x]

rec = [(1,0,1)]
cnt = 0
while len(rec) != 0:
    i,c,prev = rec.pop()
    if cat[i]:
        c += 1
    else:
        c = 0
    if c > m:
        continue
    if i != 1 and len(d[i]) == 1:
        cnt += 1
        continue
    for j in d[i]:
        if j == prev:
            continue
        rec.append((j,c,i))
print(cnt)

这段代码是dfs,AC了 580C. 用到了yield,装饰器,改造了栈递归。 https://codeforces.com/contest/580/submission/55059869

参考自: C. Kefa and Park — Runtime Error in Python3 https://codeforces.com/blog/entry/67372 Abuse yield in python to create stackless recursion (https://codeforces.com/contest/580/submission/55059869). I made this thing a while back in order to do deep recursion in python and it has been working pretty nicely.

Python 函数装饰器。其中“2 篇笔记”部分,更容易看懂。 https://www.runoob.com/w3cnote/python-func-decorators.html

python
# Not my code
# testing https://codeforces.com/contest/580/submission/55056991 code
# with my bootstraped recursion
 
# My magical way of doing recursion in python. This
# isn't the fastest, but at least it works.
from types import GeneratorType
def bootstrap(func, stack=[]):
    def wrapped_function(*args, **kwargs):
        if stack:
            return func(*args, **kwargs)
        else:
            call = func(*args, **kwargs)
            while True:
                if type(call) is GeneratorType:
                    stack.append(call)
                    call = next(call)
                else:
                    stack.pop()
                    if not stack:
                        break
                    call = stack[-1].send(call)
            return call
 
    return wrapped_function
 
@bootstrap
def go_child(node, path_cat, parent):
    # check if too more cat
    m_cat = 0
    if cat[node] == 0:
        m_cat = 0
    else:
        m_cat = path_cat + cat[node]
    # too many cats
    if m_cat > m:
        yield 0
    isLeaf = True
    sums = 0
    # traverse edges belongs to node
    for j in e[node]:
        # ignore parent edge
        if j == parent:
            continue
        # node has child
        isLeaf = False
        # dfs from child of node
        sums += yield go_child(j, m_cat, node)
    # achievable leaf
    if isLeaf:
        yield 1
    # return achievable leaves count
    yield sums
 
 
n, m = map(int, input().split())
cat = list(map(int, input().split()))
e = [[] for i in range(n)]
 
for i in range(n - 1):
    v1, v2 = map(int, input().split())
    # store undirected edge
    e[v1 - 1].append(v2 - 1)
    e[v2 - 1].append(v1 - 1)
# dfs from root
print(go_child(0, 0, -1))

2021fall-cs101,李博熙。这题需要手工建栈/队列dfs/bfs。

python
# 2021fall-cs101,LI Boxi, manually construct stack
# https://codeforces.com/contest/580/problem/C
n,m=map(int,input().split())
cat=[-1]+list(map(int,input().split()))
dict0={}
for i in range(1,n+1):
    dict0[i]=[]
for i in range(n-1):
    x,y=map(int,input().split())
    dict0[x].append(y)
    dict0[y].append(x)
ans=set()
def dfs():
    stack=[]
    stack.append([1,0,[True]*(n+1)])#(position,cat0,status)
    while len(stack)>0:
        a=stack.pop()
        position=a[0]
        cat0=a[1]
        prev_status=a[2]
        prev_status[position]=False
        if cat[position]==1:
            cat0+=1
        else:
            cat0=0
        if cat0>m:
            continue
        if len(dict0[position])==1 and position!=1:
            ans.add(position)
            continue
        for w in dict0[position]:
            if prev_status[w]==True:
                stack.append([w,cat0,prev_status])
                
dfs()
print(len(ans))

2021fall-cs101,和沛淼。

这道题dfs 和bfs 我都试过了,都可以通过,代码的写法也不难。难的地方在于时间复杂度的控制。一开始超时怎么也不行,最后把储存访问过的点的列表(visited)改成集合立马就过了。(这要是考试还不把人坑死)

python
# https://codeforces.com/contest/580/problem/C 
n,m=map(int,input().split())
L=list(map(int,input().split()))
s=dict()

for i in range(n):
    s.update({i+1:[]})
    
for _ in range(n-1):
    a,b=map(int,input().split())
    s[a].append(b)
    s[b].append(a)
    
ans=0
def dfs(q=list):
    visited=set()
    global ans
    while q!=[]:
        g=q.pop()
        now=g[0]
        t=g[1]
        if now not in visited:
            visited.add(now)
            if L[now-1]==1:
                t-=1
            else:
                t=m
            if t>=0:
                if now!=1 and len(s[now])==1:
                    ans+=1
                else:
                    for i in s[now]:
                        q.append([i,t])
    return
dfs([[1,m]])
print(ans)

2021fall-cs101,刘宇堃。这个题一开始没有注意到保证是树,按图来写的,然后就算 visited 了也有可能可以再走一次(如果从新的路上遇见更少的猫)

python
# https://codeforces.com/contest/580/problem/C
n, m = map(int, input().split())
*l0, = map(int, input().split())
l1 = [0 for i in range(n)]
ans = 0
graph = {}
for _ in range(n-1):
    a, b = map(int, input().split())
    if a not in graph:
        graph[a] = [b]
    else:
        graph[a].append(b)
    if b not in graph:
        graph[b] = [a]
    else:
        graph[b].append(a)
queue = [1]
l1[0] = l0[0]
vis = {1}
while queue:
    vertex = queue.pop(0)
    if len(graph[vertex]) > 1 or vertex == 1:
        for x in graph[vertex]:
            a = [0, l1[vertex-1]+1][l0[x-1]]
            if x not in vis and a <= m:
                queue.append(x)
                vis.add(x)
                l1[x-1] = a
    else:
        ans += 1
print(ans)

2021fall-cs101,潘逸阳。

python
n, m = map(int, input().split())
cat = [int(x) for x in input().split()]
cats = [cat[0]] + [0] * (n - 1)
link = [[] for _ in range(n)]
visit = [False] * n
for j in range(n - 1):
    a, b = (int(x) - 1 for x in input().split())
    link[b].append(a)
    link[a].append(b)
out = 0
queue = [0]
while queue:
    i = queue.pop(0)
    visit[i] = True
    if cats[i] > m: continue
    if len(link[i]) == 1:
        out += 1
        continue
    for j in link[i]:
        if visit[j]: continue
        if cat[j]: cats[j] = cats[i] + 1
        queue.append(j)
print(out)

广度优先搜索。事实上,将13行的pop(0)改为pop()后就是深度优先搜索,当然这个是在写了dfs在test35爆栈之后改了挺久才悟出来的,这个是改良之前的dfs.

python
def dfs(x, cats, no):
    global out
    if cats > m:
        return
    if x and len(link[x]) == 1:
              out += 1
        return
    for y in link[x] - no:
        if cat[y] == 0: dfs(y, 0, {x})
        else: dfs(y, cats + 1, {x})
n, m = map(int, input().split())
cat = [int(x) for x in input().split()]
link = [set() for _ in range(n)]
for j in range(n - 1):
    a, b = (int(x) - 1 for x in input().split())
    link[b].add(a)
    link[a].add(b)
out = 0
dfs(0, cat[0], set())
print(out)

然后听到老师说需要自己建栈之后就开始改,改着改着就变成了这个样子

python
n, m = map(int, input().split())
cat = [int(x) for x in input().split()]
cats = [cat[0]] + [0] * (n - 1)
link = [set() for _ in range(n)]
visit = set()
for j in range(n - 1):
    a, b = (int(x) - 1 for x in input().split())
    link[b].add(a)
    link[a].add(b)
out = 0
stack = [0]
while stack:
    i = stack.pop()
    visit.add(i)
    if cats[i] > m: continue
    if i and len(link[i]) == 1:
        out += 1
        continue
    for j in link[i] - visit:
        stack.append(j)
        if cat[j]: cats[j] = cats[i] + 1
print(out)

真的是一模一样。现在感觉写完自建stack才真正理解了回溯的含义,还有广搜和深搜之间的关系。而且dfs还比bfs要快,不知道为什么?

https://stackoverflow.com/questions/47222855/in-what-sense-is-dfs-faster-than-bfs Memory requirements: The stack size is bound by the depth whereas the queue size is bound by the width. For a balanced binary tree with n nodes, that means the stack size would be log(n) but the queue size would b O(n). Note that an explicit queue might not be needed for a BFS in all cases -- for instance, in an array embedded binary tree, it should be possible to compute the next index instead. Speed: I don't think that's true. For a full search, both cases visit all the nodes without significant extra overhead. If the search can be aborted when a matching element is found, BFS should typically be faster if the searched element is typically higher up in the search tree because it goes level by level. DFS might be faster if the searched element is typically relatively deep and finding one of many is sufficient.

2021fall-cs101,欧阳韵妍。

解题思路:如注释所示。关键点:(1)输入的数据有反向连接,所以输入一条链接后正反方向都要储存;(2)bfs 使用的 while,要预先建立一个队列,这个队列储存的就是一个水平内的信息(储存广度),这个水平走完了才走下一个水平。比如这张图:

img

3 个队列依次为:[1],[2,3],[4,5,6,7],通过找队列中父节点的所有子节点方法找出下一水平的子节点。还要注意生成一个 visited 的 set 判断这个节点是不是已经走过了(比如已经算过了 2-1,就不能再算 1-2),这个 visited 也是用来看是不是已经走到终点的关键。Visited 要用 set(),不能用 list(),不然会超时!!!(记得很久之前的作业也有一道是用 list()超时,改成 set()就 AC 了)

python
# https://codeforces.com/contest/580/problem/C
# 带注解版。注意在CF提交时候,不能有中文注释。
n,m = map(int,input().split())
cat1 = list(map(int,input().split()))
tree = {}
cat2 = [0]*(n+1) #连续多少只猫
cat2[1]=cat1[0] #要先把 1 这个点是否有猫记下来,后面的猫数以此为基础计算
for i in range(n-1):
    x,y = map(int,input().split())
    if x in tree.keys():
        tree[x].append(y)
    else:
        tree[x]=[y]
    if y in tree.keys(): #输入的数据有反向连接(比如 test8 中有且仅有的一条连线“2 1”)
        tree[y].append(x)
    else:
        tree[y]=[x] 
queue = [1] #从 1 开始看它子节点
restaurant = 0
visited = set() #如果用 list 就会超时
while queue!=[]:
    a = queue.pop() #a 是父节点
    visited.add(a)
    if cat2[a]>m: #猫数太多了,该父节点连带的所有树枝都丢弃
        continue
    flag = 1
    for i in tree[a]: #i 是子节点
        if i not in visited: #表明这个节点没有走过
            flag = 0
            if cat1[i-1]==1:
                cat2[i]=cat2[a]+1
            queue.append(i)
    if flag: #到达终点父节点时,它的子节点应该都走过了,flag 没有被标记为 0
        restaurant+=1
print(restaurant)