#include<bits/stdc++.h>
using namespace std;
#define int long long
#define rep(i,j,k) for(int i=(j);i<=(k);i++)
#define per(i,j,k) for(int i=(j);i>=(k);i--)
#define mp make_pair
#define pb push_back
#define fi first
#define se second
typedef vector<int> vi;
typedef pair<int,int> pi;
namespace ntt{// use ntt::init
const int LG=21, S=1<<LG, mod=998244353, i2=(mod+1)/2;
inline int md(const int &x){
return x>=mod? x-mod: x;
}
int qpow(int x,int y=mod-2){
int res=1;
while(y){
if(y%2) res=res*x%mod;
x=x*x%mod;
y/=2;
}
return res;
}
int P[S];
void init(){
rep(i,0,LG-1){
int stp=qpow(3, (mod-1)>>(i+1));
P[1<<i]=1;
rep(j,(1<<i)+1,(2<<i)-1){
P[j]=P[j-1]*stp%mod;
}
}
}
void DIF(vi &a,int lg){
int U=1<<lg;
per(i,lg-1,0){
int len=1<<i;
for(int j=0;j<U;j+=len*2){
int idx=len;
rep(k,j,j+len-1){
int A=a[k], B=a[k+len];
a[k]=md(A+B), a[k+len]=(A-B+mod)*P[idx++]%mod;
}
}
}
}
void DIT(vi &a,int lg){
int U=1<<lg;
rep(i,0,lg-1){
int len=1<<i;
for(int j=0;j<U;j+=len*2){
int idx=len;
rep(k,j,j+len-1){
int A=a[k], B=a[k+len]*P[idx++]%mod;
a[k]=md(A+B), a[k+len]=md(A-B+mod);
}
}
}
int iv=qpow(i2, lg);
for(int &x:a){
(x*= iv )%=mod;
}
reverse(a.begin()+1, a.end());
}
vi conv(vi a, vi b){
int len=a.size()+b.size()-1;
int lg=__lg(len-1)+1;
a.resize(1<<lg), b.resize(1<<lg);
DIF(a, lg), DIF(b, lg);
rep(i,0,(1<<lg)-1){
(a[i]*= b[i] )%=mod;
}
DIT(a, lg);
a.resize(len);
return a;
}
}
using ntt::conv;
void poly_multiply(unsigned *a, signed n, unsigned *b, signed m, unsigned *c){
ntt::init();
vi A(n+1), B(m+1), C;
rep(i,0,n) A[i] = a[i];
rep(i,0,m) B[i] = b[i];
C = conv(A, B);
rep(i,0,n+m) c[i] = C[i];
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 173.41 ms | 86 MB + 188 KB | Accepted | Score: 100 | 显示更多 |