#include <algorithm>
#include <cstring>
typedef unsigned int uint;
typedef long long unsigned int uint64;
constexpr uint Max_size = 1 << 21 | 5;
constexpr uint g = 3, Mod = 998244353;
inline uint norm_2(const uint x) {
return x < Mod * 2 ? x : x - Mod * 2;
}
inline uint norm(const uint x) {
return x < Mod ? x : x - Mod;
}
struct Z {
uint v;
Z() { }
Z(const uint _v) : v(_v) { }
};
inline Z operator+(const Z x1, const Z x2) {
return norm(x1.v + x2.v);
}
inline Z operator-(const Z x1, const Z x2) {
return norm(x1.v + Mod - x2.v);
}
inline Z operator-(const Z x) {
return x.v ? Mod - x.v : 0;
}
inline Z operator*(const Z x1, const Z x2) {
return static_cast<uint64>(x1.v) * x2.v % Mod;
}
inline Z &operator+=(Z &x1, const Z x2) {
return x1 = x1 + x2;
}
inline Z &operator-=(Z &x1, const Z x2) {
return x1 = x1 - x2;
}
inline Z &operator*=(Z &x1, const Z x2) {
return x1 = x1 * x2;
}
Z Power(Z Base, int Exp) {
Z res = 1;
for (; Exp; Base *= Base, Exp >>= 1)
if (Exp & 1)
res *= Base;
return res;
}
int size;
std::pair<uint, uint> w_p[Max_size];
inline uint mult_Shoup_2(const uint x, const std::pair<uint, uint> y_p) {
uint q = static_cast<uint64>(x) * y_p.second >> 32;
return x * y_p.first - q * Mod;
}
inline uint mult_Shoup(const uint x, const std::pair<uint, uint> y_p) {
return norm(mult_Shoup_2(x, y_p));
}
inline uint mult_Shoup_q(const uint x, const std::pair<uint, uint> y_p) {
uint q = static_cast<uint64>(x) * y_p.second >> 32;
return q + (x * y_p.first - q * Mod >= Mod);
}
void init_w(const int n) {
for (size = 2; size < n; size <<= 1)
;
uint pr = Power(g, (Mod - 1) / size).v;
std::pair<uint, uint> pr_p = {pr, (static_cast<uint64>(pr) << 32) / Mod};
uint pr_r = (static_cast<uint64>(pr) << 32) % Mod;
size >>= 1;
w_p[size] = {1, (1ULL << 32) / Mod};
#define compute(r_p, b_p)\
do\
{\
uint x = b_p.first;\
uint64 p = static_cast<uint64>(x) * pr_p.second;\
uint q = p >> 32;\
r_p = {norm(x * pr_p.first - q * Mod), static_cast<uint>(p) + mult_Shoup_q(pr_r, b_p)};\
} while (0)
if (size <= 4) {
for (int i = 1; i != size; ++i)
compute(w_p[size + i], w_p[size + i - 1]);
} else {
for (int i = 1; i != 8; ++i)
compute(w_p[size + i], w_p[size + i - 1]);
pr_p = w_p[size + 4], pr_r = -pr_p.second * Mod;
for (int i = 8; i != size; i += 4) {
compute(w_p[size + i + 0], w_p[size + i - 4]);
compute(w_p[size + i + 1], w_p[size + i - 3]);
compute(w_p[size + i + 2], w_p[size + i - 2]);
compute(w_p[size + i + 3], w_p[size + i - 1]);
}
}
for (int i = size - 1; i; --i)
w_p[i] = w_p[i * 2];
size <<= 1;
}
void DFT_fr_2(Z _A[], const int L) {
if (L == 1)
return;
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);\
a = x, b = y;\
} while (0)
if (L == 2) {
butterfly1(A[0], A[1]);
return;
}
#define butterfly(a, b, _w_p)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w_p);\
a = x, b = y;\
} while (0)
if (L == 4) {
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], w_p[3]);
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
return;
}
for (int d = L >> 1; d != 4; d >>= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4) {
butterfly(A[i + j + 0], A[i + d + j + 0], w_p[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], w_p[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], w_p[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], w_p[d + j + 3]);
}
for (int i = 0; i != L; i += 8) {
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], w_p[5]);
butterfly(A[i + 2], A[i + 6], w_p[6]);
butterfly(A[i + 3], A[i + 7], w_p[7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], w_p[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], w_p[3]);
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
}
#undef butterfly1
#undef butterfly
}
void DFT_fr(Z _A[], const int L) {
DFT_fr_2(_A, L);
for (int i = 0; i != L; ++i)_A[i] = norm(_A[i].v);
}
void IDFT_fr(Z _A[], const int L) {
if (L == 1)
return;
uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = norm_2(_b);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 2) {
butterfly1(A[0], A[1]);
A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
return;
}
#define butterfly(a, b, _w_p)\
do\
{\
uint _a = a, _b = b;\
uint x = norm_2(_a), t = mult_Shoup_2(_b, _w_p);\
a = x + t, b = x + Mod * 2 - t;\
} while (0)
if (L == 4) {
butterfly1(A[0], A[1]);
butterfly1(A[2], A[3]);
butterfly1(A[0], A[2]);
butterfly(A[1], A[3], w_p[3]);
std::swap(A[1], A[3]);
for (int i = 0; i != L; ++i) {
uint64 m = -A[i] & 3;
A[i] = norm((A[i] + m * Mod) >> 2);
}
return;
}
for (int i = 0; i != L; i += 8) {
butterfly1(A[i + 0], A[i + 1]);
butterfly1(A[i + 2], A[i + 3]);
butterfly1(A[i + 4], A[i + 5]);
butterfly1(A[i + 6], A[i + 7]);
butterfly1(A[i + 0], A[i + 2]);
butterfly(A[i + 1], A[i + 3], w_p[3]);
butterfly1(A[i + 4], A[i + 6]);
butterfly(A[i + 5], A[i + 7], w_p[3]);
butterfly1(A[i + 0], A[i + 4]);
butterfly(A[i + 1], A[i + 5], w_p[5]);
butterfly(A[i + 2], A[i + 6], w_p[6]);
butterfly(A[i + 3], A[i + 7], w_p[7]);
}
for (int d = 8; d != L; d <<= 1)
for (int i = 0; i != L; i += d << 1)
for (int j = 0; j != d; j += 4) {
butterfly(A[i + j + 0], A[i + d + j + 0], w_p[d + j + 0]);
butterfly(A[i + j + 1], A[i + d + j + 1], w_p[d + j + 1]);
butterfly(A[i + j + 2], A[i + d + j + 2], w_p[d + j + 2]);
butterfly(A[i + j + 3], A[i + d + j + 3], w_p[d + j + 3]);
}
#undef butterfly1
#undef butterfly
std::reverse(A + 1, A + L);
int k = __builtin_ctz(L);
for (int i = 0; i != L; ++i) {
uint64 m = -A[i] & (L - 1);
A[i] = norm((A[i] + m * Mod) >> k);
}
}
Z A[Max_size], B[Max_size];
void poly_multiply(uint _A[], int N, uint _B[], int M, uint _C[]) {
memcpy(reinterpret_cast<uint *>(A), _A, (N + 1) * sizeof(uint));
memcpy(reinterpret_cast<uint *>(B), _B, (M + 1) * sizeof(uint));
int L;
for (L = 2; L <= N + M; L <<= 1)
;
init_w(L);
DFT_fr_2(A, L), DFT_fr_2(B, L);
for (int i = 0; i != L; ++i)
A[i] *= B[i];
IDFT_fr(A, L);
memcpy(_C, reinterpret_cast<uint *>(A), (N + M + 1) * sizeof(uint));
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 112.117 ms | 39 MB + 656 KB | Accepted | Score: 100 | 显示更多 |