2920.收集所有金币可获得的最大积分
tree dp, https://leetcode.cn/problems/maximum-points-after-collecting-coins-from-all-nodes/
有一棵由 n 个节点组成的无向树,以 0 为根节点,节点编号从 0 到 n - 1 。给你一个长度为 n - 1 的二维 整数 数组 edges ,其中 edges[i] = [ai, bi] 表示在树上的节点 ai 和 bi 之间存在一条边。另给你一个下标从 0 开始、长度为 n 的数组 coins 和一个整数 k ,其中 coins[i] 表示节点 i 处的金币数量。
从根节点开始,你必须收集所有金币。要想收集节点上的金币,必须先收集该节点的祖先节点上的金币。
节点 i 上的金币可以用下述方法之一进行收集:
- 收集所有金币,得到共计
coins[i] - k点积分。如果coins[i] - k是负数,你将会失去abs(coins[i] - k)点积分。 - 收集所有金币,得到共计
floor(coins[i] / 2)点积分。如果采用这种方法,节点i子树中所有节点j的金币数coins[j]将会减少至floor(coins[j] / 2)。
返回收集 所有 树节点的金币之后可以获得的最大积分。
示例 1:

输入:edges = [[0,1],[1,2],[2,3]], coins = [10,10,3,3], k = 5
输出:11
解释:
使用第一种方法收集节点 0 上的所有金币。总积分 = 10 - 5 = 5 。
使用第一种方法收集节点 1 上的所有金币。总积分 = 5 + (10 - 5) = 10 。
使用第二种方法收集节点 2 上的所有金币。所以节点 3 上的金币将会变为 floor(3 / 2) = 1 ,总积分 = 10 + floor(3 / 2) = 11 。
使用第二种方法收集节点 3 上的所有金币。总积分 = 11 + floor(1 / 2) = 11.
可以证明收集所有节点上的金币能获得的最大积分是 11 。示例 2:

输入:edges = [[0,1],[0,2]], coins = [8,4,4], k = 0
输出:16
解释:
使用第一种方法收集所有节点上的金币,因此,总积分 = (8 - 0) + (4 - 0) + (4 - 0) = 16 。提示:
n == coins.length2 <= n <= 1050 <= coins[i] <= 104edges.length == n - 10 <= edges[i][0], edges[i][1] < n0 <= k <= 104
树形dp,还挺有意思。defaultdict特别有用,defaultdict(list)就是邻接表。另外,类似lru_cache的装饰器可以自己写,例如:
python
def memoize(f):
memo = {}
def helper(x, y, z):
if (x, y, z) not in memo:
memo[x, y, z] = f(x, y, z)
return memo[x, y, z]
return helper
@memoize
def dfs(node, parent, halved_times):python
class Solution:
def maximumPoints(self, edges: List[List[int]], coins: List[int], k: int) -> int:
# 构建邻接表表示的图
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
graph[v].append(u)
# 记忆化装饰器
def memoize(f):
memo = {}
def helper(x, y, z):
if (x, y, z) not in memo:
memo[x, y, z] = f(x, y, z)
return memo[x, y, z]
return helper
@memoize
def dfs(node, parent, halved_times):
# 如果该子树已经经过多次减半,直接返回0以避免过小的数值影响结果
if halved_times >= 14: # log2(10^4) < 14, 因为coins[i] <= 10^4
return 0
# 减半操作应用于当前节点
current_coins = coins[node] // (2 ** halved_times)
# 两种选择:不减半和减半
option1 = current_coins - k
option2 = current_coins // 2
for child in graph[node]:
if child != parent: # 避免回溯到父节点
option1 += dfs(child, node, halved_times)
option2 += dfs(child, node, halved_times + 1)
return max(option1, option2)
# 从根节点开始DFS遍历
return dfs(0, None, 0)python
class Solution:
def maximumPoints(self, edges: List[List[int]], coins: List[int], k: int) -> int:
# 构建邻接表表示的图
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
graph[v].append(u)
n = len(coins)
# dp数组,dp[node][halved_times]表示从node开始,经过halved_times次减半后的最大积分
dp = [[-1] * 14 for _ in range(n)]
def dfs(node, parent, halved_times):
if halved_times >= 14: # 防止过多减半导致数值过小
return 0
if dp[node][halved_times] != -1:
return dp[node][halved_times]
current_coins = floor(coins[node] / (2 ** halved_times))
# 不减半当前节点金币的选择
option1 = current_coins - k
# 减半当前节点金币的选择
option2 = floor(current_coins / 2)
for child in graph[node]:
if child == parent:
continue
option1 += dfs(child, node, halved_times)
option2 += dfs(child, node, halved_times + 1)
# 记录两种选择中的最大值
dp[node][halved_times] = max(option1, option2)
return dp[node][halved_times]
# 从根节点0开始DFS遍历
return dfs(0, None, 0)