#pragma GCC target("avx2")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include <iostream>
#include <array>
#include <chrono>
#include <climits>
#include <immintrin.h>
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(32) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
constexpr T *data()
{
return ary;
}
constexpr T *begin()
{
return ary;
}
constexpr T *end()
{
return begin() + LEN;
}
constexpr const T *data() const
{
return ary;
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *cast_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
namespace hint
{
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// bits个二进制全为1的数,等于2^bits-1
template <typename T>
constexpr T all_one(int bits)
{
T tmp = T(1) << (bits - 1);
return tmp - 1 + tmp;
}
// 整数log2
template <typename UintTy>
constexpr int hint_log2(UintTy n)
{
constexpr int bits = 8 * sizeof(UintTy);
constexpr UintTy mask = all_one<UintTy>(bits / 2) << (bits / 2);
UintTy m = mask;
int res = 0, shift = bits / 2;
while (shift > 0)
{
if ((n & m))
{
res += shift;
n >>= shift;
}
shift /= 2;
m >>= shift;
}
return res;
}
template <typename IntTy>
constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
IntTy k = a / b;
IntTy g = exgcd(b, a - k * b, y, x);
y -= k * x;
return g;
}
template <typename IntTy>
constexpr IntTy mod_inv(IntTy n, IntTy mod)
{
n %= mod;
IntTy x = 0, y = 0;
exgcd(n, mod, x, y);
if (x < 0)
{
x += mod;
}
else if (x >= mod)
{
x -= mod;
}
return x;
}
template <typename IntTy>
constexpr int hint_ctz(IntTy x)
{
return hint_log2(x ^ (x - 1));
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, uint32_t n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
// FFT与类FFT变换的命名空间
namespace hint_transform
{
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
namespace hint_ntt
{
template <uint32_t MOD>
struct MontInt32
{
static constexpr uint64_t R = uint64_t(1) << 32;
static constexpr uint32_t R_MASK = R - 1;
static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
static constexpr uint32_t MOD2 = MOD * 2;
static_assert(hint_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
static_assert(uint32_t(MOD_INV * MOD) == 1, "Montgomery32 modulus is not correct");
uint32_t data;
constexpr MontInt32() : data(0) {}
constexpr MontInt32(uint32_t n) : data(toMont(n)) {}
static constexpr uint32_t toMont(uint32_t n)
{
return (uint64_t(n) << 32) % MOD;
}
static constexpr uint32_t redcLazy(uint64_t input)
{
uint64_t n = uint32_t(input) * MOD_INV_NEG;
n = n * MOD + input;
return n >> 32;
}
static constexpr uint32_t redc(uint64_t input)
{
uint32_t n = redcLazy(input);
return n < MOD ? n : n - MOD;
}
static constexpr uint32_t toInt(uint32_t n)
{
return redc(n);
}
static constexpr uint32_t addMont(uint32_t m, uint32_t n)
{
n = m + n;
return n < MOD * 2 ? n : n - MOD * 2;
}
static constexpr uint32_t subMont(uint32_t m, uint32_t n)
{
n = m - n;
return n > m ? n + MOD * 2 : n;
}
static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
{
return redcLazy(uint64_t(m) * n);
}
constexpr void fromInt(uint32_t n)
{
data = toMont(n);
}
constexpr uint32_t toInt() const
{
return toInt(data);
}
constexpr operator uint32_t() const
{
return toInt();
}
constexpr MontInt32 operator+(MontInt32 rhs) const
{
rhs.data = addMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator-(MontInt32 rhs) const
{
rhs.data = subMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator*(MontInt32 rhs) const
{
rhs.data = mulMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 &operator+=(const MontInt32 &rhs)
{
data = addMont(data, rhs.data);
return *this;
}
constexpr MontInt32 &operator-=(const MontInt32 &rhs)
{
data = subMont(data, rhs.data);
return *this;
}
constexpr MontInt32 &operator*=(const MontInt32 &rhs)
{
data = mulMont(data, rhs.data);
return *this;
}
static constexpr uint32_t mod()
{
return MOD;
}
};
template <uint32_t MOD>
struct MontInt32X8
{
using MontInt = MontInt32<MOD>;
__m256i data;
MontInt32X8() : data(_mm256_setzero_si256()) {}
MontInt32X8(MontInt x) : data(_mm256_set1_epi32(x.data)) {}
MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
{
data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
}
MontInt32X8(__m256i rhs) : data(rhs) {}
template <typename T>
MontInt32X8(const T *p)
{
loadu(p);
}
static constexpr uint32_t mod()
{
return MOD;
}
static MontInt32X8 zeroX8()
{
return _mm256_setzero_si256();
}
static MontInt32X8 modX8()
{
return _mm256_set1_epi32(mod());
}
static MontInt32X8 mod1X8()
{
constexpr uint32_t MOD1 = mod() - 1;
return _mm256_set1_epi32(MOD1);
}
static MontInt32X8 mod2X8()
{
constexpr uint32_t MOD2 = mod() * 2;
return _mm256_set1_epi32(MOD2);
}
static MontInt32X8 modNX8()
{
constexpr uint32_t MOD_INV_NEG = MontInt::MOD_INV_NEG;
return _mm256_set1_epi32(MOD_INV_NEG);
}
static MontInt32X8 RX8()
{
constexpr uint32_t R = (uint64_t(1) << 32) % mod();
return _mm256_set1_epi32(R);
}
MontInt32X8 mul64(MontInt32X8 rhs) const
{
return _mm256_mul_epu32(data, rhs.data);
}
template <int N>
MontInt32X8 lShift64() const
{
return _mm256_slli_epi64(data, N);
}
template <int N>
MontInt32X8 rShift64() const
{
return _mm256_srli_epi64(data, N);
}
template <int N>
MontInt32X8 lShiftByte128() const
{
return _mm256_bslli_epi128(data, N);
}
template <int N>
MontInt32X8 rShiftByte128() const
{
return _mm256_bsrli_epi128(data, N);
}
template <int N>
static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
{
return _mm256_blend_epi32(a.data, b.data, N);
}
template <int N>
static MontInt32X8 permute2X128(MontInt32X8 a, MontInt32X8 b)
{
return _mm256_permute2x128_si256(a.data, b.data, N);
}
// a,b,c,d -> a,0,b,0
MontInt32X8 evenElements() const
{
return blend<0b10101010>(data, zeroX8());
}
// a,b,c,d -> 0,b,0,d
MontInt32X8 oddElements() const
{
return blend<0b01010101>(data, zeroX8());
}
std::pair<MontInt32X8, MontInt32X8> mul32X32To64(MontInt32X8 rhs) const
{
return std::make_pair(mul64(rhs), rShift64<32>().mul64(rhs.rShift64<32>()));
}
static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
{
MontInt32X8 p0 = even64.mul64(modNX8());
MontInt32X8 p1 = odd64.mul64(modNX8());
p0 = p0.mul64(modX8()).rawAdd64(even64).template rShift64<32>();
p1 = p1.mul64(modX8()).rawAdd64(odd64);
return blend<0b10101010>(p0, p1);
}
static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
{
return montRedcLazy(even64, odd64).largeNorm();
}
MontInt32X8 toMont() const
{
alignas(32) uint32_t temp[8];
store(temp);
for (auto &&i : temp)
{
i = (uint64_t(i) << 32) % MOD;
}
return MontInt32X8(temp);
}
MontInt32X8 toInt() const
{
MontInt32X8 e = evenElements();
MontInt32X8 o = rShift64<32>();
return montRedc(e, o);
}
MontInt32X8 largeNorm() const
{
MontInt32X8 sub = (*this > mod1X8()) & modX8();
return rawSub(sub);
}
MontInt32X8 smallNorm() const
{
MontInt32X8 add = (zeroX8() > *this) & modX8();
return rawAdd(add);
}
MontInt32X8 smallNorm2() const
{
MontInt32X8 add = (zeroX8() > *this) & mod2X8();
return rawAdd(add);
}
MontInt32X8 addMont(MontInt32X8 rhs) const
{
return rawAdd(rhs).largeNorm();
}
MontInt32X8 subMont(MontInt32X8 rhs) const
{
return rawSub(rhs).smallNorm();
}
MontInt32X8 mulMont(MontInt32X8 rhs) const
{
auto mulhl = mul32X32To64(rhs);
return montRedc(mulhl.first, mulhl.second);
}
MontInt32X8 addMont2(MontInt32X8 rhs) const
{
return rawAdd(rhs).rawSub(mod2X8()).smallNorm2();
}
MontInt32X8 subMont2(MontInt32X8 rhs) const
{
return rawSub(rhs).smallNorm2();
}
MontInt32X8 mulMont2(MontInt32X8 rhs) const
{
auto mulhl = mul32X32To64(rhs);
return montRedcLazy(mulhl.first, mulhl.second);
}
MontInt32X8 operator+(MontInt32X8 rhs) const
{
return addMont2(rhs);
}
MontInt32X8 operator-(MontInt32X8 rhs) const
{
return subMont2(rhs);
}
MontInt32X8 operator*(MontInt32X8 rhs) const
{
return mulMont2(rhs);
}
MontInt32X8 rawAdd(MontInt32X8 rhs) const
{
return _mm256_add_epi32(data, rhs.data);
}
MontInt32X8 rawSub(MontInt32X8 rhs) const
{
return _mm256_sub_epi32(data, rhs.data);
}
MontInt32X8 rawAdd64(MontInt32X8 rhs) const
{
return _mm256_add_epi64(data, rhs.data);
}
MontInt32X8 rawSub64(MontInt32X8 rhs) const
{
return _mm256_sub_epi64(data, rhs.data);
}
MontInt32X8 operator>(MontInt32X8 n) const
{
return _mm256_cmpgt_epi32(data, n.data);
}
MontInt32X8 operator<(MontInt32X8 n) const
{
return n > *this;
}
MontInt32X8 operator==(MontInt32X8 n) const
{
return _mm256_cmpeq_epi32(data, n.data);
}
MontInt32X8 operator&(MontInt32X8 n) const
{
return _mm256_and_si256(data, n.data);
}
MontInt32X8 operator|(MontInt32X8 n) const
{
return _mm256_or_si256(data, n.data);
}
MontInt32X8 operator^(MontInt32X8 n) const
{
return _mm256_xor_si256(data, n.data);
}
void set1(int32_t n)
{
data = _mm256_set1_epi32(n);
}
template <typename T>
void loadu(const T *p)
{
data = _mm256_loadu_si256((const __m256i *)p);
}
template <typename T>
void load(const T *p)
{
data = _mm256_load_si256((const __m256i *)p);
// data = *reinterpret_cast<const __m256i *>(p);
}
template <typename T>
void storeu(T *p) const
{
_mm256_storeu_si256((__m256i *)p, data);
}
template <typename T>
void store(T *p) const
{
_mm256_store_si256((__m256i *)p, data);
// *reinterpret_cast<__m256i *>(p) = data;
}
void printI32() const
{
alignas(32) int32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU32() const
{
alignas(32) uint32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU64() const
{
alignas(32) uint64_t v[4];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3] << "]" << std::endl;
}
};
namespace split_radix_avx
{
template <uint32_t ROOT, typename ModIntType>
inline ModIntType mul_w41(ModIntType n)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
return n * W_4_1;
}
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w81(ModIntType n)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
return n * W_8_1;
}
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w83(ModIntType n)
{
constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 8) * 3);
return n * W_8_3;
}
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
struct NTTShort
{
static constexpr size_t ntt_len = LEN;
static constexpr size_t half_len = ntt_len / 2;
static constexpr size_t quarter_len = ntt_len / 4;
static constexpr size_t octant_len = ntt_len / 8;
static constexpr size_t rank = quarter_len;
static constexpr int log_len = hint_log2(ntt_len);
using ModIntX8 = MontInt32X8<MOD>;
using ModIntType = typename ModIntX8::MontInt;
using HalfNTT = NTTShort<half_len, MOD, ROOT>;
using QuarterNTT = NTTShort<quarter_len, MOD, ROOT>;
using TableType = AlignAry<ModIntType, quarter_len>;
static constexpr TableType getNTTTable(int factor)
{
ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / LEN) * factor);
ModIntType omega(1);
TableType res;
for (auto &&i : res)
{
i = omega;
omega *= root;
}
return res;
}
static TableType table1;
static TableType table3;
static constexpr uint64_t mod()
{
return ModIntType::mod();
}
static constexpr uint64_t root()
{
return ROOT;
}
static void dit(ModIntType in_out[])
{
QuarterNTT::dit(in_out + half_len + quarter_len);
QuarterNTT::dit(in_out + half_len);
HalfNTT::dit(in_out);
ModIntX8 omega1(&table1[0]), omega3(&table3[0]);
for (auto it = in_out, it1 = &table1[0], it3 = &table3[0]; it < in_out + quarter_len; it += 8, it1 += 8, it3 += 8)
{
omega1.load(it1), omega3.load(it3);
ModIntX8 temp0, temp1, temp2, temp3;
temp2.load(&it[rank * 2]);
temp3.load(&it[rank * 3]);
temp2 = temp2 * omega1;
temp3 = temp3 * omega3;
transform2(temp2, temp3);
constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (ModIntType::mod() - 1) / 4);
temp3 = temp3 * ModIntX8(W_4_1);
temp0.load(&it[0]);
temp1.load(&it[rank]);
(temp0 + temp2).store(&it[0]);
(temp1 + temp3).store(&it[rank]);
(temp0 - temp2).store(&it[rank * 2]);
(temp1 - temp3).store(&it[rank * 3]);
}
}
static void dif(ModIntType in_out[])
{
// constexpr ModIntType u1 = qpow(ModIntType(ROOT), (MOD - 1) / ntt_len * 8);
// constexpr ModIntType u3 = qpow(u1, 3);
ModIntX8 omega1(&table1[0]), omega3(&table3[0]);
for (auto it = in_out, it1 = &table1[0], it3 = &table3[0]; it < in_out + quarter_len; it += 8, it1 += 8, it3 += 8)
{
omega1.load(it1), omega3.load(it3);
ModIntX8 temp0, temp1, temp2, temp3;
temp0.load(&it[0]);
temp1.load(&it[rank]);
temp2.load(&it[rank * 2]);
temp3.load(&it[rank * 3]);
(temp0 + temp2).store(&it[0]);
(temp1 + temp3).store(&it[rank]);
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (ModIntType::mod() - 1) / 4);
temp3 = temp3 * ModIntX8(W_4_1);
transform2(temp2, temp3);
(temp2 * omega1).store(&it[rank * 2]);
(temp3 * omega3).store(&it[rank * 3]);
// omega1 = omega1 * ModIntX8(u1);
// omega3 = omega3 * ModIntX8(u3);
}
HalfNTT::dif(in_out);
QuarterNTT::dif(in_out + half_len);
QuarterNTT::dif(in_out + half_len + quarter_len);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < LEN)
{
HalfNTT::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < LEN)
{
HalfNTT::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table1 = NTTShort<LEN, MOD, ROOT>::getNTTTable(1);
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table3 = NTTShort<LEN, MOD, ROOT>::getNTTTable(3);
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<0, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<1, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<2, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dif(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<4, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp1);
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp2;
in_out[1] = temp1 + temp3;
in_out[2] = temp0 - temp2;
in_out[3] = temp1 - temp3;
}
static void dif(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp2);
transform2(temp1, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<8, MOD, ROOT>
{
using ModIntX8 = MontInt32X8<MOD>;
using ModIntType = typename ModIntX8::MontInt;
static void dit(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp1);
transform2(temp2, temp3);
transform2(temp4, temp5);
transform2(temp6, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
in_out[0] = temp0 + temp4;
in_out[1] = temp1 + temp5;
in_out[2] = temp2 + temp6;
in_out[3] = temp3 + temp7;
in_out[4] = temp0 - temp4;
in_out[5] = temp1 - temp5;
in_out[6] = temp2 - temp6;
in_out[7] = temp3 - temp7;
}
static void dif(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp4);
transform2(temp1, temp5);
transform2(temp2, temp6);
transform2(temp3, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
in_out[4] = temp4 + temp5;
in_out[5] = temp4 - temp5;
in_out[6] = temp6 + temp7;
in_out[7] = temp6 - temp7;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<16, MOD, ROOT>
{
using ModIntX8 = MontInt32X8<MOD>;
using ModIntType = typename ModIntX8::MontInt;
static constexpr ModIntType W_16_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
static constexpr ModIntType W_16_2 = qpow(W_16_1, 2);
static constexpr ModIntType W_16_3 = qpow(W_16_1, 3);
static constexpr ModIntType W_16_4 = qpow(W_16_1, 4);
static constexpr ModIntType W_16_5 = qpow(W_16_1, 5);
static constexpr ModIntType W_16_6 = qpow(W_16_1, 6);
static constexpr ModIntType W_16_7 = qpow(W_16_1, 7);
static ModIntX8 transform2X4(ModIntX8 in)
{
ModIntX8 temp1 = in.template rShift64<32>(); // b, 0
ModIntX8 temp2 = in.template lShift64<32>(); // 0, a
temp1 = in.rawSub(ModIntX8::mod2X8()).rawAdd(temp1); // a + b ,X
temp2 = temp2.rawSub(in); // X, a - b
return ModIntX8::template blend<0b10101010>(temp1, temp2).smallNorm2();
}
static void dit4X4(ModIntX8 &A, ModIntX8 &B)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1};
ModIntX8 temp0, temp1, temp2, temp3, omega;
temp0 = transform2X4(A); // A
temp1 = transform2X4(B); // B
temp2 = temp0.template rShiftByte128<8>(); // A2,A3,X,X
temp3 = temp1.template lShiftByte128<8>(); // X,X,B0,B1
temp0 = ModIntX8::template blend<0b11001100>(temp0, temp3); // A0,A1,B0,B1
temp1 = ModIntX8::template blend<0b11001100>(temp2, temp1); // A2,A3,B2,B3
omega.load(w_arr);
temp1 = temp1 * omega; // (A2,A3,B2,B3)*w
temp2 = temp0 + temp1; // A0,A1,B0,B1
temp3 = temp0 - temp1; // A2,A3,B2,B3
temp0 = temp2.template rShiftByte128<8>(); // B0,B1,X,X
temp1 = temp3.template lShiftByte128<8>(); // X,X,A2,A3
A = ModIntX8::template blend<0b11001100>(temp2, temp1); // A0,A1,A2,A3
B = ModIntX8::template blend<0b11001100>(temp0, temp3); // B0,B1,B2,B3
}
static void dit8X2(ModIntX8 &A, ModIntX8 &B)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
constexpr ModIntType W_8_2 = qpow(W_8_1, 2);
constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};
dit4X4(A, B);
ModIntX8 temp0, temp1, temp2, temp3, omega;
temp0 = ModIntX8::template permute2X128<0x20>(A, B); // A0,B0
temp1 = ModIntX8::template permute2X128<0x31>(A, B); // A1,B1
omega.load(w_arr);
temp1 = temp1 * omega;
temp2 = temp0 + temp1; // A0,B0
temp3 = temp0 - temp1; // A1,B1
A = ModIntX8::template permute2X128<0x20>(temp2, temp3); // A0,A1
B = ModIntX8::template permute2X128<0x31>(temp2, temp3); // B0,B1
}
static void dit(ModIntX8 &temp0, ModIntX8 &temp1)
{
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_16_1, W_16_2, W_16_3, W_16_4, W_16_5, W_16_6, W_16_7};
ModIntX8 omega;
omega.load(w_arr);
dit8X2(temp0, temp1);
temp1 = temp1 * omega;
transform2(temp0, temp1);
}
static void dit(ModIntType in_out[])
{
ModIntX8 temp0, temp1;
temp0.load(&in_out[0]), temp1.load(&in_out[8]);
dit(temp0, temp1);
temp0.store(&in_out[0]);
temp1.store(&in_out[8]);
}
static void dif4X4(ModIntX8 &A, ModIntX8 &B)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1};
ModIntX8 temp0, temp1, temp2, temp3, omega;
temp2 = A.template rShiftByte128<8>(); // A2,A3,X,X
temp3 = B.template lShiftByte128<8>(); // X,X,B0,B1
temp0 = ModIntX8::template blend<0b11001100>(A, temp3); // A0,A1,B0,B1
temp1 = ModIntX8::template blend<0b11001100>(temp2, B); // A2,A3,B2,B3
temp2 = temp0 + temp1; // A0,A1,B0,B1
temp3 = temp0 - temp1; // A2,A3,B2,B3
omega.load(w_arr);
temp3 = temp3 * omega; // (A2,A3,B2,B3)*w
temp0 = temp2.template rShiftByte128<8>(); // B0,B1,X,X
temp1 = temp3.template lShiftByte128<8>(); // X,X,A2,A3
temp2 = ModIntX8::template blend<0b11001100>(temp2, temp1); // A0,A1,A2,A3
temp3 = ModIntX8::template blend<0b11001100>(temp0, temp3); // B0,B1,B2,B3
A = transform2X4(temp2); // A
B = transform2X4(temp3); // B
}
static void dif8X2(ModIntX8 &A, ModIntX8 &B)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
constexpr ModIntType W_8_2 = qpow(W_8_1, 2);
constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};
ModIntX8 temp0, temp1, temp2, temp3, omega;
temp0 = ModIntX8::template permute2X128<0x20>(A, B); // A0,B0
temp1 = ModIntX8::template permute2X128<0x31>(A, B); // A1,B1
temp2 = temp0 + temp1; // A0,B0
temp3 = temp0 - temp1; // A1,B1
omega.load(w_arr);
temp3 = temp3 * omega;
A = ModIntX8::template permute2X128<0x20>(temp2, temp3); // A0,A1
B = ModIntX8::template permute2X128<0x31>(temp2, temp3); // B0,B1
dif4X4(A, B);
}
static void dif(ModIntX8 &temp0, ModIntX8 &temp1)
{
alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_16_1, W_16_2, W_16_3, W_16_4, W_16_5, W_16_6, W_16_7};
ModIntX8 omega;
omega.load(w_arr);
transform2(temp0, temp1);
temp1 = temp1 * omega;
dif8X2(temp0, temp1);
}
static void dif(ModIntType in_out[])
{
ModIntX8 temp0, temp1;
temp0.load(&in_out[0]), temp1.load(&in_out[8]);
dif(temp0, temp1);
temp0.store(&in_out[0]);
temp1.store(&in_out[8]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<32, MOD, ROOT>
{
using ModIntX8 = MontInt32X8<MOD>;
using ModIntType = typename ModIntX8::MontInt;
using NTT16 = NTTShort<16, MOD, ROOT>;
using TableType = AlignAry<ModIntType, 8>;
static constexpr TableType getNTTTable(int factor)
{
ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 32) * factor);
ModIntType omega(1);
TableType res;
for (auto &&i : res)
{
i = omega;
omega *= root;
}
return res;
}
static void dit(ModIntType in_out[])
{
constexpr TableType w1_arr = getNTTTable(1);
constexpr TableType w3_arr = getNTTTable(3);
ModIntX8 temp0, temp1, temp2, temp3, omega1, omega3;
omega1.load(w1_arr.data());
omega3.load(w3_arr.data());
temp0.load(&in_out[0]);
temp1.load(&in_out[8]);
temp2.load(&in_out[16]);
temp3.load(&in_out[24]);
NTT16::dit(temp0, temp1);
NTT16::dit8X2(temp2, temp3);
{
temp2 = temp2 * omega1;
temp3 = temp3 * omega3;
transform2(temp2, temp3);
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (MOD - 1) / 4);
temp3 = temp3 * ModIntX8(W_4_1);
transform2(temp0, temp2);
transform2(temp1, temp3);
}
temp0.store(&in_out[0]);
temp1.store(&in_out[8]);
temp2.store(&in_out[16]);
temp3.store(&in_out[24]);
}
static void dif(ModIntType in_out[])
{
constexpr TableType w1_arr = getNTTTable(1);
constexpr TableType w3_arr = getNTTTable(3);
ModIntX8 temp0, temp1, temp2, temp3, omega1, omega3;
omega1.load(w1_arr.data());
omega3.load(w3_arr.data());
temp0.load(&in_out[0]);
temp1.load(&in_out[8]);
temp2.load(&in_out[16]);
temp3.load(&in_out[24]);
{
transform2(temp0, temp2);
transform2(temp1, temp3);
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (MOD - 1) / 4);
temp3 = temp3 * ModIntX8(W_4_1);
transform2(temp2, temp3);
temp2 = temp2 * omega1;
temp3 = temp3 * omega3;
}
NTT16::dif(temp0, temp1);
NTT16::dif8X2(temp2, temp3);
temp0.store(&in_out[0]);
temp1.store(&in_out[8]);
temp2.store(&in_out[16]);
temp3.store(&in_out[24]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 32)
{
NTT16::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 32)
{
NTT16::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTT
{
static constexpr uint32_t mod()
{
return MOD;
}
static constexpr uint32_t root()
{
return ROOT;
}
static constexpr uint32_t iroot()
{
return mod_inv<int64_t>(root(), mod());
}
static constexpr bool selfCheck()
{
uint64_t n = root();
n *= uint64_t(iroot());
n %= uint64_t(mod());
return n == uint64_t(1);
}
static_assert(root() < mod(), "ROOT must be smaller than MOD");
static_assert(selfCheck(), "IROOT * ROOT % MOD must be 1");
static constexpr int mod_bits = hint_log2(mod()) + 1;
static constexpr int max_log_len = hint_ctz(mod() - 1);
static constexpr size_t getMaxLen()
{
if (max_log_len < sizeof(size_t) * CHAR_BIT)
{
return size_t(1) << max_log_len;
}
return size_t(1) << (sizeof(size_t) * CHAR_BIT - 1);
}
static constexpr size_t ntt_max_len = getMaxLen();
using INTT = NTT<mod(), iroot()>;
static constexpr size_t LONG_THRESHOLD = size_t(1) << 12;
using NTTTemplate = NTTShort<LONG_THRESHOLD, MOD, ROOT>;
using ModIntType = typename NTTTemplate::ModIntType;
using ModIntX8 = typename NTTTemplate::ModIntX8;
static constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (mod() - 1) / 4);
static constexpr ModIntType W_8_1 = qpow(ModIntType(root()), (mod() - 1) / 8);
static constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
static ModIntX8 unitx8(size_t ntt_len, int factor)
{
return ModIntX8(qpow(ModIntType(root()), (mod() - 1) / ntt_len * factor * 8));
}
static ModIntX8 omegax8(size_t ntt_len, int factor)
{
alignas(32) ModIntType w_arr[8]{};
ModIntType w(1), unit(qpow(ModIntType(root()), (mod() - 1) / ntt_len * factor));
for (auto &&i : w_arr)
{
i = w;
w = w * unit;
}
return ModIntX8(w_arr);
}
static void dit(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dit(in_out, ntt_len);
return;
}
size_t octant_len = ntt_len / 8;
dit(in_out + octant_len * 7, ntt_len / 8);
dit(in_out + octant_len * 6, ntt_len / 8);
dit(in_out + octant_len * 4, ntt_len / 4);
dit(in_out, ntt_len / 2);
const ModIntX8 unit1_x8 = unitx8(ntt_len, 1), unit3_x8 = unitx8(ntt_len, 3), unit7_x8 = unitx8(ntt_len, 7);
ModIntX8 omega1 = omegax8(ntt_len, 1), omega3 = omegax8(ntt_len, 3), omega7 = omegax8(ntt_len, 7);
for (auto it = in_out; it < in_out + octant_len; it += 8)
{
{
ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
temp0.load(&it[0]);
temp1.load(&it[octant_len]);
temp2.load(&it[octant_len * 2]);
temp3.load(&it[octant_len * 3]);
temp4.load(&it[octant_len * 4]);
temp5.load(&it[octant_len * 5]);
temp6.load(&it[octant_len * 6]);
temp7.load(&it[octant_len * 7]);
temp4 = temp4 * omega1;
temp5 = temp5 * omega1;
temp6 = temp6 * omega3;
temp7 = temp7 * omega7;
transform2(temp6, temp7);
transform2(temp4, temp6);
temp6 = temp6 * ModIntX8(W_4_1);
temp7 = temp7 * ModIntX8(W_4_1);
transform2(temp5, temp7);
temp5 = temp5 * ModIntX8(W_8_1);
temp7 = temp7 * ModIntX8(W_8_3);
(temp0 + temp4).store(&it[0]);
(temp1 + temp5).store(&it[octant_len]);
(temp2 + temp6).store(&it[octant_len * 2]);
(temp3 + temp7).store(&it[octant_len * 3]);
(temp0 - temp4).store(&it[octant_len * 4]);
(temp1 - temp5).store(&it[octant_len * 5]);
(temp2 - temp6).store(&it[octant_len * 6]);
(temp3 - temp7).store(&it[octant_len * 7]);
}
omega1 = omega1 * unit1_x8;
omega3 = omega3 * unit3_x8;
omega7 = omega7 * unit7_x8;
}
}
static void dif(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dif(in_out, ntt_len);
return;
}
size_t octant_len = ntt_len / 8;
const ModIntX8 unit1_x8 = unitx8(ntt_len, 1), unit3_x8 = unitx8(ntt_len, 3), unit7_x8 = unitx8(ntt_len, 7);
ModIntX8 omega1 = omegax8(ntt_len, 1), omega3 = omegax8(ntt_len, 3), omega7 = omegax8(ntt_len, 7);
for (auto it = in_out; it < in_out + octant_len; it += 8)
{
{
ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
temp0.load(&it[0]);
temp1.load(&it[octant_len]);
temp2.load(&it[octant_len * 2]);
temp3.load(&it[octant_len * 3]);
temp4.load(&it[octant_len * 4]);
temp5.load(&it[octant_len * 5]);
temp6.load(&it[octant_len * 6]);
temp7.load(&it[octant_len * 7]);
transform2(temp0, temp4);
transform2(temp1, temp5);
transform2(temp2, temp6);
transform2(temp3, temp7);
temp5 = temp5 * ModIntX8(W_8_1);
temp7 = temp7 * ModIntX8(W_8_3);
transform2(temp5, temp7);
temp6 = temp6 * ModIntX8(W_4_1);
temp7 = temp7 * ModIntX8(W_4_1);
transform2(temp4, temp6);
transform2(temp6, temp7);
(temp0).store(&it[0]);
(temp1).store(&it[octant_len]);
(temp2).store(&it[octant_len * 2]);
(temp3).store(&it[octant_len * 3]);
(temp4 * omega1).store(&it[octant_len * 4]);
(temp5 * omega1).store(&it[octant_len * 5]);
(temp6 * omega3).store(&it[octant_len * 6]);
(temp7 * omega7).store(&it[octant_len * 7]);
}
omega1 = omega1 * unit1_x8;
omega3 = omega3 * unit3_x8;
omega7 = omega7 * unit7_x8;
}
dif(in_out, octant_len * 4);
dif(in_out + octant_len * 4, octant_len * 2);
dif(in_out + octant_len * 6, octant_len);
dif(in_out + octant_len * 7, octant_len);
}
static void convolution(ModIntType in1[], ModIntType in2[], ModIntType out[], size_t ntt_len)
{
const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
dif(in1, ntt_len);
dif(in2, ntt_len);
if (ntt_len < 8)
{
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i] * inv_len;
}
}
else
{
const ModIntX8 inv8(inv_len);
for (size_t i = 0; i < ntt_len; i += 8)
{
ModIntX8 temp0, temp1;
temp0.load(&in1[i]), temp1.load(&in2[i]);
(temp0 * temp1 * inv8).store(&out[i]);
}
}
INTT::dit(out, ntt_len);
}
};
};
}
}
}
using namespace std;
using namespace hint;
using namespace hint_transform::hint_ntt::split_radix_avx;
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
size_t conv_len = m + n + 1, ntt_len = int_ceil2(conv_len);
using ntt = NTT<998244353, 3>;
using ModInt = ntt::ModIntType;
using ModIntX8 = ntt::ModIntX8;
ModInt *a_ntt = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
ModInt *b_ntt = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
std::fill(a_ntt + n + 1, a_ntt + ntt_len, ModInt{});
std::fill(b_ntt + m + 1, b_ntt + ntt_len, ModInt{});
std::copy(a, a + n + 1, a_ntt);
std::copy(b, b + m + 1, b_ntt);
ntt::convolution(a_ntt, b_ntt, a_ntt, ntt_len);
size_t rem_len = conv_len % 8, i = 0;
for (; i < conv_len - rem_len; i += 8)
{
ModIntX8 temp;
temp.load(&a_ntt[i]);
temp = temp.toInt();
temp.storeu(&c[i]);
}
for (; i < conv_len; i++)
{
c[i] = uint32_t(a_ntt[i]);
}
_mm_free(a_ntt);
_mm_free(b_ntt);
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |