#pragma GCC target("avx2")
#pragma GCC optimize("O2")
#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <array>
#include <ctime>
#include <cstring>
#include <immintrin.h>
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(4096) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T& operator[](size_t index)
{
return ary[index];
}
constexpr const T& operator[](size_t index) const
{
return ary[index];
}
T* data()
{
return reinterpret_cast<T*>(ary);
}
T* begin()
{
return reinterpret_cast<T*>(ary);
}
T* end()
{
return begin() + LEN;
}
const T* data() const
{
return reinterpret_cast<const T*>(ary);
}
template <typename Ty>
Ty* cast_ptr()
{
return reinterpret_cast<Ty*>(ary);
}
template <typename Ty>
const Ty* cast_ptr() const
{
return reinterpret_cast<const Ty*>(ary);
}
};
namespace hint
{
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// bits个二进制全为1的数,等于2^bits-1
template <typename T>
constexpr T all_one(int bits)
{
T tmp = T(1) << (bits - 1);
return tmp - 1 + tmp;
}
// 整数log2
template <typename UintTy>
constexpr int hint_log2(UintTy n)
{
constexpr int bits = 8 * sizeof(UintTy);
constexpr UintTy mask = all_one<UintTy>(bits / 2) << (bits / 2);
UintTy m = mask;
int res = 0, shift = bits / 2;
while (shift > 0)
{
if ((n & m))
{
res += shift;
n >>= shift;
}
shift /= 2;
m >>= shift;
}
return res;
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, uint32_t n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
// 数组按位相乘
template <typename T>
inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
{
size_t mod4 = len % 4;
len -= mod4;
for (size_t i = 0; i < len; i += 4)
{
out[i] = in1[i] * in2[i];
out[i + 1] = in1[i + 1] * in2[i + 1];
out[i + 2] = in1[i + 2] * in2[i + 2];
out[i + 3] = in1[i + 3] * in2[i + 3];
}
for (size_t i = len; i < len + mod4; i++)
{
out[i] = in1[i] * in2[i];
}
}
// FFT与类FFT变换的命名空间
namespace hint_transform
{
template <typename T>
inline void transform2(T& sum, T& diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
// 返回单位圆上辐角为theta的点
template <typename FloatTy>
inline auto unit_root(FloatTy theta)
{
return std::polar<FloatTy>(1.0, theta);
}
// 二进制逆序
template <typename It>
void binary_reverse_swap(It begin, It end)
{
const size_t len = end - begin;
// 左下标小于右下标时交换,防止重复交换
auto smaller_swap = [=](It it_left, It it_right)
{
if (it_left < it_right)
{
std::swap(it_left[0], it_right[0]);
}
};
// 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
auto get_next_bitrev = [=](It last)
{
size_t k = len / 2, indx = last - begin;
indx ^= k;
while (k > indx)
{
k >>= 1;
indx ^= k;
};
return begin + indx;
};
// 长度较短的普通逆序
if (len <= 16)
{
for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
{
smaller_swap(i, j);
j = get_next_bitrev(j);
}
return;
}
const size_t len_8 = len / 8;
const auto last = begin + len_8;
auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
{
smaller_swap(i0, j);
smaller_swap(i1, j + 1);
smaller_swap(i2, j + 2);
smaller_swap(i3, j + 3);
smaller_swap(i0 + len_8, j + 4);
smaller_swap(i1 + len_8, j + 5);
smaller_swap(i2 + len_8, j + 6);
smaller_swap(i3 + len_8, j + 7);
j = get_next_bitrev(j);
}
}
// 二进制逆序
template <typename T>
void binary_reverse_swap(T ary, const size_t len)
{
binary_reverse_swap(ary, ary + len);
}
namespace hint_ntt
{
template <uint32_t MOD>
struct ModInt32
{
uint32_t data;
constexpr ModInt32() {}
constexpr ModInt32(uint32_t in) : data(in) {}
constexpr ModInt32 largeNorm() const
{
return data < MOD ? data : data - MOD;
}
constexpr uint64_t mul64(ModInt32 in) const
{
return uint64_t(data) * uint64_t(in.data);
}
constexpr ModInt32 getW1() const
{
return (uint64_t(data) << 32) / MOD;
}
constexpr ModInt32 mulModShoup(ModInt32 w, ModInt32 w1) const
{
uint64_t q = (uint64_t(data) * uint32_t(w1.data)) >> 32;
ModInt32 res = uint64_t(data) * w.data - q * MOD;
// return res.largeNorm();
return res;
}
constexpr ModInt32 operator+(ModInt32 in) const
{
uint32_t diff = MOD - in.data;
return data < diff ? data + in.data : data - diff;
}
constexpr ModInt32 operator-(ModInt32 in) const
{
in.data = data - in.data;
return in.data > data ? in.data + MOD : in.data;
}
constexpr ModInt32 operator*(ModInt32 in) const
{
return mul64(in) % MOD;
}
constexpr ModInt32& operator+=(ModInt32 in)
{
return *this = *this + in;
}
constexpr ModInt32& operator-=(ModInt32 in)
{
return *this = *this - in;
}
constexpr ModInt32& operator*=(ModInt32 in)
{
return *this = *this * in;
}
constexpr operator uint32_t() const
{
return data;
}
static constexpr uint32_t mod()
{
return MOD;
}
};
template <uint32_t MOD>
struct MontInt32
{
static constexpr uint64_t R = uint64_t(1) << 32;
static constexpr uint32_t R_MASK = R - 1;
static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
static constexpr uint32_t MOD2 = MOD * 2;
static_assert(hint_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
static_assert(uint32_t(MOD_INV* MOD) == 1, "Montgomery32 modulus is not correct");
uint32_t data;
constexpr MontInt32() : data(0) {}
constexpr MontInt32(uint32_t n) : data(toMont(n)) {}
static constexpr uint32_t toMont(uint32_t n)
{
return (uint64_t(n) << 32) % MOD;
}
static constexpr uint32_t redc(uint64_t input)
{
uint64_t n = uint32_t(input) * MOD_INV_NEG;
n = n * MOD + input;
n >>= 32;
return n < MOD ? n : n - MOD;
}
static constexpr uint32_t toInt(uint32_t n)
{
return redc(n);
}
static constexpr uint32_t addMont(uint32_t m, uint32_t n)
{
n = m + n;
return n < MOD ? n : n - MOD;
}
static constexpr uint32_t subMont(uint32_t m, uint32_t n)
{
n = m - n;
return n > m ? n + MOD : n;
}
static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
{
return redc(uint64_t(m) * n);
}
constexpr void fromInt(uint32_t n)
{
data = toMont(n);
}
constexpr uint32_t toInt() const
{
return toInt(data);
}
constexpr operator uint32_t() const
{
return toInt();
}
constexpr MontInt32 operator+(MontInt32 rhs) const
{
rhs.data = addMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator-(MontInt32 rhs) const
{
rhs.data = subMont(data, rhs.data);
return rhs;
}
constexpr MontInt32 operator*(MontInt32 rhs) const
{
rhs.data = mulMont(data, rhs.data);
return rhs;
}
constexpr MontInt32& operator+=(const MontInt32& rhs)
{
data = addMont(data, rhs.data);
return *this;
}
constexpr MontInt32& operator-=(const MontInt32& rhs)
{
data = subMont(data, rhs.data);
return *this;
}
constexpr MontInt32& operator*=(const MontInt32& rhs)
{
data = mulMont(data, rhs.data);
return *this;
}
static constexpr uint32_t mod()
{
return MOD;
}
};
template <uint32_t MOD>
struct MontInt32X8
{
using MontInt = MontInt32<MOD>;
__m256i data;
MontInt32X8() { data = _mm256_setzero_si256(); }
MontInt32X8(MontInt x) { data = _mm256_set1_epi32(x.data); }
MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
{
data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
}
MontInt32X8(__m256i rhs) : data(rhs) {}
template <typename T>
MontInt32X8(const T* p)
{
loadu(p);
}
static constexpr uint32_t mod()
{
return MOD;
}
static MontInt32X8 zeroX8()
{
return _mm256_setzero_si256();
}
static MontInt32X8 modX8()
{
return _mm256_set1_epi32(MOD);
}
static MontInt32X8 mod1X8()
{
return _mm256_set1_epi32(MOD - 1);
}
static MontInt32X8 mod2X8()
{
return _mm256_set1_epi32(MOD * 2);
}
static MontInt32X8 modNX8()
{
return _mm256_set1_epi32(MontInt::MOD_INV_NEG);
}
MontInt32X8 mul64(MontInt32X8 rhs) const
{
return _mm256_mul_epu32(data, rhs.data);
}
MontInt32X8 lShift64(int n) const
{
return _mm256_slli_epi64(data, n);
}
MontInt32X8 rShift64(int n) const
{
return _mm256_srli_epi64(data, n);
}
MontInt32X8 evenElements() const
{
return blend<0b10101010>(data, zeroX8());
}
MontInt32X8 oddElements() const
{
return blend<0b01010101>(data, zeroX8());
}
std::pair<MontInt32X8, MontInt32X8> mul64hl(MontInt32X8 rhs) const
{
return std::make_pair(mul64(rhs), rShift64(32).mul64(rhs.rShift64(32)));
}
// MontInt32X8 getW1() const
// {
// alignas(32) uint32_t temp[8];
// store(temp);
// for (auto &&i : temp)
// {
// i = (uint64_t(i) << 32) / mod;
// }
// return MontInt32X8(temp);
// }
// MontInt32X8 mulModShoup(MontInt32X8 w, MontInt32X8 w1) const
// {
// MontInt32X8 q0, q1, t0, t1;
// std::tie(q0, q1) = mul64hl(w1);
// std::tie(t0, t1) = mul64hl(w);
// q0 = q0.rShift64(32), q1 = q1.rShift64(32);
// q0 = q0.mul64(MontInt32X8(mod)), q1 = q1.mul64(MontInt32X8(mod));
// t0 = t0.rawSub64(q0), t1 = t1.rawSub64(q1);
// return t0 | t1.lShift64(32);
// }
static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
{
MontInt32X8 p0 = even64.mul64(modNX8());
MontInt32X8 p1 = odd64.mul64(modNX8());
p0 = p0.mul64(modX8()).rawAdd64(even64).rShift64(32);
p1 = p1.mul64(modX8()).rawAdd64(odd64);
return blend<0b10101010>(p0, p1);
}
static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
{
return montRedcLazy(even64, odd64).largeNorm();
}
MontInt32X8 toMont() const
{
alignas(32) uint32_t temp[8];
store(temp);
for (auto&& i : temp)
{
i = (uint64_t(i) << 32) % MOD;
}
return MontInt32X8(temp);
}
MontInt32X8 toInt() const
{
MontInt32X8 e = evenElements();
MontInt32X8 o = rShift64(32);
return montRedc(e, o);
}
template <int N>
static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
{
return _mm256_blend_epi32(a.data, b.data, N);
}
MontInt32X8 largeNorm() const
{
MontInt32X8 sub = (*this > mod1X8()) & modX8();
return rawSub(sub);
}
MontInt32X8 smallNorm() const
{
MontInt32X8 add = (zeroX8() > *this) & modX8();
return rawAdd(add);
}
MontInt32X8 smallNorm2() const
{
MontInt32X8 add = (zeroX8() > *this) & mod2X8();
return rawAdd(add);
}
MontInt32X8 addMont(MontInt32X8 rhs) const
{
return rawAdd(rhs).largeNorm();
}
MontInt32X8 subMont(MontInt32X8 rhs) const
{
return rawSub(rhs).smallNorm();
}
MontInt32X8 mulMont(MontInt32X8 rhs) const
{
auto mulhl = mul64hl(rhs);
return montRedc(mulhl.first, mulhl.second);
}
MontInt32X8 operator+(MontInt32X8 rhs) const
{
return addMont(rhs);
}
MontInt32X8 operator-(MontInt32X8 rhs) const
{
return subMont(rhs);
}
MontInt32X8 operator*(MontInt32X8 rhs) const
{
return mulMont(rhs);
}
MontInt32X8 rawAdd(MontInt32X8 rhs) const
{
return _mm256_add_epi32(data, rhs.data);
}
MontInt32X8 rawSub(MontInt32X8 rhs) const
{
return _mm256_sub_epi32(data, rhs.data);
}
MontInt32X8 rawAdd64(MontInt32X8 rhs) const
{
return _mm256_add_epi64(data, rhs.data);
}
MontInt32X8 rawSub64(MontInt32X8 rhs) const
{
return _mm256_sub_epi64(data, rhs.data);
}
MontInt32X8 operator>(MontInt32X8 n) const
{
return _mm256_cmpgt_epi32(data, n.data);
}
MontInt32X8 operator<(MontInt32X8 n) const
{
return n > *this;
}
MontInt32X8 operator==(MontInt32X8 n) const
{
return _mm256_cmpeq_epi32(data, n.data);
}
MontInt32X8 operator&(MontInt32X8 n) const
{
return _mm256_and_si256(data, n.data);
}
MontInt32X8 operator|(MontInt32X8 n) const
{
return _mm256_or_si256(data, n.data);
}
MontInt32X8 operator^(MontInt32X8 n) const
{
return _mm256_xor_si256(data, n.data);
}
void set1(int32_t n)
{
data = _mm256_set1_epi32(n);
}
template <typename T>
void loadu(const T* p)
{
data = _mm256_loadu_si256((const __m256i*)p);
}
template <typename T>
void load(const T* p)
{
data = _mm256_load_si256((const __m256i*)p);
// data = *reinterpret_cast<const __m256i *>(p);
}
template <typename T>
void storeu(T* p) const
{
_mm256_storeu_si256((__m256i*)p, data);
}
template <typename T>
void store(T* p) const
{
_mm256_store_si256((__m256i*)p, data);
// *reinterpret_cast<__m256i *>(p) = data;
}
void printI32() const
{
alignas(32) int32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU32() const
{
alignas(32) uint32_t v[8];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3]
<< "," << v[4] << "," << v[5]
<< "," << v[6] << "," << v[7] << "]" << std::endl;
}
void printU64() const
{
alignas(32) uint64_t v[4];
store(v);
std::cout << "[" << v[0] << "," << v[1]
<< "," << v[2] << "," << v[3] << "]" << std::endl;
}
};
namespace split_radix_avx
{
template <uint32_t ROOT, typename ModIntType>
inline ModIntType mul_w41(ModIntType n)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
return n * W_4_1;
}
template <uint32_t ROOT, typename ModIntType>
inline ModIntType mul_w81(ModIntType n)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
return n * W_8_1;
}
template <uint32_t ROOT, typename ModIntType>
inline ModIntType mul_w83(ModIntType n)
{
constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 8) * 3);
return n * W_8_3;
}
// template <uint32_t MOD, uint32_t ROOT>
// inline void dit_butterfly244(ModIntType it[], ModIntType omega1, ModIntType omega3, size_t rank)
// {
// auto temp2 = it[rank * 2] * omega1;
// auto temp3 = it[rank * 3] * omega3;
// transform2(temp2, temp3);
// temp3 = mul_w41<ROOT>(temp3);
// auto temp0 = it[0];
// auto temp1 = it[rank];
// it[0] = temp0 + temp2;
// it[rank] = temp1 + temp3;
// it[rank * 2] = temp0 - temp2;
// it[rank * 3] = temp1 - temp3;
// }
// template <uint32_t MOD, uint32_t ROOT>
// inline void dif_butterfly244(ModIntType it[], ModIntType omega1, ModIntType omega3, size_t rank)
// {
// auto temp0 = it[0];
// auto temp1 = it[rank];
// auto temp2 = it[rank * 2];
// auto temp3 = it[rank * 3];
// it[0] = temp0 + temp2;
// it[rank] = temp1 + temp3;
// temp2 = temp0 - temp2;
// temp3 = temp1 - temp3;
// temp3 = mul_w41<ROOT>(temp3);
// transform2(temp2, temp3);
// it[rank * 2] = temp2 * omega1;
// it[rank * 3] = temp3 * omega3;
// }
template <uint32_t ROOT, uint32_t MOD>
inline void dit_butterfly244_avx2(MontInt32<MOD> it[], MontInt32X8<MOD> omega1, MontInt32X8<MOD> omega3, size_t rank)
{
MontInt32X8<MOD> temp0, temp1, temp2, temp3;
temp2.load(&it[rank * 2]);
temp3.load(&it[rank * 3]);
temp2 = temp2 * omega1;
temp3 = temp3 * omega3;
transform2(temp2, temp3);
constexpr MontInt32<MOD> W_4_1 = qpow(MontInt32<MOD>(ROOT), (MOD - 1) / 4);
MontInt32X8<MOD> Wx8;
Wx8.set1(W_4_1.data);
temp3 = temp3 * Wx8;
temp0.load(&it[0]);
temp1.load(&it[rank]);
(temp0 + temp2).store(&it[0]);
(temp1 + temp3).store(&it[rank]);
(temp0 - temp2).store(&it[rank * 2]);
(temp1 - temp3).store(&it[rank * 3]);
}
template <uint32_t ROOT, uint32_t MOD>
inline void dif_butterfly244_avx2(MontInt32<MOD> it[], MontInt32X8<MOD> omega1, MontInt32X8<MOD> omega3, size_t rank)
{
MontInt32X8<MOD> temp0, temp1, temp2, temp3;
temp0.loadu(&it[0]);
temp1.loadu(&it[rank]);
temp2.loadu(&it[rank * 2]);
temp3.loadu(&it[rank * 3]);
(temp0 + temp2).storeu(&it[0]);
(temp1 + temp3).storeu(&it[rank]);
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
constexpr MontInt32<MOD> W_4_1 = qpow(MontInt32<MOD>(ROOT), (MOD - 1) / 4);
MontInt32X8<MOD> Wx8;
Wx8.set1(W_4_1.data);
temp3 = temp3 * Wx8;
transform2(temp2, temp3);
(temp2 * omega1).storeu(&it[rank * 2]);
(temp3 * omega3).storeu(&it[rank * 3]);
}
// template <uint64_t ROOT, uint32_t MOD>
// static void dit_butterfly2488_avx(MontInt32<MOD> input[],
// MontInt32X8<MOD> omega, MontInt32X8<MOD> omega3, MontInt32X8<MOD> omega7,
// size_t rank)
// {
// MontInt32X8<MOD> temp0 = input[0];
// MontInt32X8<MOD> temp1 = input[rank];
// MontInt32X8<MOD> temp2 = input[rank * 2];
// MontInt32X8<MOD> temp3 = input[rank * 3];
// MontInt32X8<MOD> temp4 = input[rank * 4] * omega;
// MontInt32X8<MOD> temp5 = input[rank * 5] * omega;
// MontInt32X8<MOD> temp6 = input[rank * 6] * omega3;
// MontInt32X8<MOD> temp7 = input[rank * 7] * omega7;
// transform2(temp6, temp7);
// transform2(temp4, temp6);
// temp6 = mul_w41<ROOT>(temp6);
// temp7 = mul_w41<ROOT>(temp7);
// transform2(temp5, temp7);
// temp5 = mul_w81<ROOT>(temp5);
// temp7 = mul_w83<ROOT>(temp7);
// input[0] = temp0 + temp4;
// input[rank] = temp1 + temp5;
// input[rank * 2] = temp2 + temp6;
// input[rank * 3] = temp3 + temp7;
// input[rank * 4] = temp0 - temp4;
// input[rank * 5] = temp1 - temp5;
// input[rank * 6] = temp2 - temp6;
// input[rank * 7] = temp3 - temp7;
// }
// template <uint64_t ROOT, uint32_t MOD>
// static void dif_butterfly2488_avx(MontInt32<MOD> input[],
// MontInt32X8<MOD> omega, MontInt32X8<MOD> omega3, MontInt32X8<MOD> omega7,
// size_t rank)
// {
// MontInt32X8<MOD> temp0 = input[0];
// MontInt32X8<MOD> temp1 = input[rank];
// MontInt32X8<MOD> temp2 = input[rank * 2];
// MontInt32X8<MOD> temp3 = input[rank * 3];
// MontInt32X8<MOD> temp4 = input[rank * 4];
// MontInt32X8<MOD> temp5 = input[rank * 5];
// MontInt32X8<MOD> temp6 = input[rank * 6];
// MontInt32X8<MOD> temp7 = input[rank * 7];
// transform2(temp0, temp4);
// transform2(temp1, temp5);
// transform2(temp2, temp6);
// transform2(temp3, temp7);
// temp5 = mul_w81<ROOT>(temp5);
// temp7 = mul_w83<ROOT>(temp7);
// transform2(temp5, temp7);
// temp6 = mul_w41<ROOT>(temp6);
// temp7 = mul_w41<ROOT>(temp7);
// transform2(temp4, temp6);
// transform2(temp6, temp7);
// input[0] = temp0;
// input[rank] = temp1;
// input[rank * 2] = temp2;
// input[rank * 3] = temp3;
// input[rank * 4] = temp4 * omega;
// input[rank * 5] = temp5 * omega;
// input[rank * 6] = temp6 * omega3;
// input[rank * 7] = temp7 * omega7;
// }
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
struct NTTShort
{
static constexpr size_t ntt_len = LEN;
static constexpr size_t half_len = ntt_len / 2;
static constexpr size_t quarter_len = ntt_len / 4;
static constexpr size_t octant_len = ntt_len / 8;
static constexpr int log_len = hint_log2(ntt_len);
using ModIntX8 = MontInt32X8<MOD>;
using ModIntType = typename ModIntX8::MontInt;
using HalfNTT = NTTShort<half_len, MOD, ROOT>;
using QuarterNTT = NTTShort<quarter_len, MOD, ROOT>;
using TableType = AlignAry<ModIntType, quarter_len>;
static constexpr TableType getNTTTable(int factor)
{
ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / LEN) * factor);
ModIntType omega(1);
TableType res;
for (auto&& i : res)
{
i = omega;
omega *= root;
}
return res;
}
static TableType table1;
static TableType table3;
static constexpr uint64_t mod()
{
return ModIntType::mod();
}
static constexpr uint64_t root()
{
return ROOT;
}
static void dit(ModIntType in_out[])
{
QuarterNTT::dit(in_out + half_len + quarter_len);
QuarterNTT::dit(in_out + half_len);
HalfNTT::dit(in_out);
ModIntX8 omega1, omega3;
for (size_t i = 0; i < quarter_len; i += 8)
{
omega1.loadu(&table1[i]), omega3.loadu(&table3[i]);
dit_butterfly244_avx2<ROOT>(in_out + i, omega1, omega3, quarter_len);
}
}
static void dif(ModIntType in_out[])
{
ModIntX8 omega1, omega3;
for (size_t i = 0; i < quarter_len; i += 8)
{
omega1.loadu(&table1[i]), omega3.loadu(&table3[i]);
dif_butterfly244_avx2<ROOT>(in_out + i, omega1, omega3, quarter_len);
}
HalfNTT::dif(in_out);
QuarterNTT::dif(in_out + half_len);
QuarterNTT::dif(in_out + half_len + quarter_len);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < LEN)
{
HalfNTT::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < LEN)
{
HalfNTT::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table1 = NTTShort<LEN, MOD, ROOT>::getNTTTable(1);
template <size_t LEN, uint32_t MOD, uint32_t ROOT>
typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table3 = NTTShort<LEN, MOD, ROOT>::getNTTTable(3);
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<0, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<1, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<2, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dif(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<4, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp1);
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp2;
in_out[1] = temp1 + temp3;
in_out[2] = temp0 - temp2;
in_out[3] = temp1 - temp3;
}
static void dif(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp2);
transform2(temp1, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<8, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
static void dit(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp1);
transform2(temp2, temp3);
transform2(temp4, temp5);
transform2(temp6, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
in_out[0] = temp0 + temp4;
in_out[1] = temp1 + temp5;
in_out[2] = temp2 + temp6;
in_out[3] = temp3 + temp7;
in_out[4] = temp0 - temp4;
in_out[5] = temp1 - temp5;
in_out[6] = temp2 - temp6;
in_out[7] = temp3 - temp7;
}
static void dif(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp4);
transform2(temp1, temp5);
transform2(temp2, temp6);
transform2(temp3, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
in_out[4] = temp4 + temp5;
in_out[5] = temp4 - temp5;
in_out[6] = temp6 + temp7;
in_out[7] = temp6 - temp7;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint32_t MOD, uint32_t ROOT>
struct NTTShort<16, MOD, ROOT>
{
using ModIntType = MontInt32<MOD>;
using NTT4 = NTTShort<4, MOD, ROOT>;
using NTT8 = NTTShort<8, MOD, ROOT>;
static void dit(ModIntType in_out[])
{
constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
constexpr ModIntType w2 = qpow(w1, 2);
constexpr ModIntType w3 = qpow(w1, 3);
constexpr ModIntType w6 = qpow(w1, 6);
constexpr ModIntType w9 = qpow(w1, 9);
NTT4::dit(in_out + 12);
NTT4::dit(in_out + 8);
NTT8::dit(in_out);
ModIntType temp0, temp1, temp2, temp3;
temp2 = in_out[8];
temp3 = in_out[12];
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[0];
temp1 = in_out[4];
in_out[0] = temp0 + temp2;
in_out[4] = temp1 + temp3;
in_out[8] = temp0 - temp2;
in_out[12] = temp1 - temp3;
temp2 = in_out[9] * w1;
temp3 = in_out[13] * w3;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[1];
temp1 = in_out[5];
in_out[1] = temp0 + temp2;
in_out[5] = temp1 + temp3;
in_out[9] = temp0 - temp2;
in_out[13] = temp1 - temp3;
temp2 = in_out[10] * w2;
temp3 = in_out[14] * w6;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[2];
temp1 = in_out[6];
in_out[2] = temp0 + temp2;
in_out[6] = temp1 + temp3;
in_out[10] = temp0 - temp2;
in_out[14] = temp1 - temp3;
temp2 = in_out[11] * w3;
temp3 = in_out[15] * w9;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[3];
temp1 = in_out[7];
in_out[3] = temp0 + temp2;
in_out[7] = temp1 + temp3;
in_out[11] = temp0 - temp2;
in_out[15] = temp1 - temp3;
}
static void dif(ModIntType in_out[])
{
constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
constexpr ModIntType w2 = qpow(w1, 2);
constexpr ModIntType w3 = qpow(w1, 3);
constexpr ModIntType w6 = qpow(w1, 6);
constexpr ModIntType w9 = qpow(w1, 9);
ModIntType temp0, temp1, temp2, temp3;
temp0 = in_out[0];
temp1 = in_out[4];
temp2 = in_out[8];
temp3 = in_out[12];
in_out[0] = temp0 + temp2;
in_out[4] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[8] = temp2;
in_out[12] = temp3;
temp0 = in_out[1];
temp1 = in_out[5];
temp2 = in_out[9];
temp3 = in_out[13];
in_out[1] = temp0 + temp2;
in_out[5] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[9] = temp2 * w1;
in_out[13] = temp3 * w3;
temp0 = in_out[2];
temp1 = in_out[6];
temp2 = in_out[10];
temp3 = in_out[14];
in_out[2] = temp0 + temp2;
in_out[6] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[10] = temp2 * w2;
in_out[14] = temp3 * w6;
temp0 = in_out[3];
temp1 = in_out[7];
temp2 = in_out[11];
temp3 = in_out[15];
in_out[3] = temp0 + temp2;
in_out[7] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[11] = temp2 * w3;
in_out[15] = temp3 * w9;
NTT8::dif(in_out);
NTT4::dif(in_out + 8);
NTT4::dif(in_out + 12);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, MOD, ROOT>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, MOD, ROOT>::dif(in_out, len);
return;
}
dif(in_out);
}
};
};
}
}
}
using namespace std;
using namespace hint;
using namespace hint_transform;
// using namespace hint_ntt::split_radix_avx;
// template <typename T>
// vector<T> poly_multiply(const vector<T> &in1, const vector<T> &in2)
// {
// size_t len1 = in1.size(), len2 = in2.size(), out_len = len1 + len2;
// vector<T> result(out_len);
// size_t ntt_len = int_ceil2(out_len);
// using ntt = NTTShort<1 << 23, 998244353, 3>;
// using intt = NTTShort<1 << 23, 998244353, 3>;
// auto mod_ary1 = new ntt::NTTModInt32[ntt_len]();
// auto mod_ary2 = new ntt::NTTModInt32[ntt_len]();
// for (size_t i = 0; i < len1; i++)
// {
// mod_ary1[i] = in1[i];
// }
// for (size_t i = 0; i < len2; i++)
// {
// mod_ary2[i] = in2[i];
// }
// ntt::ntt_dif(mod_ary1, ntt_len);
// ntt::ntt_dif(mod_ary2, ntt_len);
// ary_mul(mod_ary1, mod_ary2, mod_ary1, ntt_len);
// intt::ntt_dit(mod_ary1, ntt_len);
// intt::ntt_basic::ntt_normalize(mod_ary1, ntt_len);
// for (size_t i = 0; i < out_len; i++)
// {
// result[i] = mod_ary1[i].data;
// }
// delete[] mod_ary1;
// delete[] mod_ary2;
// return result;
// }
// template <uint32_t MOD, uint32_t G_ROOT>
// void poly_inv(uint32_t *in, uint32_t *out, size_t len)
// {
// using ntt = NTT<MOD, G_ROOT, 1 << 24>;
// using intt = typename ntt::intt;
// using NttInt = typename ntt::NTTModInt32;
// std::vector<NttInt> ntt_ary(len * 2);
// auto in_ntt = ntt_ary.data();
// auto out_ntt = reinterpret_cast<NttInt *>(out);
// out[0] = mod_inv(in[0], MOD);
// for (size_t rank = 2; rank <= len; rank *= 2)
// {
// size_t gap = rank * 2;
// std::copy(in, in + rank, in_ntt);
// std::fill(in_ntt + rank, in_ntt + gap, 0);
// std::fill(out + rank / 2, out + gap, 0);
// // std::cout << gap << "\n";
// auto t1 = std::chrono::high_resolution_clock::now();
// ntt::ntt_dif(in_ntt, gap);
// ntt::ntt_dif(out_ntt, gap);
// for (size_t i = 0; i < gap; i++)
// {
// uint32_t a = in_ntt[i].data, b = out[i];
// out[i] = MOD - ((b * b % MOD) * a - b * 2 + MOD) % MOD;
// }
// intt::ntt_dit(out_ntt, gap);
// auto t2 = std::chrono::high_resolution_clock::now();
// // std::cout << "ntt time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us\n";
// uint32_t inv = mod_inv(gap, MOD);
// for (size_t i = 0; i < rank; i++)
// {
// out[i] = out[i] * inv % MOD;
// }
// }
// }
// template <typename T>
// void result_test(const vector<T> &res, uint32_t ele)
// {
// size_t len = res.size();
// for (size_t i = 0; i < len / 2; i++)
// {
// uint64_t x = (i + 1) * ele * ele;
// uint64_t y = res[i];
// if (x != y)
// {
// cout << "fail:" << i << "\t" << (i + 1) * ele * ele << "\t" << y << "\n";
// return;
// }
// }
// for (size_t i = len / 2; i < len; i++)
// {
// uint64_t x = (len - i - 1) * ele * ele;
// uint64_t y = res[i];
// if (x != y)
// {
// cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
// return;
// }
// }
// std::cout << "success\n";
// }
// int main()
// {
// StopWatch w(1000);
// int n = 18;
// cin >> n;
// size_t len = 1 << n; // 变换长度
// cout << "fft len:" << len << "\n";
// uint64_t ele = 5;
// vector<uint32_t> in1(len / 2, ele);
// vector<uint32_t> in2(len / 2, ele); // 计算两个长度为len/2,每个元素为ele的卷积
// w.start();
// vector<uint32_t> res = poly_multiply(in1, in2);
// // poly_inv<998244353, 3>(in1.data(), in2.data(), 1 << 16);
// w.stop();
// result_test(res, ele); // 结果校验
// cout << w.duration() << "ms\n";
// }
template <uint64_t ROOT, typename ModInt>
void ntt_dit(ModInt in_out[], size_t ntt_len)
{
for (size_t rank = 2; rank <= ntt_len; rank *= 2)
{
ModInt unit_omega = hint::qpow(ModInt(ROOT), (ModInt::mod() - 1) / rank);
size_t dis = rank / 2;
for (auto begin = in_out; begin < in_out + ntt_len; begin += rank)
{
ModInt omega = 1;
for (auto p = begin; p < begin + dis; p++)
{
auto temp0 = p[0], temp1 = p[dis] * omega;
p[0] = temp0 + temp1;
p[dis] = temp0 - temp1;
omega = omega * unit_omega;
}
}
}
}
template <uint64_t ROOT, typename ModInt>
void ntt_dif(ModInt in_out[], size_t ntt_len)
{
for (size_t rank = ntt_len; rank >= 2; rank /= 2)
{
ModInt unit_omega = hint::qpow(ModInt(ROOT), (ModInt::mod() - 1) / rank);
size_t dis = rank / 2;
for (auto begin = in_out; begin < in_out + ntt_len; begin += rank)
{
ModInt omega = 1;
for (auto p = begin; p < begin + dis; p++)
{
auto temp0 = p[0], temp1 = p[dis];
p[0] = temp0 + temp1;
p[dis] = (temp0 - temp1) * omega;
omega = omega * unit_omega;
}
}
}
}
void avx2_test()
{
using namespace hint;
using namespace hint_transform;
using namespace hint_ntt;
constexpr size_t len = 1 << 5;
constexpr uint32_t mod = 998244353;
using NTTX8 = MontInt32X8<mod>;
using ModInt = typename NTTX8::MontInt;
alignas(64) static ModInt a[len];
for (size_t i = 0; i < len; i++)
{
a[i] = i;
// b[i] = i;
}
size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
auto t1 = std::chrono::steady_clock::now();
for (size_t i = 0; i < times; i++)
{
// ntt::dif(a);
NTTX8 x;
x.loadu(a);
x = x * x;
x.storeu(a);
// ntt::dif(b.data(), len);
// ntt::dit(b.data(), len);
}
auto t2 = std::chrono::steady_clock::now();
auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
{
std::cout << i << ":\t" << uint32_t(a[i]) << "\n";
}
std::cout << time1 << "\n";
}
void ntt_check()
{
using namespace hint;
using namespace hint_transform;
using namespace hint_ntt;
using namespace split_radix_avx;
constexpr size_t len = 1 << 23;
constexpr uint32_t mod = 469762049, root = 3;
using ntt = NTTShort<len, mod, root>;
using ModInt = ntt::ModIntType;
using ModInt1 = ModInt32<mod>;
using NTTX8 = ntt::ModIntX8;
static AlignAry<ModInt, len> a;
static AlignAry<ModInt, len> b;
for (size_t i = 0; i < len; i++)
{
a[i] = i;
b[i] = i;
}
size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
auto t1 = std::chrono::steady_clock::now();
for (size_t i = 0; i < times; i++)
{
ntt::dif(a.data());
ntt::dif(a.data());
ntt::dit(a.data());
// ntt::dif(b.data(), len);
// ntt::dit(b.data(), len);
}
auto t2 = std::chrono::steady_clock::now();
for (size_t i = 0; i < times; i++)
{
ntt_dif<root>(b.data(), len);
ntt_dif<root>(b.data(), len);
ntt_dit<root>(b.data(), len);
// ntt_dit<root>(b, len);
// ntt_dif<root>(a.data(), len);
// ntt_dit<root>(a.data(), len);
}
auto t3 = std::chrono::steady_clock::now();
auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
auto time2 = std::chrono::duration_cast<std::chrono::duration<double>>(t3 - t2).count();
for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
{
if (uint32_t(a[i]) != uint32_t(b[i]))
{
std::cout << i << ":\t" << uint32_t(a[i].data) << "\t" << uint32_t(b[i].data) << "\n";
std::cout << i << ":\t" << uint32_t(a[i]) << "\t" << uint32_t(b[i]) << "\n";
return;
}
}
std::cout << time1 << "\t" << time2 << "\t" << time2 / time1 << "\n";
}
int main()
{
ntt_check();
// avx2_test();
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 50.671 ms | 96 MB + 92 KB | Runtime Error | Score: 0 | 显示更多 |