#include <vector>
#include <complex>
#include <iostream>
#include <cassert>
#include <cstring>
#include <ctime>
#include <cstddef>
#include <cstdint>
#include <climits>
#include <string>
#include <array>
#include <fstream>
#include <type_traits>
#include <immintrin.h>
#pragma GCC optimize("inline")
#pragma GCC target("avx2")
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr size_t L1_BYTE = size_t(1) << 24; // L1 cache size, change this if you know your cache size.
constexpr size_t L2_BYTE = size_t(1) << 20; // L2 cache size, change this if you know your cache size.
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
// bits of 1, equals to 2^bits - 1
template <typename T>
constexpr T all_one(int bits)
{
T temp = T(1) << (bits - 1);
return temp - 1 + temp;
}
// Leading zeros
template <typename IntTy>
constexpr int hint_clz(IntTy x)
{
constexpr uint32_t MASK32 = uint32_t(0xFFFF) << 16;
int res = sizeof(IntTy) * CHAR_BIT;
if (x & MASK32)
{
res -= 16;
x >>= 16;
}
if (x & (MASK32 >> 8))
{
res -= 8;
x >>= 8;
}
if (x & (MASK32 >> 12))
{
res -= 4;
x >>= 4;
}
if (x & (MASK32 >> 14))
{
res -= 2;
x >>= 2;
}
if (x & (MASK32 >> 15))
{
res -= 1;
x >>= 1;
}
return res - x;
}
// Leading zeros
constexpr int hint_clz(uint64_t x)
{
if (x & (uint64_t(0xFFFFFFFF) << 32))
{
return hint_clz(uint32_t(x >> 32));
}
return hint_clz(uint32_t(x)) + 32;
}
// Integer bit length
template <typename IntTy>
constexpr int hint_bit_length(IntTy x)
{
if (x == 0)
{
return 0;
}
return sizeof(IntTy) * CHAR_BIT - hint_clz(x);
}
// Integer log2
template <typename IntTy>
constexpr int hint_log2(IntTy x)
{
return (sizeof(IntTy) * CHAR_BIT - 1) - hint_clz(x);
}
constexpr int hint_ctz(uint32_t x)
{
int r = 31;
x &= (-x);
if (x & 0x0000FFFF)
{
r -= 16;
}
if (x & 0x00FF00FF)
{
r -= 8;
}
if (x & 0x0F0F0F0F)
{
r -= 4;
}
if (x & 0x33333333)
{
r -= 2;
}
if (x & 0x55555555)
{
r -= 1;
}
return r;
}
constexpr int hint_ctz(uint64_t x)
{
if (x & 0xFFFFFFFF)
{
return hint_ctz(uint32_t(x));
}
return hint_ctz(uint32_t(x >> 32)) + 32;
}
// Fast power
template <typename T, typename T1>
constexpr T qpow(T m, T1 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
}
m *= m;
n >>= 1;
}
return result;
}
// Fast power with mod
template <typename T, typename T1>
constexpr T qpow(T m, T1 n, T mod)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
result %= mod;
}
m *= m;
m %= mod;
n >>= 1;
}
return result;
}
// Get cloest power of 2 that not larger than n
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
// Get cloest power of 2 that not smaller than n
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// x + y = sum with carry
template <typename UintTy>
constexpr UintTy add_half(UintTy x, UintTy y, bool &cf)
{
x = x + y;
cf = (x < y);
return x;
}
// x - y = diff with borrow
template <typename UintTy>
constexpr UintTy sub_half(UintTy x, UintTy y, bool &bf)
{
y = x - y;
bf = (y > x);
return y;
}
// x + y + cf = sum with carry
template <typename UintTy>
constexpr UintTy add_carry(UintTy x, UintTy y, bool &cf)
{
UintTy sum = x + cf;
cf = (sum < x);
sum += y; // carry
cf = cf || (sum < y); // carry
return sum;
}
// x - y - bf = diff with borrow
template <typename UintTy>
constexpr UintTy sub_borrow(UintTy x, UintTy y, bool &bf)
{
UintTy diff = x - bf;
bf = (diff > x);
y = diff - y; // borrow
bf = bf || (y > diff); // borrow
return y;
}
// a * x + b * y = gcd(a,b)
template <typename IntTy>
constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
IntTy k = a / b;
IntTy g = exgcd(b, a - k * b, y, x);
y -= k * x;
return g;
}
// return n^-1 mod mod
template <typename IntTy>
constexpr IntTy mod_inv(IntTy n, IntTy mod)
{
n %= mod;
IntTy x = 0, y = 0;
exgcd(n, mod, x, y);
if (x < 0)
{
x += mod;
}
else if (x >= mod)
{
x -= mod;
}
return x;
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
namespace simd
{
class Int256
{
public:
Int256() : data(_mm256_setzero_si256()) {}
Int256(__m256i data) : data(data) {}
Int256(int data) : data(_mm256_set1_epi32(data)) {}
Int256 add32(Int256 input) const
{
return _mm256_add_epi32(data, input.data);
}
Int256 sub32(Int256 input) const
{
return _mm256_sub_epi32(data, input.data);
}
Int256 add64(Int256 input) const
{
return _mm256_add_epi64(data, input.data);
}
Int256 sub64(Int256 input) const
{
return _mm256_sub_epi64(data, input.data);
}
Int256 minU32(Int256 input) const
{
return _mm256_min_epu32(data, input.data);
}
Int256 maxU32(Int256 input) const
{
return _mm256_max_epu32(data, input.data);
}
Int256 minI32(Int256 input) const
{
return _mm256_min_epi32(data, input.data);
}
Int256 maxI32(Int256 input) const
{
return _mm256_max_epi32(data, input.data);
}
Int256 mullo32To64(Int256 input) const
{
return _mm256_mul_epu32(data, input.data);
}
Int256 evenEle32() const
{
return blend32<0b10101010>(*this, Int256{});
}
template <int N>
Int256 lShift64() const
{
return _mm256_slli_epi64(data, N);
}
template <int N>
Int256 rShift64() const
{
return _mm256_srli_epi64(data, N);
}
template <int M>
static Int256 blend32(Int256 a, Int256 b)
{
return _mm256_blend_epi32(a.data, b.data, M);
}
template <typename T>
void loadu(const T *p)
{
data = _mm256_loadu_si256((const __m256i *)p);
}
template <typename T>
void load(const T *p)
{
data = _mm256_load_si256((const __m256i *)p);
}
template <typename T>
void storeu(T *p) const
{
_mm256_storeu_si256((__m256i *)p, data);
}
template <typename T>
void store(T *p) const
{
_mm256_store_si256((__m256i *)p, data);
}
operator __m256i() const
{
return data;
}
uint32_t nthU32(size_t i) const
{
return _mm256_extract_epi32(data, i);
}
uint64_t nthU64(size_t i) const
{
return _mm256_extract_epi64(data, i);
}
void printU32() const
{
std::cout << "[" << nthU32(0) << "," << nthU32(1)
<< "," << nthU32(2) << "," << nthU32(3)
<< "," << nthU32(4) << "," << nthU32(5)
<< "," << nthU32(6) << "," << nthU32(7) << "]" << std::endl;
}
void printU64() const
{
std::cout << "[" << nthU64(0) << "," << nthU64(1)
<< "," << nthU64(2) << "," << nthU64(3) << "]" << std::endl;
}
private:
__m256i data;
};
}
namespace modint
{
class Montgomery32
{
public:
using Ui32X8 = hint::simd::Int256;
Montgomery32(uint32_t m) : mod(m), modx8(m), mod2x8(m * 2)
{
uint32_t inv = inv_mod2pow(mod, 32);
r = (uint64_t(1) << 32) % m;
r2 = uint64_t(r) * r % m;
ninv = (uint64_t(1) << 32) - inv;
assert(inv * mod == 1);
assert(inv + ninv == 0);
r2x8 = Ui32X8(r2);
ninvx8 = Ui32X8(ninv);
}
uint32_t toMontgomery(uint32_t x) const
{
return redcLazy(uint64_t(x) * r2);
}
uint32_t fromMontgomery(uint32_t x) const
{
return redc(x);
}
uint32_t add(uint32_t x, uint32_t y) const
{
return x + y;
}
uint32_t sub(uint32_t x, uint32_t y) const
{
return x - y + mod * 2;
}
uint32_t addNorm2(uint32_t x, uint32_t y) const
{
return norm2(x + y);
}
uint32_t subNorm2(uint32_t x, uint32_t y) const
{
y = x - y;
return y > x ? y + mod * 2 : y;
}
uint32_t norm(uint32_t x) const
{
return x >= mod ? x - mod : x;
}
uint32_t norm2(uint32_t x) const
{
return x >= mod * 2 ? x - mod * 2 : x;
}
uint32_t mul(uint32_t x, uint32_t y) const
{
return redcLazy(uint64_t(x) * y);
}
uint32_t mulNorm(uint32_t x, uint32_t y) const
{
return redc(uint64_t(x) * y);
}
uint32_t redcLazy(uint64_t x) const
{
uint32_t prod = uint32_t(x) * ninv;
return (uint64_t(prod) * mod + x) >> 32;
}
uint32_t redc(uint64_t x) const
{
return norm(redcLazy(x));
}
uint32_t inv(uint32_t x) const
{
return pow(x, mod - 2);
}
template <typename T, typename Ti>
T pow(T x, Ti index) const
{
T res = montOne();
while (true)
{
if (index & 1)
{
res = mul(res, x);
}
index >>= 1;
if (index == 0)
{
break;
}
x = mul(x, x);
}
return res;
}
uint32_t montOne() const
{
return r;
}
uint32_t montR() const
{
return r2;
}
uint32_t getMod() const
{
return mod;
}
Ui32X8 toMontgomery(Ui32X8 x) const
{
return mul(x, r2x8);
}
Ui32X8 fromMontgomery(Ui32X8 x) const
{
return redc(x.evenEle32(), x.rShift64<32>());
}
Ui32X8 add(Ui32X8 x, Ui32X8 y) const
{
return x.add32(y);
}
Ui32X8 sub(Ui32X8 x, Ui32X8 y) const
{
return x.sub32(y).add32(mod2x8);
}
Ui32X8 addNorm2(Ui32X8 x, Ui32X8 y) const
{
return norm2(x.add32(y));
}
Ui32X8 subNorm2(Ui32X8 x, Ui32X8 y) const
{
return negNorm2(x.sub32(y));
}
Ui32X8 norm(Ui32X8 x) const
{
Ui32X8 dif = x.sub32(modx8);
return x.minU32(dif);
}
Ui32X8 norm2(Ui32X8 x) const
{
Ui32X8 dif = x.sub32(mod2x8);
return x.minU32(dif);
}
Ui32X8 negNorm2(Ui32X8 x) const
{
Ui32X8 sum = x.add32(mod2x8);
return x.minU32(sum);
}
Ui32X8 mul(Ui32X8 x, Ui32X8 y) const
{
Ui32X8 prodo = x.mullo32To64(y);
Ui32X8 prode = x.rShift64<32>().mullo32To64(y.rShift64<32>());
return redcLazy(prodo, prode);
}
Ui32X8 redcLazy(Ui32X8 e, Ui32X8 o) const
{
Ui32X8 prod0 = e.mullo32To64(ninvx8);
Ui32X8 prod1 = o.mullo32To64(ninvx8);
prod0 = prod0.mullo32To64(modx8).add64(e);
prod1 = prod1.mullo32To64(modx8).add64(o);
prod0 = prod0.rShift64<32>();
return Ui32X8::blend32<0b10101010>(prod0, prod1);
}
Ui32X8 redc(Ui32X8 e, Ui32X8 o) const
{
return norm(redcLazy(e, o));
}
private:
Ui32X8 modx8, ninvx8, mod2x8, r2x8;
uint32_t mod, ninv, r, r2;
};
}
namespace transform
{
namespace ntt
{
using namespace simd;
using namespace modint;
namespace radix2_avx
{
template <uint32_t ROOT, typename ModIntType, typename T>
inline T mul_w41(const T &n)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
return n * T(W_4_1);
}
template <uint32_t ROOT, typename ModIntType, typename T>
inline T mul_w81(const T &n)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
return n * T(W_8_1);
}
template <uint32_t ROOT, typename ModIntType, typename T>
inline T mul_w83(const T &n)
{
constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8 * 3);
return n * T(W_8_3);
}
// in: in_out0<4p, in_ou1<4p
// out: in_out0<4p, in_ou1<4p
template <typename ModIntType>
inline void dit_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
{
auto x = in_out0.largeNorm();
auto y = in_out1 * omega;
in_out0 = x.add(y);
in_out1 = x.sub(y);
}
// in: in_out0<2p, in_ou1<4p
// out: in_out0<4p, in_ou1<4p
template <typename ModIntType>
inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::false_type)
{
auto x = in_out0;
auto y = in_out1 * omega;
in_out0 = x.add(y);
in_out1 = x.sub(y);
}
// in: in_out0<2p, in_ou1<4p
// out: in_out0<2p, in_ou1<2p
template <typename ModIntType>
inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::true_type)
{
auto x = in_out0;
auto y = in_out1 * omega;
in_out0 = x + y;
in_out1 = x - y;
}
// in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
// out: in_out0<2p or 4p, in_ou1<2p or 4p, in_out2<2p or 4p, in_ou3<2p or 4p
template <bool OUT2P, typename ModIntType>
inline void dit_butterfly2_i2424_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
dit_butterfly2_i24(in_out0, in_out2, omega0, std::integral_constant<bool, OUT2P>{});
dit_butterfly2_i24(in_out1, in_out3, omega1, std::integral_constant<bool, OUT2P>{});
}
// in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
// out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
template <typename ModIntType>
inline void dit_butterfly2_2layer_out3(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
dit_butterfly2_i24(in_out0, in_out2, omega0, std::true_type{});
in_out1 = in_out1 + in_out3 * omega1;
}
// in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
// out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
template <typename ModIntType>
inline void dit_butterfly2_2layer_out2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
in_out0 = in_out0 + in_out2 * omega0;
in_out1 = in_out1 + in_out3 * omega1;
}
// in: in_out0<2p, in_ou1<2p
// out: in_out0<2p, in_ou1<2p
template <typename ModIntType>
inline void dif_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
{
auto x = in_out0.add(in_out1);
auto y = in_out0.sub(in_out1);
in_out0 = x.largeNorm();
in_out1 = y * omega;
}
// in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
// out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
template <typename ModIntType>
inline void dif_butterfly2_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
dif_butterfly2(in_out0, in_out2, omega0);
dif_butterfly2(in_out1, in_out3, omega1);
dif_butterfly2(in_out0, in_out1, omega_last);
dif_butterfly2(in_out2, in_out3, omega_last);
}
// in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
// out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
template <typename ModIntType>
inline void dif_butterfly2_2layer_in2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
in_out2 = in_out0 * omega0;
in_out3 = in_out1 * omega1;
dif_butterfly2(in_out0, in_out1, omega_last);
dif_butterfly2(in_out2, in_out3, omega_last);
}
// in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
// out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
template <typename ModIntType>
inline void dif_butterfly2_2layer_in1(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
ModIntType omega0, ModIntType omega1, ModIntType omega_last)
{
in_out2 = in_out0 * omega0;
in_out1 = in_out0 * omega_last;
in_out3 = in_out2 * omega_last;
}
// template <typename ModIntType, uint32_t ROOT>
// static auto omegax8(size_t ntt_len, int factor, size_t begin = 0, bool inv = false)
// {
// using ModIntX8 = MontInt32X8<ModIntType>;
// alignas(32) ModIntType w_arr[8]{};
// ModIntType unit(qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / ntt_len * factor));
// if (inv)
// {
// unit = unit.inv();
// }
// ModIntType w(qpow(unit, begin));
// for (auto &&i : w_arr)
// {
// i = w;
// w = w * unit;
// }
// return ModIntX8(w_arr);
// }
struct NTT32AVX
{
using Int = uint32_t;
static constexpr size_t L1_LEN = L1_BYTE / (2 * sizeof(Int));
static constexpr size_t L2_LEN = L2_BYTE / (2 * sizeof(Int));
static constexpr int LOG_L1_LEN = hint_log2(L1_LEN);
template <int LOG_LEN>
class BinRevTable
{
public:
static constexpr size_t LEN = size_t(1) << LOG_LEN;
static constexpr size_t TABLE_LEN = LEN / 2;
BinRevTable(Int root, const Montgomery32 &mont, bool init_all) : cur_len(2)
{
table[0] = mont.montOne();
root = mont.toMontgomery(root);
for (size_t i = 1; i < TABLE_LEN; i *= 2)
{
table[i] = getOmega(LEN, LEN / 4 / i, root, mont);
}
if (init_all)
{
init(LEN, mont);
}
}
void init(size_t len, const Montgomery32 &mont)
{
size_t n = std::min(len, LEN) / 2;
for (size_t begin = cur_len; begin < n; begin *= 2)
{
Int unit = table[begin];
for (size_t i = begin + 1; i < begin * 2; i++)
{
table[i] = mont.mul(unit, table[i - begin]);
}
}
cur_len = n;
}
Int getRevOmega(size_t i) const
{
return table[i];
}
Int getOmega(size_t n, size_t index, uint32_t root, const Montgomery32 &mont)
{
return mont.pow(root, (mont.getMod() - 1) / n * index);
}
private:
Int table[TABLE_LEN];
size_t cur_len;
};
Montgomery32 mont;
BinRevTable<LOG_L1_LEN> table;
BinRevTable<LOG_L1_LEN> itable;
NTT32AVX(Int root, Int mod, bool init_all = true) : mont(mod), table(root, mont, init_all), itable(mod_inv<int64_t>(root, mod), mont, init_all) {}
void dit(Int in_out[], size_t ntt_len)
{
assert(ntt_len <= L1_LEN);
itable.init(ntt_len, mont);
for (size_t rank = 2; rank < ntt_len; rank *= 2)
{
size_t gap = rank / 2, omega_index = 1;
for (size_t i = 0; i < gap; i++)
{
Int x = in_out[i], y = in_out[gap + i];
in_out[i] = mont.addNorm2(x, y);
in_out[gap + i] = mont.subNorm2(x, y);
}
for (auto it = in_out + rank; it < in_out + ntt_len; it += rank, omega_index++)
{
const Int omega = itable.getRevOmega(omega_index);
for (size_t j = 0; j < gap; j++)
{
Int x = it[j], y = it[gap + j];
it[j] = mont.addNorm2(x, y);
it[gap + j] = mont.mul(mont.subNorm2(x, y), omega);
}
}
}
for (size_t i = 0; i < ntt_len / 2; i++)
{
Int x = in_out[i], y = in_out[ntt_len / 2 + i];
in_out[i] = mont.norm(mont.addNorm2(x, y));
in_out[ntt_len / 2 + i] = mont.norm(mont.subNorm2(x, y));
}
}
void dif(Int in_out[], size_t ntt_len)
{
assert(ntt_len <= L1_LEN);
table.init(ntt_len, mont);
for (size_t rank = ntt_len; rank >= 2; rank /= 2)
{
size_t gap = rank / 2, omega_index = 1;
for (size_t i = 0; i < gap; i++)
{
Int x = in_out[i], y = in_out[gap + i];
in_out[i] = mont.addNorm2(x, y);
in_out[gap + i] = mont.subNorm2(x, y);
}
for (auto it = in_out + rank; it < in_out + ntt_len; it += rank, omega_index++)
{
const Int omega = table.getRevOmega(omega_index);
for (size_t j = 0; j < gap; j++)
{
Int x = it[j], y = mont.mul(it[gap + j], omega);
it[j] = mont.addNorm2(x, y);
it[gap + j] = mont.subNorm2(x, y);
}
}
}
}
// void difL1X2(Int in_out1[], Int in_out2[], size_t ntt_len)
// {
// }
// void ditL1(Int in_out[], size_t ntt_len, size_t rank)
// {
// for (; rank <= ntt_len; rank *= 4)
// {
// size_t gap = rank / 4;
// for (size_t i = 0; i < ntt_len; i += rank)
// {
// }
// }
// }
// void conv32(Int in1_out[], Int in2[])
// {
// }
// void conv64(Int in1_out[], Int in2[])
// {
// }
// void convL1(Int in1_out[], Int in2[], size_t ntt_len)
// {
// assert(ntt_len <= L1_LEN && ntt_len >= 32);
// difL1X2(in1_out, in2, ntt_len);
// conv32(in1_out, in2);
// ditL1(in1_out, ntt_len);
// }
};
}
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
using namespace hint;
using namespace transform;
using namespace ntt::radix2_avx;
static NTT32AVX ntt(3, 998244353, false);
const size_t conv_len = n + m + 1, ntt_len = int_ceil2(conv_len);
auto ntt_a = new uint32_t[ntt_len];
auto ntt_b = new uint32_t[ntt_len];
std::memcpy(ntt_a, a, (n + 1) * sizeof(unsigned));
std::memcpy(ntt_b, b, (m + 1) * sizeof(unsigned));
std::memset(ntt_a + n + 1, 0, (ntt_len - n - 1) * sizeof(unsigned));
std::memset(ntt_b + m + 1, 0, (ntt_len - m - 1) * sizeof(unsigned));
ntt.dif(ntt_a, ntt_len);
ntt.dif(ntt_b, ntt_len);
uint32_t len_inv_r = ntt.mont.toMontgomery(ntt_len);
len_inv_r = ntt.mont.inv(len_inv_r);
len_inv_r = ntt.mont.mul(len_inv_r, ntt.mont.montR());
for (size_t i = 0; i < ntt_len; i++)
{
uint32_t n = ntt.mont.mul(ntt_a[i], ntt_b[i]);
ntt_a[i] = ntt.mont.mul(n, len_inv_r);
}
ntt.dit(ntt_a, ntt_len);
std::memcpy(c, ntt_a, conv_len * sizeof(uint32_t));
delete[] ntt_a;
delete[] ntt_b;
}
#include "stopwatch.hpp"
void test_convolution()
{
int m, n;
// std::cin >> m >> n;
int len1 = 1 << 22, len2 = len1;
unsigned *a = new unsigned[len1];
unsigned *b = new unsigned[len2];
unsigned *c = new unsigned[len1 + len2 - 1]{};
uint64_t ele = 5;
for (size_t i = 0; i < len1; i++)
{
// scanf("%d", &a[i]);
a[i] = 2;
}
for (size_t i = 0; i < len2; i++)
{
// scanf("%d", &b[i]);
b[i] = 5;
}
StopWatch w(1000);
w.start();
poly_multiply(a, len1 - 1, b, len2 - 1, c);
w.stop();
std::cout << w.duration() << "ms" << std::endl;
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
// std::cout << c[i] << " ";
}
delete[] a;
delete[] b;
delete[] c;
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |