Skip to content

C26B:Doomsday Cumulonimbus

http://poj.openjudge.cn/practice/C26B/

C++
#include <bits/stdc++.h>
using namespace std;

static const int MOD = 998244353;
static const int G = 3;

static const int MAXA = 60000;
static const int MAXM = 10000000;

static const int B = 1500;
static const int NB = MAXA / B;

int mod_pow(long long a, long long e) {
    long long r = 1;
    while (e > 0) {
        if (e & 1) r = r * a % MOD;
        a = a * a % MOD;
        e >>= 1;
    }
    return (int)r;
}

void ntt(vector<int>& a, bool invert) {
    int n = (int)a.size();

    for (int i = 1, j = 0; i < n; ++i) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) j ^= bit;
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }

    for (int len = 2; len <= n; len <<= 1) {
        int wlen = mod_pow(G, (MOD - 1) / len);
        if (invert) wlen = mod_pow(wlen, MOD - 2);

        for (int i = 0; i < n; i += len) {
            long long w = 1;
            int half = len >> 1;

            for (int j = 0; j < half; ++j) {
                int u = a[i + j];
                int v = (int)(a[i + j + half] * w % MOD);

                int x = u + v;
                if (x >= MOD) x -= MOD;

                int y = u - v;
                if (y < 0) y += MOD;

                a[i + j] = x;
                a[i + j + half] = y;

                w = w * wlen % MOD;
            }
        }
    }

    if (invert) {
        int inv_n = mod_pow(n, MOD - 2);
        for (int &x : a) {
            x = (int)(1LL * x * inv_n % MOD);
        }
    }
}

vector<int> multiply_trunc(const vector<int>& A, const vector<int>& Bv, int need) {
    vector<int> c(need, 0);
    if (need <= 0 || A.empty() || Bv.empty()) return c;

    int nA = min((int)A.size(), need);
    int nB = min((int)Bv.size(), need);
    if (nA == 0 || nB == 0) return c;

    if (nA == 1) {
        long long s = A[0];
        for (int i = 0; i < nB && i < need; ++i) {
            c[i] = (int)(s * Bv[i] % MOD);
        }
        return c;
    }

    if (nB == 1) {
        long long s = Bv[0];
        for (int i = 0; i < nA && i < need; ++i) {
            c[i] = (int)(s * A[i] % MOD);
        }
        return c;
    }

    if (1LL * nA * nB <= 20000) {
        for (int i = 0; i < nA; ++i) {
            long long ai = A[i];
            int lim = min(nB, need - i);
            for (int j = 0; j < lim; ++j) {
                c[i + j] = (c[i + j] + ai * Bv[j]) % MOD;
            }
        }
        return c;
    }

    int sz = 1;
    while (sz < nA + nB - 1) sz <<= 1;

    vector<int> a(sz, 0), b(sz, 0);
    for (int i = 0; i < nA; ++i) a[i] = A[i];
    for (int i = 0; i < nB; ++i) b[i] = Bv[i];

    ntt(a, false);
    ntt(b, false);

    for (int i = 0; i < sz; ++i) {
        a[i] = (int)(1LL * a[i] * b[i] % MOD);
    }

    ntt(a, true);

    int lim = min(need, nA + nB - 1);
    for (int i = 0; i < lim; ++i) c[i] = a[i];

    return c;
}

vector<int> poly_inv(const vector<int>& f, int n) {
    vector<int> g(1, mod_pow(f[0], MOD - 2));

    for (int len = 1; len < n; len <<= 1) {
        int m = min(len << 1, n);

        int sz = 1;
        while (sz < 2 * m) sz <<= 1;

        vector<int> F(sz, 0), Gv(sz, 0);

        int flen = min((int)f.size(), m);
        for (int i = 0; i < flen; ++i) F[i] = f[i];
        for (int i = 0; i < (int)g.size(); ++i) Gv[i] = g[i];

        ntt(F, false);
        ntt(Gv, false);

        for (int i = 0; i < sz; ++i) {
            int fg = (int)(1LL * F[i] * Gv[i] % MOD);
            int val = 2 - fg;
            if (val < 0) val += MOD;
            Gv[i] = (int)(1LL * Gv[i] * val % MOD);
        }

        ntt(Gv, true);
        g.assign(Gv.begin(), Gv.begin() + m);
    }

    return g;
}

vector<vector<int>> suffixPoly;
vector<vector<int>> blockPoly;
vector<vector<int>> baseSeries;

vector<int> fact, ifact;

void build_suffix_and_block_poly() {
    suffixPoly.assign(NB + 1, {});
    blockPoly.assign(NB + 1, {});

    for (int b = 1; b <= NB; ++b) {
        int K0 = b * B;
        int C = B;

        int total = C * (C + 1) / 2;
        suffixPoly[b].resize(total);

        vector<int> poly(C + 1, 0);
        poly[0] = 1;

        int pos = 0;

        for (int r = 0; r < C; ++r) {
            if (r > 0) {
                int factor = K0 - r + 1;
                poly[r] = 0;

                for (int e = r; e >= 1; --e) {
                    int v = poly[e] - (int)(1LL * factor * poly[e - 1] % MOD);
                    if (v < 0) v += MOD;
                    poly[e] = v;
                }
            }

            for (int e = 0; e <= r; ++e) {
                suffixPoly[b][pos++] = poly[e];
            }
        }

        int factor = K0 - C + 1;

        vector<int> R(C + 1, 0);
        for (int e = 0; e < C; ++e) R[e] = poly[e];

        R[C] = 0;

        for (int e = C; e >= 1; --e) {
            int v = R[e] - (int)(1LL * factor * R[e - 1] % MOD);
            if (v < 0) v += MOD;
            R[e] = v;
        }

        blockPoly[b] = move(R);
    }
}

void build_base_series() {
    baseSeries.assign(NB + 1, {});

    vector<int> H(1, 1);

    for (int b = 1; b <= NB; ++b) {
        int K0 = b * B;
        int L = MAXA - K0 + B;
        int need = L + 1;

        vector<int> invR = poly_inv(blockPoly[b], need);

        vector<int> Hcut(min((int)H.size(), need));
        for (int i = 0; i < (int)Hcut.size(); ++i) {
            Hcut[i] = H[i];
        }

        H = multiply_trunc(Hcut, invR, need);
        baseSeries[b] = H;
    }

    vector<vector<int>>().swap(blockPoly);
}

void build_factorials() {
    fact.assign(MAXM + 1, 1);
    ifact.assign(MAXM + 1, 1);

    for (int i = 1; i <= MAXM; ++i) {
        fact[i] = (int)(1LL * fact[i - 1] * i % MOD);
    }

    ifact[MAXM] = mod_pow(fact[MAXM], MOD - 2);

    for (int i = MAXM; i >= 1; --i) {
        ifact[i - 1] = (int)(1LL * ifact[i] * i % MOD);
    }
}

int stirling2(int N, int K) {
    if (K < 0 || K > N) return 0;
    if (K == 0) return N == 0 ? 1 : 0;

    int b = (K + B - 1) / B;
    int K0 = b * B;

    int r = K0 - K;
    int d = N - K;

    const vector<int>& F = baseSeries[b];
    const vector<int>& C = suffixPoly[b];

    int start = r * (r + 1) / 2;
    int lim = min(r, d);

    const int* cp = C.data() + start;

    long long res = 0;

    for (int e = 0; e <= lim; ++e) {
        res += 1LL * cp[e] * F[d - e] % MOD;
    }

    return (int)(res % MOD);
}

int solve_one(int n, int m, int k) {
    int N = n - 1;
    int K = k - 1;
    int M = m - 1;

    if (K < 0 || K > N || K > M) return 0;

    int s = stirling2(N, K);
    if (s == 0) return 0;

    int falling = (int)(1LL * fact[M] * ifact[M - K] % MOD);

    int invPow = N == 0 ? 1 : mod_pow(M, MOD - 1 - N);

    return (int)(1LL * s * falling % MOD * invPow % MOD);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int q;
    cin >> q;

    build_suffix_and_block_poly();
    build_base_series();
    build_factorials();

    int last = 0;

    while (q--) {
        long long a, b, c;
        cin >> a >> b >> c;

        int n = (int)(a ^ last);
        int m = (int)(b ^ last);
        int k = (int)(c ^ last);

        int ans = solve_one(n, m, k);

        cout << ans << '\n';
        last = ans;
    }

    return 0;
}

针对 POJ/OpenJudge C26B 题目的 C++ 代码,该程序利用了 NTT (数论变换)多项式求逆 以及 分块 (Square Root Decomposition) 技术来高效计算大范围的第二类斯特林数 S(n,k)

由于 Python 的执行速度远低于 C++,直接翻译 NTT 可能在 N=60,000 的规模下超时。为了在 Python 中尽可能提高效率,我们需要:

  1. 使用 sys.stdin.read().split() 快速读取所有输入。
  2. 使用 array.array 替代列表以节省大量内存。
  3. 对 NTT 的核心循环进行优化(虽然 Python 的 NTT 仍然较慢)。
  4. 在计算 107 级别的阶乘时使用紧凑的循环。

以下是转换后的 Python 代码:

python
import sys
from array import array

# 设置递归深度
sys.setrecursionlimit(200000)

def solve():
    # 参数设置
    MOD = 998244353
    G = 3
    MAXA = 60000
    MAXM = 10000000
    B = 1500  # 分块大小
    NB = MAXA // B

    # 快速读入
    input_data = sys.stdin.read().split()
    if not input_data:
        return
    q_idx = 0
    num_queries = int(input_data[q_idx])
    q_idx += 1

    # 预处理阶乘
    fact = array('I', [1]) * (MAXM + 1)
    for i in range(1, MAXM + 1):
        fact[i] = (fact[i - 1] * i) % MOD
    
    # 逆元计算辅助函数
    def mod_pow(a, e):
        return pow(a, e, MOD)

    ifact = array('I', [1]) * (MAXM + 1)
    ifact[MAXM] = mod_pow(fact[MAXM], MOD - 2)
    for i in range(MAXM, 0, -1):
        ifact[i - 1] = (ifact[i] * i) % MOD

    # NTT 实现
    def ntt(a, invert):
        n = len(a)
        j = 0
        for i in range(1, n):
            bit = n >> 1
            while j & bit:
                j ^= bit
                bit >>= 1
            j ^= bit
            if i < j:
                a[i], a[j] = a[j], a[i]
        
        length = 2
        while length <= n:
            wlen = pow(G, (MOD - 1) // length, MOD)
            if invert:
                wlen = pow(wlen, MOD - 2, MOD)
            
            half = length >> 1
            # 预计算当前层的 w 幂
            w_powers = [1] * half
            for k in range(1, half):
                w_powers[k] = (w_powers[k - 1] * wlen) % MOD
                
            for i in range(0, n, length):
                for k in range(half):
                    u = a[i + k]
                    v = (a[i + k + half] * w_powers[k]) % MOD
                    a[i + k] = (u + v) % MOD
                    a[i + k + half] = (u - v + MOD) % MOD
            length <<= 1

        if invert:
            inv_n = mod_pow(n, MOD - 2)
            for i in range(n):
                a[i] = (a[i] * inv_n) % MOD

    def multiply_trunc(A, Bv, need):
        if need <= 0 or not A or not Bv:
            return array('I', [0] * need)
        
        nA, nB = min(len(A), need), min(len(Bv), need)
        if nA * nB <= 2000: # 小规模使用朴素卷积
            res = [0] * need
            for i in range(nA):
                ai = A[i]
                for j in range(min(nB, need - i)):
                    res[i + j] = (res[i + j] + ai * Bv[j]) % MOD
            return array('I', res)

        sz = 1
        while sz < nA + nB - 1: sz <<= 1
        fa = list(A[:nA]) + [0] * (sz - nA)
        fb = list(Bv[:nB]) + [0] * (sz - nB)
        ntt(fa, False)
        ntt(fb, False)
        for i in range(sz):
            fa[i] = (fa[i] * fb[i]) % MOD
        ntt(fa, True)
        return array('I', fa[:need])

    def poly_inv(f, n):
        g = [mod_pow(f[0], MOD - 2)]
        length = 1
        while length < n:
            length <<= 1
            sz = length << 1
            F = list(f[:min(len(f), length)]) + [0] * (sz - min(len(f), length))
            Gv = g + [0] * (sz - len(g))
            ntt(F, False)
            ntt(Gv, False)
            for i in range(sz):
                F[i] = Gv[i] * (2 - F[i] * Gv[i] % MOD + MOD) % MOD
            ntt(F, True)
            g = F[:min(length, n)]
        return array('I', g)

    # 预处理分块多项式
    suffixPoly = [] # 将存储 array
    blockPoly = []
    
    for b in range(NB + 1):
        if b == 0:
            suffixPoly.append(array('I'))
            blockPoly.append(array('I'))
            continue
        
        K0, C = b * B, B
        poly = [0] * (C + 1)
        poly[0] = 1
        current_suffix = array('I')
        
        for r in range(C):
            if r > 0:
                factor = (K0 - r + 1) % MOD
                for e in range(r, 0, -1):
                    poly[e] = (poly[e] - factor * poly[e - 1]) % MOD
            for e in range(r + 1):
                current_suffix.append(poly[e] % MOD)
        
        suffixPoly.append(current_suffix)
        
        factor = (K0 - C + 1) % MOD
        R = list(poly[:C]) + [0]
        for e in range(C, 0, -1):
            R[e] = (R[e] - factor * R[e - 1]) % MOD
        blockPoly.append(array('I', [x % MOD for x in R]))

    # 建立基础序列
    baseSeries = [array('I')] * (NB + 1)
    H = array('I', [1])
    for b in range(1, NB + 1):
        K0 = b * B
        need = MAXA - K0 + B + 1
        invR = poly_inv(blockPoly[b], need)
        H = multiply_trunc(H[:min(len(H), need)], invR, need)
        baseSeries[b] = H

    def get_stirling2(N, K):
        if K < 0 or K > N: return 0
        if K == 0: return 1 if N == 0 else 0
        b = (K + B - 1) // B
        K0 = b * B
        r, d = K0 - K, N - K
        F = baseSeries[b]
        C = suffixPoly[b]
        start = r * (r + 1) // 2
        res = 0
        lim = min(r, d)
        for e in range(lim + 1):
            res = (res + C[start + e] * F[d - e]) % MOD
        return res

    # 在线查询
    last = 0
    results = []
    for _ in range(num_queries):
        a = int(input_data[q_idx]) ^ last
        b = int(input_data[q_idx + 1]) ^ last
        c = int(input_data[q_idx + 2]) ^ last
        q_idx += 3
        
        N, M, K = a - 1, b - 1, c - 1
        if K < 0 or K > N or K > M:
            ans = 0
        else:
            s = get_stirling2(N, K)
            if s == 0:
                ans = 0
            else:
                falling = (fact[M] * ifact[M - K]) % MOD
                invPow = 1 if N == 0 else mod_pow(M, MOD - 1 - N)
                ans = (s * falling % MOD * invPow) % MOD
        
        results.append(str(ans))
        last = ans

    sys.stdout.write('\n'.join(results) + '\n')

if __name__ == '__main__':
    solve()

关键点说明:

  1. 分块斯特林数计算

    • 该算法核心在于:S(n,k)=[xnk]i=1k11ix
    • 利用分块,将 k 靠近某个预计算好的块终点 k0=bB
    • 通过 1i=1k(1ix)=1i=1k0(1ix)j=k+1k0(1jx),将求大范围斯特林数转化为一个基础多项式(baseSeries)与一个低次多项式(suffixPoly,次数不超过 B)的乘积。
  2. 内存优化

    • array.array('I', ...) 使用 4 字节无符号整数,比 Python 列表(每个元素至少 28 字节)节省约 7 倍内存。这对于存储 107 规模的阶乘和大量的分块多项式至关重要。
  3. 性能瓶颈

    • 阶乘循环:Python 的 for 循环处理 107 次计算约需 1-1.5 秒。
    • NTT:Python 无法像 C++ 那样高效利用 NTT,但好在 poly_inv 只在预处理执行约 40 次。
    • 在线解码:通过 last 变量记录上一次答案,符合题目强制在线的要求。
  4. 分块大小 (B)

    • 代码中设 B=1500。如果遇到超时,可尝试调小 B(如 800-1000)以平衡预处理和查询的时间,但注意这会增加 NB,从而增加 baseSeries 的 NTT 处理次数。