C26B:Doomsday Cumulonimbus
http://poj.openjudge.cn/practice/C26B/
C++
#include <bits/stdc++.h>
using namespace std;
static const int MOD = 998244353;
static const int G = 3;
static const int MAXA = 60000;
static const int MAXM = 10000000;
static const int B = 1500;
static const int NB = MAXA / B;
int mod_pow(long long a, long long e) {
long long r = 1;
while (e > 0) {
if (e & 1) r = r * a % MOD;
a = a * a % MOD;
e >>= 1;
}
return (int)r;
}
void ntt(vector<int>& a, bool invert) {
int n = (int)a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = mod_pow(G, (MOD - 1) / len);
if (invert) wlen = mod_pow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
long long w = 1;
int half = len >> 1;
for (int j = 0; j < half; ++j) {
int u = a[i + j];
int v = (int)(a[i + j + half] * w % MOD);
int x = u + v;
if (x >= MOD) x -= MOD;
int y = u - v;
if (y < 0) y += MOD;
a[i + j] = x;
a[i + j + half] = y;
w = w * wlen % MOD;
}
}
}
if (invert) {
int inv_n = mod_pow(n, MOD - 2);
for (int &x : a) {
x = (int)(1LL * x * inv_n % MOD);
}
}
}
vector<int> multiply_trunc(const vector<int>& A, const vector<int>& Bv, int need) {
vector<int> c(need, 0);
if (need <= 0 || A.empty() || Bv.empty()) return c;
int nA = min((int)A.size(), need);
int nB = min((int)Bv.size(), need);
if (nA == 0 || nB == 0) return c;
if (nA == 1) {
long long s = A[0];
for (int i = 0; i < nB && i < need; ++i) {
c[i] = (int)(s * Bv[i] % MOD);
}
return c;
}
if (nB == 1) {
long long s = Bv[0];
for (int i = 0; i < nA && i < need; ++i) {
c[i] = (int)(s * A[i] % MOD);
}
return c;
}
if (1LL * nA * nB <= 20000) {
for (int i = 0; i < nA; ++i) {
long long ai = A[i];
int lim = min(nB, need - i);
for (int j = 0; j < lim; ++j) {
c[i + j] = (c[i + j] + ai * Bv[j]) % MOD;
}
}
return c;
}
int sz = 1;
while (sz < nA + nB - 1) sz <<= 1;
vector<int> a(sz, 0), b(sz, 0);
for (int i = 0; i < nA; ++i) a[i] = A[i];
for (int i = 0; i < nB; ++i) b[i] = Bv[i];
ntt(a, false);
ntt(b, false);
for (int i = 0; i < sz; ++i) {
a[i] = (int)(1LL * a[i] * b[i] % MOD);
}
ntt(a, true);
int lim = min(need, nA + nB - 1);
for (int i = 0; i < lim; ++i) c[i] = a[i];
return c;
}
vector<int> poly_inv(const vector<int>& f, int n) {
vector<int> g(1, mod_pow(f[0], MOD - 2));
for (int len = 1; len < n; len <<= 1) {
int m = min(len << 1, n);
int sz = 1;
while (sz < 2 * m) sz <<= 1;
vector<int> F(sz, 0), Gv(sz, 0);
int flen = min((int)f.size(), m);
for (int i = 0; i < flen; ++i) F[i] = f[i];
for (int i = 0; i < (int)g.size(); ++i) Gv[i] = g[i];
ntt(F, false);
ntt(Gv, false);
for (int i = 0; i < sz; ++i) {
int fg = (int)(1LL * F[i] * Gv[i] % MOD);
int val = 2 - fg;
if (val < 0) val += MOD;
Gv[i] = (int)(1LL * Gv[i] * val % MOD);
}
ntt(Gv, true);
g.assign(Gv.begin(), Gv.begin() + m);
}
return g;
}
vector<vector<int>> suffixPoly;
vector<vector<int>> blockPoly;
vector<vector<int>> baseSeries;
vector<int> fact, ifact;
void build_suffix_and_block_poly() {
suffixPoly.assign(NB + 1, {});
blockPoly.assign(NB + 1, {});
for (int b = 1; b <= NB; ++b) {
int K0 = b * B;
int C = B;
int total = C * (C + 1) / 2;
suffixPoly[b].resize(total);
vector<int> poly(C + 1, 0);
poly[0] = 1;
int pos = 0;
for (int r = 0; r < C; ++r) {
if (r > 0) {
int factor = K0 - r + 1;
poly[r] = 0;
for (int e = r; e >= 1; --e) {
int v = poly[e] - (int)(1LL * factor * poly[e - 1] % MOD);
if (v < 0) v += MOD;
poly[e] = v;
}
}
for (int e = 0; e <= r; ++e) {
suffixPoly[b][pos++] = poly[e];
}
}
int factor = K0 - C + 1;
vector<int> R(C + 1, 0);
for (int e = 0; e < C; ++e) R[e] = poly[e];
R[C] = 0;
for (int e = C; e >= 1; --e) {
int v = R[e] - (int)(1LL * factor * R[e - 1] % MOD);
if (v < 0) v += MOD;
R[e] = v;
}
blockPoly[b] = move(R);
}
}
void build_base_series() {
baseSeries.assign(NB + 1, {});
vector<int> H(1, 1);
for (int b = 1; b <= NB; ++b) {
int K0 = b * B;
int L = MAXA - K0 + B;
int need = L + 1;
vector<int> invR = poly_inv(blockPoly[b], need);
vector<int> Hcut(min((int)H.size(), need));
for (int i = 0; i < (int)Hcut.size(); ++i) {
Hcut[i] = H[i];
}
H = multiply_trunc(Hcut, invR, need);
baseSeries[b] = H;
}
vector<vector<int>>().swap(blockPoly);
}
void build_factorials() {
fact.assign(MAXM + 1, 1);
ifact.assign(MAXM + 1, 1);
for (int i = 1; i <= MAXM; ++i) {
fact[i] = (int)(1LL * fact[i - 1] * i % MOD);
}
ifact[MAXM] = mod_pow(fact[MAXM], MOD - 2);
for (int i = MAXM; i >= 1; --i) {
ifact[i - 1] = (int)(1LL * ifact[i] * i % MOD);
}
}
int stirling2(int N, int K) {
if (K < 0 || K > N) return 0;
if (K == 0) return N == 0 ? 1 : 0;
int b = (K + B - 1) / B;
int K0 = b * B;
int r = K0 - K;
int d = N - K;
const vector<int>& F = baseSeries[b];
const vector<int>& C = suffixPoly[b];
int start = r * (r + 1) / 2;
int lim = min(r, d);
const int* cp = C.data() + start;
long long res = 0;
for (int e = 0; e <= lim; ++e) {
res += 1LL * cp[e] * F[d - e] % MOD;
}
return (int)(res % MOD);
}
int solve_one(int n, int m, int k) {
int N = n - 1;
int K = k - 1;
int M = m - 1;
if (K < 0 || K > N || K > M) return 0;
int s = stirling2(N, K);
if (s == 0) return 0;
int falling = (int)(1LL * fact[M] * ifact[M - K] % MOD);
int invPow = N == 0 ? 1 : mod_pow(M, MOD - 1 - N);
return (int)(1LL * s * falling % MOD * invPow % MOD);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int q;
cin >> q;
build_suffix_and_block_poly();
build_base_series();
build_factorials();
int last = 0;
while (q--) {
long long a, b, c;
cin >> a >> b >> c;
int n = (int)(a ^ last);
int m = (int)(b ^ last);
int k = (int)(c ^ last);
int ans = solve_one(n, m, k);
cout << ans << '\n';
last = ans;
}
return 0;
}针对 POJ/OpenJudge C26B 题目的 C++ 代码,该程序利用了 NTT (数论变换)、多项式求逆 以及 分块 (Square Root Decomposition) 技术来高效计算大范围的第二类斯特林数
由于 Python 的执行速度远低于 C++,直接翻译 NTT 可能在
- 使用
sys.stdin.read().split()快速读取所有输入。 - 使用
array.array替代列表以节省大量内存。 - 对 NTT 的核心循环进行优化(虽然 Python 的 NTT 仍然较慢)。
- 在计算
级别的阶乘时使用紧凑的循环。
以下是转换后的 Python 代码:
python
import sys
from array import array
# 设置递归深度
sys.setrecursionlimit(200000)
def solve():
# 参数设置
MOD = 998244353
G = 3
MAXA = 60000
MAXM = 10000000
B = 1500 # 分块大小
NB = MAXA // B
# 快速读入
input_data = sys.stdin.read().split()
if not input_data:
return
q_idx = 0
num_queries = int(input_data[q_idx])
q_idx += 1
# 预处理阶乘
fact = array('I', [1]) * (MAXM + 1)
for i in range(1, MAXM + 1):
fact[i] = (fact[i - 1] * i) % MOD
# 逆元计算辅助函数
def mod_pow(a, e):
return pow(a, e, MOD)
ifact = array('I', [1]) * (MAXM + 1)
ifact[MAXM] = mod_pow(fact[MAXM], MOD - 2)
for i in range(MAXM, 0, -1):
ifact[i - 1] = (ifact[i] * i) % MOD
# NTT 实现
def ntt(a, invert):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
wlen = pow(G, (MOD - 1) // length, MOD)
if invert:
wlen = pow(wlen, MOD - 2, MOD)
half = length >> 1
# 预计算当前层的 w 幂
w_powers = [1] * half
for k in range(1, half):
w_powers[k] = (w_powers[k - 1] * wlen) % MOD
for i in range(0, n, length):
for k in range(half):
u = a[i + k]
v = (a[i + k + half] * w_powers[k]) % MOD
a[i + k] = (u + v) % MOD
a[i + k + half] = (u - v + MOD) % MOD
length <<= 1
if invert:
inv_n = mod_pow(n, MOD - 2)
for i in range(n):
a[i] = (a[i] * inv_n) % MOD
def multiply_trunc(A, Bv, need):
if need <= 0 or not A or not Bv:
return array('I', [0] * need)
nA, nB = min(len(A), need), min(len(Bv), need)
if nA * nB <= 2000: # 小规模使用朴素卷积
res = [0] * need
for i in range(nA):
ai = A[i]
for j in range(min(nB, need - i)):
res[i + j] = (res[i + j] + ai * Bv[j]) % MOD
return array('I', res)
sz = 1
while sz < nA + nB - 1: sz <<= 1
fa = list(A[:nA]) + [0] * (sz - nA)
fb = list(Bv[:nB]) + [0] * (sz - nB)
ntt(fa, False)
ntt(fb, False)
for i in range(sz):
fa[i] = (fa[i] * fb[i]) % MOD
ntt(fa, True)
return array('I', fa[:need])
def poly_inv(f, n):
g = [mod_pow(f[0], MOD - 2)]
length = 1
while length < n:
length <<= 1
sz = length << 1
F = list(f[:min(len(f), length)]) + [0] * (sz - min(len(f), length))
Gv = g + [0] * (sz - len(g))
ntt(F, False)
ntt(Gv, False)
for i in range(sz):
F[i] = Gv[i] * (2 - F[i] * Gv[i] % MOD + MOD) % MOD
ntt(F, True)
g = F[:min(length, n)]
return array('I', g)
# 预处理分块多项式
suffixPoly = [] # 将存储 array
blockPoly = []
for b in range(NB + 1):
if b == 0:
suffixPoly.append(array('I'))
blockPoly.append(array('I'))
continue
K0, C = b * B, B
poly = [0] * (C + 1)
poly[0] = 1
current_suffix = array('I')
for r in range(C):
if r > 0:
factor = (K0 - r + 1) % MOD
for e in range(r, 0, -1):
poly[e] = (poly[e] - factor * poly[e - 1]) % MOD
for e in range(r + 1):
current_suffix.append(poly[e] % MOD)
suffixPoly.append(current_suffix)
factor = (K0 - C + 1) % MOD
R = list(poly[:C]) + [0]
for e in range(C, 0, -1):
R[e] = (R[e] - factor * R[e - 1]) % MOD
blockPoly.append(array('I', [x % MOD for x in R]))
# 建立基础序列
baseSeries = [array('I')] * (NB + 1)
H = array('I', [1])
for b in range(1, NB + 1):
K0 = b * B
need = MAXA - K0 + B + 1
invR = poly_inv(blockPoly[b], need)
H = multiply_trunc(H[:min(len(H), need)], invR, need)
baseSeries[b] = H
def get_stirling2(N, K):
if K < 0 or K > N: return 0
if K == 0: return 1 if N == 0 else 0
b = (K + B - 1) // B
K0 = b * B
r, d = K0 - K, N - K
F = baseSeries[b]
C = suffixPoly[b]
start = r * (r + 1) // 2
res = 0
lim = min(r, d)
for e in range(lim + 1):
res = (res + C[start + e] * F[d - e]) % MOD
return res
# 在线查询
last = 0
results = []
for _ in range(num_queries):
a = int(input_data[q_idx]) ^ last
b = int(input_data[q_idx + 1]) ^ last
c = int(input_data[q_idx + 2]) ^ last
q_idx += 3
N, M, K = a - 1, b - 1, c - 1
if K < 0 or K > N or K > M:
ans = 0
else:
s = get_stirling2(N, K)
if s == 0:
ans = 0
else:
falling = (fact[M] * ifact[M - K]) % MOD
invPow = 1 if N == 0 else mod_pow(M, MOD - 1 - N)
ans = (s * falling % MOD * invPow) % MOD
results.append(str(ans))
last = ans
sys.stdout.write('\n'.join(results) + '\n')
if __name__ == '__main__':
solve()关键点说明:
分块斯特林数计算:
- 该算法核心在于:
。 - 利用分块,将
靠近某个预计算好的块终点 。 - 通过
,将求大范围斯特林数转化为一个基础多项式( baseSeries)与一个低次多项式(suffixPoly,次数不超过)的乘积。
- 该算法核心在于:
内存优化:
array.array('I', ...)使用 4 字节无符号整数,比 Python 列表(每个元素至少 28 字节)节省约 7 倍内存。这对于存储规模的阶乘和大量的分块多项式至关重要。
性能瓶颈:
- 阶乘循环:Python 的
for循环处理次计算约需 1-1.5 秒。 - NTT:Python 无法像 C++ 那样高效利用 NTT,但好在
poly_inv只在预处理执行约 40 次。 - 在线解码:通过
last变量记录上一次答案,符合题目强制在线的要求。
- 阶乘循环:Python 的
分块大小 (B):
- 代码中设
。如果遇到超时,可尝试调小 (如 800-1000)以平衡预处理和查询的时间,但注意这会增加 NB,从而增加baseSeries的 NTT 处理次数。
- 代码中设