Skip to content

T1622.奇妙序列

math, segment tree, https://leetcode.cn/problems/fancy-sequence/

请你实现三个 API appendaddAllmultAll 来实现奇妙序列。

请实现 Fancy 类 :

  • Fancy() 初始化一个空序列对象。
  • void append(val) 将整数 val 添加在序列末尾。
  • void addAll(inc) 将所有序列中的现有数值都增加 inc
  • void multAll(m) 将序列中的所有现有数值都乘以整数 m
  • int getIndex(idx) 得到下标为 idx 处的数值(下标从 0 开始),并将结果对 10^9 + 7 取余。如果下标大于等于序列的长度,请返回 -1

示例:

输入:
["Fancy", "append", "addAll", "append", "multAll", "getIndex", "addAll", "append", "multAll", "getIndex", "getIndex", "getIndex"]
[[], [2], [3], [7], [2], [0], [3], [10], [2], [0], [1], [2]]
输出:
[null, null, null, null, null, 10, null, null, null, 26, 34, 20]

解释:
Fancy fancy = new Fancy();
fancy.append(2);   // 奇妙序列:[2]
fancy.addAll(3);   // 奇妙序列:[2+3] -> [5]
fancy.append(7);   // 奇妙序列:[5, 7]
fancy.multAll(2);  // 奇妙序列:[5*2, 7*2] -> [10, 14]
fancy.getIndex(0); // 返回 10
fancy.addAll(3);   // 奇妙序列:[10+3, 14+3] -> [13, 17]
fancy.append(10);  // 奇妙序列:[13, 17, 10]
fancy.multAll(2);  // 奇妙序列:[13*2, 17*2, 10*2] -> [26, 34, 20]
fancy.getIndex(0); // 返回 26
fancy.getIndex(1); // 返回 34
fancy.getIndex(2); // 返回 20

提示:

  • 1 <= val, inc, m <= 100
  • 0 <= idx <= 10^5
  • 总共最多会有 10^5 次对 appendaddAllmultAllgetIndex 的调用。

灵茶山艾府 链接:https://leetcode.cn/problems/fancy-sequence/solutions/3917656/lan-geng-xin-deng-jie-zhuan-hua-pythonja-csvl/

MOD = 1_000_000_007

class Fancy:
    def __init__(self):
        self.vals = []
        self.add = 0
        self.mul = 1

    def append(self, val: int) -> None:
        self.vals.append((val - self.add) * pow(self.mul, -1, MOD) % MOD)

    def addAll(self, inc: int) -> None:
        self.add += inc

    def multAll(self, m: int) -> None:
        self.mul = self.mul * m % MOD
        self.add = self.add * m % MOD

    def getIndex(self, idx: int) -> int:
        if idx >= len(self.vals):
            return -1
        return (self.vals[idx] * self.mul + self.add) % MOD

这道题的核心挑战在于如何在 O(1)O(logN) 的时间内处理全序列的加法和乘法操作。如果每次 addAllmultAll 都遍历整个序列,时间复杂度会达到 O(N2),在 105 的数据规模下会超时。

解题思路:线性变换映射

我们可以维护一个全局的线性变换y=Ax+B

  1. 状态维护

    • A:当前的全局乘积因子(初始为 1)。
    • B:当前的全局增量偏移(初始为 0)。
    • MOD=109+7
  2. 操作逻辑

    • addAll(inc):所有的数都加 inc。 新的结果变为 (Ax+B)+inc=Ax+(B+inc)。 只需更新:B=(B+inc)(modMOD)
    • multAll(m):所有的数都乘 m。 新的结果变为 m(Ax+B)=(mA)x+(mB)。 只需更新:A=(Am)(modMOD)B=(Bm)(modMOD)
    • append(val):我们需要将 val 存入数组,使得之后应用当前的 (A,B) 时,能还原回 val。 即求解方程:Ax+Bval(modMOD)。 解得:x(valB)A1(modMOD)。 这里 A1A 在模 MOD 意义下的逆元。因为 MOD 是质数,可以使用费马小定理通过 pow(A, MOD - 2, MOD) 计算。
    • getIndex(idx):返回 (Anums[idx]+B)(modMOD)

代码实现

python
class Fancy:

    def __init__(self):
        self.nums = []
        self.a = 1  # 全局乘法因子
        self.b = 0  # 全局加法因子
        self.MOD = 10**9 + 7

    def append(self, val: int) -> None:
        # 我们存入一个 x,使得 (x * a + b) % MOD == val
        # x = (val - b) * inv(a) % MOD
        # 使用费马小定理求逆元:inv(a) = a^(MOD-2) % MOD
        inv_a = pow(self.a, self.MOD - 2, self.MOD)
        self.nums.append(((val - self.b) * inv_a) % self.MOD)

    def addAll(self, inc: int) -> None:
        # (x * a + b) + inc = x * a + (b + inc)
        self.b = (self.b + inc) % self.MOD

    def multAll(self, m: int) -> None:
        # (x * a + b) * m = x * (a * m) + (b * m)
        self.a = (self.a * m) % self.MOD
        self.b = (self.b * m) % self.MOD

    def getIndex(self, idx: int) -> int:
        if idx >= len(self.nums):
            return -1
        # 计算当前值:(原始存入的x * 当前a + 当前b) % MOD
        return (self.nums[idx] * self.a + self.b) % self.MOD

复杂度分析

  • __init__: O(1)
  • append: O(logMOD)。主要耗时在计算模逆元(快速幂)。由于 MOD 是固定的,且 m,val1a 永远不会是 MOD 的倍数,因此逆元始终存在。
  • addAll: O(1)
  • multAll: O(1)
  • getIndex: O(1)
  • 空间复杂度: O(N),用于存储序列元素。

这种方法巧妙地利用了线性变换的复合性质,将对整个数组的操作转化为了对两个全局变量的操作。

如果“线性变换”数学推导让你觉得抽象,我们换一种结构化的思路:线段树(Segment Tree)

虽然这道题的操作永远是针对“整个序列”(即前缀 [0,n1]),使用线段树有点“大材小用”,但它是一种非常通用的解决区间修改区间查询的工具。

线段树的核心思路

线段树的每一个节点维护一个区间的状态。对于这道题,我们关注两个懒惰标记(Lazy Tag)

  1. mul:该区间需要乘的倍数(初始为 1)。
  2. add:该区间需要加的数值(初始为 0)。

更新规则(重要): 当我们对一个区间先乘以 m,再加上 inc 时,节点的状态遵循 y=xmul+add

  • 如果再执行 multAll(m): 新的值 = (xmul+add)m=x(mulm)+(addm)。 所以:mul = mul * m, add = add * m
  • 如果再执行 addAll(inc): 新的值 = (xmul+add)+inc=xmul+(add+inc)。 所以:add = add + inc

Python 代码实现(线段树版)

注意:在 Python 中,为了防止 10^5 规模下的递归导致超时(TLE),我们通常使用数组模拟的线段树,并尽量减少操作。

python
class Fancy:
    def __init__(self):
        self.MAX_N = 100005
        self.n = 0
        # tree 只在叶子节点存值,其实这道题只需要标记即可
        # mul_tag[i] 表示第 i 个节点对应的区间需要乘的系数
        self.mul_tag = [1] * (4 * self.MAX_N)
        # add_tag[i] 表示第 i 个节点对应的区间需要加的增量
        self.add_tag = [0] * (4 * self.MAX_N)
        self.nums = []
        self.MOD = 10**9 + 7

    def _push_down(self, node):
        """下传懒惰标记"""
        m = self.mul_tag[node]
        a = self.add_tag[node]
        if m == 1 and a == 0:
            return
        
        # 更新左右子节点
        for child in [node * 2, node * 2 + 1]:
            # 子节点的旧值 y = x * m1 + a1
            # 应用父节点标记后:(x * m1 + a1) * m + a = x * (m1 * m) + (a1 * m + a)
            self.mul_tag[child] = (self.mul_tag[child] * m) % self.MOD
            self.add_tag[child] = (self.add_tag[child] * m + a) % self.MOD
        
        # 重置当前节点标记
        self.mul_tag[node] = 1
        self.add_tag[node] = 0

    def _update(self, node, start, end, l, r, m, a):
        """区间更新:将 [l, r] 区间乘以 m 加上 a"""
        if l <= start and end <= r:
            self.mul_tag[node] = (self.mul_tag[node] * m) % self.MOD
            self.add_tag[node] = (self.add_tag[node] * m + a) % self.MOD
            return
        
        self._push_down(node)
        mid = (start + end) // 2
        if l <= mid:
            self._update(node * 2, start, mid, l, r, m, a)
        if r > mid:
            self._update(node * 2 + 1, mid + 1, end, l, r, m, a)

    def _point_update(self, node, start, end, idx, val):
        """点更新:在 idx 位置插入初始值 val"""
        if start == end:
            self.mul_tag[node] = 1
            self.add_tag[node] = val # 这里借用 add_tag 存初始值
            return
        
        self._push_down(node)
        mid = (start + end) // 2
        if idx <= mid:
            self._point_update(node * 2, start, mid, idx, val)
        else:
            self._point_update(node * 2 + 1, mid + 1, end, idx, val)

    def _query(self, node, start, end, idx):
        """点查询:获取 idx 位置的当前值"""
        if start == end:
            return self.add_tag[node]
        
        self._push_down(node)
        mid = (start + end) // 2
        if idx <= mid:
            return self._query(node * 2, start, mid, idx)
        else:
            return self._query(node * 2 + 1, mid + 1, end, idx)

    def append(self, val: int) -> None:
        # 在当前末尾插入新数
        self._point_update(1, 0, self.MAX_N - 1, self.n, val)
        self.n += 1

    def addAll(self, inc: int) -> None:
        if self.n > 0:
            # 更新范围 [0, n-1]
            self._update(1, 0, self.MAX_N - 1, 0, self.n - 1, 1, inc)

    def multAll(self, m: int) -> None:
        if self.n > 0:
            # 更新范围 [0, n-1]
            self._update(1, 0, self.MAX_N - 1, 0, self.n - 1, m, 0)

    def getIndex(self, idx: int) -> int:
        if idx >= self.n:
            return -1
        return self._query(1, 0, self.MAX_N - 1, idx)

线段树版与数学版的区别

  1. 为什么数学版更快? 在这道题中,所有的修改操作都是针对整个当前长度的序列(从 0 到 n1)。这意味着我们其实只需要维护线段树的根节点的标记即可。线段树是在处理“只修改中间某一段”时才显示出威力。
  2. 核心逻辑是一样的: 你会发现线段树代码里的 mul = mul * madd = add * m + a,本质上就是我之前解释的线性变换复合公式。
  3. 如何理解 append
    • 在数学版里,我们为了不影响旧数据,对新数据做了“逆运算”。
    • 在线段树版里,我们不需要逆运算。因为 append 是在一个全新的位置插入数据,那个位置之前的 muladd 标记还是初始状态(1 和 0),所以直接存进去就行,它只会受到“未来”操作的影响。

总结: 如果你觉得数学推导难,线段树提供了一个物理模型:把数组分成一段段,每一段都贴上两个标签(乘法标、加法标),每次操作时,如果是针对整段的,就只改标签;如果查具体某个数,就把路上的标签全算一遍。