Skip to content

T3655.区间乘法查询后的异或 II

分治,逆元,费马小定理,https://leetcode.cn/problems/xor-after-range-multiplication-queries-ii/

给你一个长度为 n 的整数数组 nums 和一个大小为 q 的二维整数数组 queries,其中 queries[i] = [li, ri, ki, vi]

对于每个查询,需要按以下步骤依次执行操作:

  • 设定 idx = li

  • idx <= ri 时:

    • 更新:nums[idx] = (nums[idx] * vi) % (10^9 + 7)
    • idx += ki

    在处理完所有查询后,返回数组 nums 中所有元素的 按位异或 结果。

示例 1:

输入: nums = [1,1,1], queries = [[0,2,1,4]]

输出: 4

解释:

  • 唯一的查询 [0, 2, 1, 4] 将下标 0 到下标 2 的每个元素乘以 4。
  • 数组从 [1, 1, 1] 变为 [4, 4, 4]
  • 所有元素的异或为 4 ^ 4 ^ 4 = 4

示例 2:

输入: nums = [2,3,1,5,4], queries = [[1,4,2,3],[0,2,1,2]]

输出: 31

解释:

  • 第一个查询 [1, 4, 2, 3] 将下标 1 和 3 的元素乘以 3,数组变为 [2, 9, 1, 15, 4]
  • 第二个查询 [0, 2, 1, 2] 将下标 0、1 和 2 的元素乘以 2,数组变为 [4, 18, 2, 15, 4]
  • 所有元素的异或为 4 ^ 18 ^ 2 ^ 15 ^ 4 = 31

提示:

  • 1 <= n == nums.length <= 10^5
  • 1 <= nums[i] <= 10^9
  • 1 <= q == queries.length <= 10^5
  • queries[i] = [li, ri, ki, vi]
  • 0 <= li <= ri < n
  • 1 <= ki <= n
  • 1 <= vi <= 10^5

力扣官方题解,链接:https://leetcode.cn/problems/xor-after-range-multiplication-queries-ii/solutions/3941260/qu-jian-cheng-fa-cha-xun-hou-de-yi-huo-i-wifp/

方法:根号分治 + 差分

思路与算法

最朴素的想法是对每个查询,直接模拟,逐个去乘。单次查询时间复杂度就是 O(n),q 次查询总计 O(nq),规模约 1010。必然超时。问题症结在当 k 很小时,一次查询会触及大量元素,代价很高。

注意到步长 k 对复杂度的影响是截然不同的,我们可以按 k 和 n 的大小关系,将查询分为两类,分别用最适合的方法处理:

  • kn 时,每次查询最多触及 nkn 个元素,暴力可以接受。时间复杂度为 O(qn)
  • k<n 时,单次查询可以触及很多元素,暴力就扛不住了。

针对较小步长 (k<n),我们把查询按 k 值分组。相同的 k 可以一起处理。因为相同的 k 影响到的下标构成的网格是一样的,比如说 k=3 的所有查询里面,影响的下标都形如 l,l+k,l+2k,,它们都接受步长 3 跳跃的。

这样一来,我们固定了 k 之后,对于每个查询 [l,r,v],我们需要把 l,l+k,l+2k, 上的元素都乘上 v,这本质上是一个区间乘,但不是直接乘在原数组的连续区间,而是在步长为 k 的子序列上。

我们构造一个差分数组 dif,dif 的数组初始化为 1,处理查询 [l,r,v] 时,找到最后一个需要处理的元素的下标的下一个位置,设为 R(举个例子,对于查询 [2,7,3] 来讲,最后一个处理的元素下标为 5,那么 R 等于 8)。将 dif[l] = dif[l] * vdif[R] = dif[R] * v^{-1}。其中 v1 是 v 在模 M=109+7 意义下的逆元,可以使用费马小定理求解 vM2 得到。这样单次查询的处理时间复杂度就是 O(logM)

最后我们从前往后遍历 dif 数组,令 dif[i] = dif[i] * dif[i - k],即得到当前 k 下每个位置对于所有查询的累积量,然后再更新到原数组中,时间复杂度 O(n)。对于较小步长的查询处理,总体的时间复杂度为 O(nn+qlogA)

还有一个问题是 R 如何计算?查询影响到的最后一个下标是 l+rlkk,因此 R=l+(rlk+1)k。R 最大值是 n+k,为了方便模拟,我们申请 dif 数组大小为 n+n

python
   class Solution:
    def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
        mod = 10**9 + 7
        n = len(nums)
        T = int(n ** 0.5)

        groups = [[] for _ in range(T)]
        for l, r, k, v in queries:
            if k < T:
                groups[k].append((l, r, v))
            else:
                for i in range(l, r + 1, k):
                    nums[i] = nums[i] * v % mod

        dif = [1] * (n + T)
        for k in range(1, T):
            if not groups[k]:
                continue
            dif[:] = [1] * len(dif)
            for l, r, v in groups[k]:
                dif[l] = dif[l] * v % mod
                R = ((r - l) // k + 1) * k + l
                dif[R] = dif[R] * pow(v, mod - 2, mod) % mod

            for i in range(k, n):
                dif[i] = dif[i] * dif[i - k] % mod
            for i in range(n):
                nums[i] = nums[i] * dif[i] % mod

        res = 0
        for x in nums:
            res ^= x
        return res

复杂度分析

  • 时间复杂度:O((n+q)n+qlogM)
  • 空间复杂度:O(n+q)

针对这个问题,我们可以观察到每个查询操作实际上是针对数组 nums 的一个算术级数(Arithmetic Progression)位置上的元素进行乘法。由于乘法满足结合律和交换律,我们可以离线处理所有查询。

核心思路:

  1. 分组处理 (Grouping by Step and Remainder): 每个查询 [li, ri, ki, vi] 影响的是下标满足 idxli(modki)liidxri 的元素。我们可以按照步长 k 和余数 rem=li(modk) 将查询进行分组。

  2. 差分数组/扫描线 (Difference Array / Sweep-line): 对于每一组具有相同 (k,rem) 的查询,我们将对应的算术级数看作一个独立的 1D 数组。查询就是在该数组的区间 [l_idx,r_idx] 上进行乘法。 使用差分数组的思想,我们在位置 l_idx 乘上 vi,在 r_idx+1 乘上 vi 的逆元(因为 109+7 是质数且 vi1,逆元始终存在)。 遍历该算术级数的所有变动点(即 l_idxr_idx+1),在两个变动点之间的所有位置,累积乘积(curr_multiplier)保持不变。

  3. 复杂度分析:

    • 时间复杂度: 这种方法的总更新次数受限于 kmin(n,Qknk),其中 Qk 是步长为 k 的查询数量。可以证明该复杂度为 O(nq+qlogq)。在 n,q=105 的量级下,这大约是 3107 次操作,在 Python 中通过紧凑的循环可以运行。
    • 空间复杂度: O(n+q),用于存储数组、分组查询和逆元。

    Python 代码实现:

python
import collections

class Solution:
    def xorAfterQueries(self, nums: List[int], queries: List[List[int]]) -> int:
        n = len(nums)
        MOD = 1000000007
        
        # 将查询按步长 k 和起始位置的余数 rem 进行分组
        queries_by_k_rem = collections.defaultdict(list)
        for l, r, k, v in queries:
            if v == 1:
                continue
            queries_by_k_rem[(k, l % k)].append((l, r, v))
        
        # 预计算所有出现过的乘数 v 的模逆元
        memo_inv = {}
        for k_rem, qs in queries_by_k_rem.items():
            for l, r, v in qs:
                if v not in memo_inv:
                    # Python 3.8+ pow(v, -1, MOD) 效率很高
                    memo_inv[v] = pow(v, -1, MOD)
                    
        # 遍历每一组算术级数
        for (k, rem), qs in queries_by_k_rem.items():
            # points 存储该级数中乘积发生变化的位置索引
            # key 为级数中的偏移量 j (即 index = rem + j * k)
            points = {}
            for l, r, v in qs:
                l_idx = (l - rem) // k
                r_idx = (r - rem) // k
                
                # 在 l_idx 处开始生效
                points[l_idx] = (points.get(l_idx, 1) * v) % MOD
                # 在 r_idx + 1 处失效(乘以逆元)
                inv_v = memo_inv[v]
                points[r_idx + 1] = (points.get(r_idx + 1, 1) * inv_v) % MOD
            
            # 按级数中的索引排序变动点
            sorted_keys = sorted(points.keys())
            curr_multiplier = 1
            
            # 在两个相邻变动点之间的区间,累积乘积是固定的
            for i in range(len(sorted_keys) - 1):
                p_start = sorted_keys[i]
                p_end = sorted_keys[i+1]
                curr_multiplier = (curr_multiplier * points[p_start]) % MOD
                
                # 如果当前累积乘积不为 1,则更新区间内的原数组元素
                if curr_multiplier != 1:
                    start_idx = rem + p_start * k
                    # 只有当起始下标在数组范围内时才进行更新
                    if start_idx < n:
                        stop_idx = rem + p_end * k
                        # 使用 Python 原生 range 的 step 优化内部循环速度
                        for idx in range(start_idx, stop_idx, k):
                            nums[idx] = (nums[idx] * curr_multiplier) % MOD
                            
        # 最后返回所有元素的按位异或结果
        res = 0
        for x in nums:
            res ^= x
        return res