Skip to content

24686: 树的重量

http://cs101.openjudge.cn/practice/24686/

有一棵 k 层的满二叉树(一共有2k-1个节点,且从上到下从左到右依次编号为1, 2, ..., 2k-1),最开始每个节点的重量均为0。请编程实现如下两种操作:

1 x y:给以 x 为根的子树的每个节点的重量分别增加 y( y 是整数且绝对值不超过100)

2 x:查询(此时的)以 x 为根的子树的所有节点重量之和

输入

输入有n+1行。第一行是两个整数k, n,分别表示满二叉树的层数和操作的个数。接下来n行,每行形如1 x y或2 x,表示一个操作。

k<=15(即最多32767个节点),n<=50000。

输出

输出有若干行,对每个查询操作依次输出结果,每个结果占一行。

样例输入

3 7
1 2 1
2 4
1 6 3
2 1
1 3 -2
1 4 1
2 3

样例输出

1
6
-3

提示

可以通过对数计算某节点的深度:

import math

math.log2(x) #以小数形式返回x的对数值,注意x不能为0

满二叉树是一种特殊的二叉树,其中每个节点要么是叶子节点,要么有两个子节点。

变量k和n分别代表满二叉树的层数和操作的个数。f和g是两个列表,用于存储每个节点的权重和懒惰标记。dep列表用于存储每个节点的深度。

如果操作的长度为2,那么这是一个查询操作,需要计算以给定节点为根的子树的所有节点的权重之和。如果操作的长度为3,那么这是一个更新操作,需要更新以给定节点为根的子树的所有节点的权重。

初始化了三个列表 f, g, 和 dep 来存储关于树的信息。f 用于记录懒惰传播的值,g 可能是用于存储临时的累积更新,dep 存储每个节点的深度。tot 是树中节点的总数。

  • f could represent some aggregated value at each node (like a lazy propagation value).
  • g could represent some other value that needs to be propagated down the tree (potentially a modification that applies to all child nodes).

计算深度:从下到上计算每个节点的深度。对于满二叉树来说,如果一个节点编号为 i,则它的子节点编号为 2i2i + 1。深度是从最底层叶子节点开始反向计算的。

查询操作,首先获取根节点的权重,然后逐层向上,获取每一层父节点的权重,最后加上懒惰标记的权重。 查询操作(2 x):如果操作有两个数字,它是一个查询操作。它从根开始累积 f 中的值,沿着树向上移动直到达到节点 x。然后计算以 x 为根的子树中所有节点的重量之和,考虑到懒惰传播的值和直接更新的值,最后打印结果。

更新操作,首先更新给定节点的权重,然后逐层向上,更新每一层父节点的懒惰标记。 增加操作(1 x y):将 y 增加到 f 中对应节点 x 的值,并且将 wy 乘以以 x 为根的子树的节点总数)累积到 g 中对应节点 x 的父节点中。然后它继续沿树向上更新 g,直到根节点。

主要思想是使用懒惰标记来优化查询和更新操作的时间复杂度。

问:查询时候,是计算以 x 为根的子树的所有节点重量之和,为什么要向上找根节点,一路计算?

答:查询操作的目标是计算以x为根的子树的所有节点重量之和。这是通过向上找根节点并一路计算来实现的。这种方法的原因是,代码中的更新操作是延迟的,也就是说,当我们对一个区间进行更新操作时,并不立即更新区间中的所有元素,而是将更新的值存储在一个特定的数据结构中(在这个例子中是数组f和g)。然后,当我们进行查询操作时,我们需要检查这个区间是否有待更新的值,如果有,我们就需要在查询的过程中,一路向上找到根节点,将这些待更新的值加入到查询结果中。

具体来说,对于每一个节点u,我们都存储了一个值f[u],表示这个节点及其所有子节点需要增加的值。然后,当我们进行查询操作时,我们需要从目标节点开始,一路向上找到根节点,将这些待更新的值加入到查询结果中。这就是为什么我们在查询操作中需要一路向上找根节点的原因。

同时,我们还需要注意,由于我们的更新操作是延迟的,所以在查询操作中,我们还需要处理那些还没有被实际更新的节点。这就是为什么我们在查询操作中,除了加入f[u]之外,还需要加入g[u]。g[u]存储的是这个节点及其所有子节点由于之前的更新操作而增加的值,但是这些值还没有被实际加入到这些节点中。所以,在查询操作中,我们需要将这些值也加入到查询结果中。

python
k, n = [int(x) for x in input().split()]
f, g, dep = [], [], []
tot = (1 << k) - 1
for _ in range(tot+1):
    f.append(0)
    g.append(0)
    dep.append(0)
for i in range(tot, 0, -1):
    dep[i] = 1 if i * 2 > tot else dep[i * 2] + 1
for _ in range(n):
    a = [int(x) for x in input().split()]
    if len(a) == 2:
        u = a[1]
        s = f[1]
        while u != 1:
            s += f[u]
            u >>= 1
        ans = s * ((1 << dep[a[1]]) - 1) + g[a[1]]
        print(ans)
    elif len(a) == 3:
        u = a[1]
        w = a[2] * ((1 << dep[u]) - 1)
        f[u] += a[2]
        while u != 1:
            u >>= 1
            g[u] += w