#include<iostream>
using i32 = int;
using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
//蒙哥马利模空间
namespace Montgo{
//32位蒙哥马利约减器
struct Mont32{
u32 Mod, Mod2, Inv, Neg_Inv, R2, R3;
constexpr u32 reduce (u64 x)const{return (x + u64(u32(x) * Neg_Inv) * Mod) >> 32;}
constexpr u32 mul(u32 x,u32 y)const{return reduce(u64(x) * y);}
//n应当是一个奇数(以与2^32互质), n应当小于2^30 防止溢出
constexpr Mont32(u32 n):Mod(n), Mod2(n << 1), Inv(n), Neg_Inv(), R2((-u64(n)) % n), R3(){
for (int i = 0; i < 5; ++i){Inv *= 2 - n * Inv;}
Neg_Inv = -Inv, R3 = mul(R2, R2);
}
constexpr u32 In(u32 x)const{return mul(x, R2);}
constexpr u32 In_In(u32 x)const{return mul(x, R3);}
constexpr u32 Out(u32 x)const{return (u64(x * Neg_Inv) * Mod) >> 32;}
};
}
//定义域(交换除环)Z
//Z位于蒙哥马利模空间下 且 ∈[0, mod2)
namespace field_Z{
constexpr u32 mod = 998244353;
//mod * 2
constexpr u32 mod2 = mod * 2;
constexpr Montgo::Mont32 Space(mod);
using Z = u32;
//进入和离开域Z
constexpr Z InZ(u32 x){return Space.In(x);}
constexpr u32 OutZ(Z x){return Space.Out(x);}
//模意义下的0
constexpr Z zero_Z(0);
//模意义下的1
constexpr Z one_Z(Space.In(1u));
//模意义下不存在 注意:不应再对它进行任何运算
constexpr Z not_exist_Z = -1;
constexpr bool isgood(Z x){return x < mod2;}
//对于Z的基本运算
namespace calc{
constexpr u32 shrink(u32 x){return x >= mod ? x - mod : x;}
constexpr u32 shrink2(u32 x){return x >= mod2 ? x - mod2 : x;}
constexpr u32 dilate(u32 x){return x >> 31 ? x + mod : x;}
constexpr u32 dilate2(u32 x){return x >> 31 ? x + mod2 : x;}
constexpr Z mulZ(Z x, Z y){return Space.mul(x, y);}
constexpr Z powZ(Z a,u32 b){
Z r(one_Z);
while(b){
if(b & 1){r = mulZ(r, a);}
a = mulZ(a, a), b >>= 1;
}
return r;
}
constexpr Z addZ(Z x, Z y){return dilate2(x + y - mod2);}
constexpr Z subZ(Z x, Z y){return dilate2(x - y);}
constexpr Z mulZ_strict(Z x, Z y){return dilate(mulZ(x, y) - mod);}
constexpr Z In_InZ(u32 x){return Space.In_In(x);}
constexpr Z invZ(Z x){return powZ(x, mod - 2);}
}
}
//多项式主体
namespace poly{
//多项式主体::引入对多项式的基础支持
namespace poly_base{
//多项式基础支持::引入所处的域——Z
using namespace field_Z;
//按位向上取整
inline constexpr int bit_ceil(int x){
return 1 << (std::__lg(x - 1) + 1);
}
//多项式基础::快速数论变换的基础
namespace fast_number_theoretic_transform_base{
//引入在模域下的计算
using namespace field_Z::calc;
//mod = 2 ^ 23 * 7 * 17 + 1
constexpr int mp2(23);
//原根为3
constexpr Z _g(InZ(3));
struct P_R_Tab{
Z t[mp2 + 1];
constexpr P_R_Tab(Z G):t(){
t[mp2] = powZ(G, (mod - 1) >> mp2);
for(int i = mp2 - 1; ~i; --i){t[i] = mulZ(t[i+1], t[i+1]);}
}
constexpr Z operator [] (int i) const {return t[i];}
};
constexpr P_R_Tab __g(_g),__g_I(invZ(_g));
int size_W(0);
Z *Wn(nullptr), *Wn_I(nullptr);
//进行单位根预处理 你不需要也不应该调用此函数 单位根预处理在ntt时会自动进行
void _ntt_init(int lim){
if(lim > size_W){
if(Wn != nullptr){
delete[] Wn;
}
size_W = lim, Wn = new Z[2 * lim], Wn_I = Wn + lim;
Wn[0] = Wn[1] = Wn_I[0] = Wn_I[1] = one_Z;
for(int i = 2, R = 2, i2 = 4; i < lim; i <<= 1, ++R, i2 <<= 1){
Z g_w(__g[R]), g_w_I(__g_I[R]);
for(int k = i; k < i2; k += 2){
Wn[k + 1] = mulZ(Wn[k] = Wn[k >> 1], g_w);
Wn_I[k + 1] = mulZ_strict(Wn_I[k] = Wn_I[k >> 1], g_w_I);
}
}
}
}
}
}using namespace poly_base;
}
namespace poly{
//多项式主体::引入基于转置原理的(DIF式)NTT和(DIT式)INTT
namespace fast_number_theoretic_transform_core{
//引入快速数论变换的基础
using namespace fast_number_theoretic_transform_base;
//快速数论变换 (DIF)
void NTT(Z* A, int lim){
_ntt_init(lim);
#define FLY(o) {Z x(A[j + k + o]), y(mod2 - a[j + k + o]);a[j + k + o] = mulZ(x + y, wn[k + o]), A[j + k + o] = dilate2(x - y);}
for(int i(lim >> 1), R(lim); i >= 4; i >>= 1, R >>= 1){
Z *wn(Wn + i), *a(A + i);
for(int j = 0; j < lim; j += R){
for(int k = 0; k < i; k+=4){
FLY(0)FLY(1)FLY(2)FLY(3)
}
}
}
//i == 2
{
Z *wn(Wn + 2), *a = A + 2;
constexpr int k = 0;
for(int j = 0; j < lim; j += 4){
FLY(0)FLY(1)
}
}
#undef FLY
//i == 1
{
for(int j = 0; j < lim; j += 4){
{
Z x = A[j + 0], y = A[j + 1];
A[j + 1] = dilate2(x - y), A[j + 0] = shrink2(x + y);
}
{
Z x = A[j + 2], y = A[j + 3];
A[j + 3] = dilate2(x - y), A[j + 2] = shrink2(x + y);
}
}
}
}
//快速数论变换.逆 (DIT) fixes表示是否进行低代价的修正(*=R)
template<bool fixes = false>void INTT(Z* A, int lim){
_ntt_init(lim);
//i == 1
{
for(int j = 0; j < lim; j += 4){
{
Z x(A[j + 0]), y(A[j + 1]);
A[j + 0] = x + y, A[j + 1] = x - y + mod2;
}
{
Z x(A[j + 2]), y(A[j + 3]);
A[j + 2] = x + y, A[j + 3] = x - y + mod2;
}
}
}
#define FLY(o) {Z x(dilate2(A[j + k + o] - mod2)), y(mulZ(a[j + k + o], wn[k + o]));a[j + k + o] = x - y + mod2, A[j + k + o] = x + y;}
//i == 2
{
Z *wn(Wn_I + 2), *a = A + 2;
constexpr int k = 0;
for(int j = 0;j < lim; j += 4){
FLY(0)FLY(1)
}
}
for(int i(4), R(8); i < lim; i <<= 1, R <<= 1){
Z *wn(Wn_I + i), *a(A + i);
for(int j = 0; j < lim; j += R){
for(int k = 0; k < i; k += 4){
FLY(0)FLY(1)FLY(2)FLY(3)
}
}
}
#undef FLY
if constexpr (fixes){
Z invt(shrink(In_InZ(mod - ((mod - 1) / lim))));
for(int i = 0; i < lim; ++i){
A[i] = mulZ_strict(A[i], invt);
}
}
else{
Z invt(shrink(InZ(mod - ((mod - 1) / lim))));
for(int i = 0; i < lim; ++i){
A[i] = mulZ(A[i], invt);
}
}
}
}
using fast_number_theoretic_transform_core::NTT;
using fast_number_theoretic_transform_core::INTT;
//点乘
void dot(Z* A, int n, Z* B){
for(int i = 0; i < n; ++i){
A[i] = calc::mulZ(A[i], B[i]);
}
}
//卷积 fixes表示是否进行修正(*=R)
template<bool fixes = false>void Conv(Z* A, int lim, Z* B){
NTT(A, lim), NTT(B, lim), dot(A, lim, B), INTT<fixes>(A, lim);
}
//多项式乘法
template<bool clr = true>void Mul(Z* A, int n, Z* B, int m){
int lim(std::max<int>(bit_ceil(n + m + 1), 4));
if constexpr(clr){
std::fill(A + n + 1, A + lim, zero_Z), std::fill(B + m + 1, B + lim, zero_Z);
}
Conv<true>(A, lim, B);
}
}
#include <chrono>
struct Timer{
std::string str;
std::chrono::system_clock::time_point lst;
Timer():str(), lst(std::chrono::system_clock::now()){
}
Timer(const std::string &s):str(s + ' '), lst(std::chrono::system_clock::now()){
}
void start(){
lst = std::chrono::system_clock::now();
}
void stop(std::ostream& outf = std::clog){
std::chrono::duration<long double, std::milli> tott = (std::chrono::system_clock::now() - lst);
outf << "\nThe timer " << str << "stoped.\n";
char bbuf[20];
snprintf(bbuf,20,"%.6Lf",tott.count());
outf << "It passed by " << std::string(bbuf) << "ms until stop." << std::endl;
}
};
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
int lim = poly::bit_ceil(n + m + 1);
u32 *f = new u32[lim];
u32 *g = new u32[lim];
std::copy(a, a + n + 1, f), std::copy(b, b + m + 1, g);
poly::Mul(f,n,g,m);
std::copy(f,f + n + m + 1,c);
delete []f;
delete []g;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 100.668 ms | 39 MB + 680 KB | Accepted | Score: 100 | 显示更多 |