#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <climits>
#include <string>
#include <array>
#include <type_traits>
// Windows 64位快速乘法宏
#if defined(_WIN64)
#include <intrin.h>
#define UMUL128
#endif //_WIN64
// GCC 64位快速乘法宏
#if defined(__SIZEOF_INT128__)
#define UINT128T
#endif //__SIZEOF_INT128__
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
// bits of 1, equals to 2^bits - 1
template <typename T>
constexpr T all_one(int bits)
{
T temp = T(1) << (bits - 1);
return temp - 1 + temp;
}
// Leading zeros
template <typename IntTy>
constexpr int hint_clz(IntTy x)
{
constexpr uint32_t MASK32 = uint32_t(0xFFFF) << 16;
int res = sizeof(IntTy) * CHAR_BIT;
if (x & MASK32)
{
res -= 16;
x >>= 16;
}
if (x & (MASK32 >> 8))
{
res -= 8;
x >>= 8;
}
if (x & (MASK32 >> 12))
{
res -= 4;
x >>= 4;
}
if (x & (MASK32 >> 14))
{
res -= 2;
x >>= 2;
}
if (x & (MASK32 >> 15))
{
res -= 1;
x >>= 1;
}
return res - x;
}
// Leading zeros
constexpr int hint_clz(uint64_t x)
{
if (x & (uint64_t(0xFFFFFFFF) << 32))
{
return hint_clz(uint32_t(x >> 32));
}
return hint_clz(uint32_t(x)) + 32;
}
// Integer bit length
template <typename IntTy>
constexpr int hint_bit_length(IntTy x)
{
if (x == 0)
{
return 0;
}
return sizeof(IntTy) * CHAR_BIT - hint_clz(x);
}
// Integer log2
template <typename IntTy>
constexpr int hint_log2(IntTy x)
{
return (sizeof(IntTy) * CHAR_BIT - 1) - hint_clz(x);
}
constexpr int hint_ctz(uint32_t x)
{
int r = 31;
x &= (-x);
if (x & 0x0000FFFF)
{
r -= 16;
}
if (x & 0x00FF00FF)
{
r -= 8;
}
if (x & 0x0F0F0F0F)
{
r -= 4;
}
if (x & 0x33333333)
{
r -= 2;
}
if (x & 0x55555555)
{
r -= 1;
}
return r;
}
constexpr int hint_ctz(uint64_t x)
{
if (x & 0xFFFFFFFF)
{
return hint_ctz(uint32_t(x));
}
return hint_ctz(uint32_t(x >> 32)) + 32;
}
// Fast power
template <typename T, typename T1>
constexpr T qpow(T m, T1 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
}
m *= m;
n >>= 1;
}
return result;
}
// Fast power with mod
template <typename T, typename T1>
constexpr T qpow(T m, T1 n, T mod)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
result %= mod;
}
m *= m;
m %= mod;
n >>= 1;
}
return result;
}
// Get cloest power of 2 that not larger than n
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
// Get cloest power of 2 that not smaller than n
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// x + y = sum with carry
template <typename UintTy>
constexpr UintTy add_half(UintTy x, UintTy y, bool &cf)
{
x = x + y;
cf = (x < y);
return x;
}
// x - y = diff with borrow
template <typename UintTy>
constexpr UintTy sub_half(UintTy x, UintTy y, bool &bf)
{
y = x - y;
bf = (y > x);
return y;
}
// x + y + cf = sum with carry
template <typename UintTy>
constexpr UintTy add_carry(UintTy x, UintTy y, bool &cf)
{
UintTy sum = x + cf;
cf = (sum < x);
sum += y; // carry
cf = cf || (sum < y); // carry
return sum;
}
// x - y - bf = diff with borrow
template <typename UintTy>
constexpr UintTy sub_borrow(UintTy x, UintTy y, bool &bf)
{
UintTy diff = x - bf;
bf = (diff > x);
y = diff - y; // borrow
bf = bf || (y > diff); // borrow
return y;
}
// a * x + b * y = gcd(a,b)
template <typename IntTy>
constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
IntTy k = a / b;
IntTy g = exgcd(b, a - k * b, y, x);
y -= k * x;
return g;
}
// return n^-1 mod mod
template <typename IntTy>
constexpr IntTy mod_inv(IntTy n, IntTy mod)
{
n %= mod;
IntTy x = 0, y = 0;
exgcd(n, mod, x, y);
if (x < 0)
{
x += mod;
}
else if (x >= mod)
{
x -= mod;
}
return x;
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
// Compute Integer multiplication, 64bit x 64bit to 128bit, basic algorithm
// first is low 64bit, second is high 64bit
constexpr std::pair<uint64_t, uint64_t> mul64x64to128_base(uint64_t a, uint64_t b)
{
uint64_t ah = a >> 32, bh = b >> 32;
a = uint32_t(a), b = uint32_t(b);
uint64_t r0 = a * b, r1 = a * bh, r2 = ah * b, r3 = ah * bh;
r3 += (r1 >> 32) + (r2 >> 32);
r1 = uint32_t(r1), r2 = uint32_t(r2);
r1 += r2;
r1 += (r0 >> 32);
r3 += (r1 >> 32);
r0 = (r1 << 32) | uint32_t(r0);
return std::make_pair(r0, r3);
}
inline std::pair<uint64_t, uint64_t> mul64x64to128(uint64_t a, uint64_t b)
{
#if defined(UMUL128)
#pragma message("Using _umul128 to compute 64bit x 64bit to 128bit")
unsigned long long low, high;
low = _umul128(a, b, &high);
return std::make_pair(low, high);
#else
#if defined(UINT128T) // No _umul128
#pragma message("Using __uint128_t to compute 64bit x 64bit to 128bit")
__uint128_t x(a);
x *= b;
return std::make_pair(uint64_t(x), uint64_t(x >> 64));
#else // No __uint128_t
#pragma message("Using basic function to compute 64bit x 64bit to 128bit")
return mul64x64to128_base(a, b);
#endif // UINT128T
#endif // UMUL128
}
constexpr uint32_t div128by32(uint64_t ÷nd_hi64, uint64_t ÷nd_lo64, uint32_t divisor)
{
uint32_t quot_hi32 = 0, quot_lo32 = 0;
uint64_t dividend = dividend_hi64 >> 32;
quot_hi32 = dividend / divisor;
dividend %= divisor;
dividend = (dividend << 32) | uint32_t(dividend_hi64);
quot_lo32 = dividend / divisor;
dividend %= divisor;
dividend_hi64 = (uint64_t(quot_hi32) << 32) | quot_lo32;
dividend = (dividend << 32) | uint32_t(dividend_lo64 >> 32);
quot_hi32 = dividend / divisor;
dividend %= divisor;
dividend = (dividend << 32) | uint32_t(dividend_lo64);
quot_lo32 = dividend / divisor;
dividend %= divisor;
dividend_lo64 = (uint64_t(quot_hi32) << 32) | quot_lo32;
return dividend;
}
// 96bit integer divided by 64bit integer, input make sure the quotient smaller than 2^32.
constexpr uint32_t div96by64to32(uint32_t dividend_hi32, uint64_t ÷nd_lo64, uint64_t divisor)
{
if (0 == dividend_hi32)
{
uint32_t quotient = dividend_lo64 / divisor;
dividend_lo64 %= divisor;
return quotient;
}
uint64_t divid2 = (uint64_t(dividend_hi32) << 32) | (dividend_lo64 >> 32);
uint64_t divis1 = divisor >> 32;
divisor = uint32_t(divisor);
uint64_t qhat = divid2 / divis1;
divid2 %= divis1;
divid2 = (divid2 << 32) | uint32_t(dividend_lo64);
uint64_t prod = qhat * divisor;
divis1 <<= 32;
if (prod > divid2)
{
qhat--;
prod -= divisor;
divid2 += divis1;
// if divid2 <= divis1, the addtion of divid2 is overflow, so prod must not be larger than divid2.
if ((divid2 > divis1) && (prod > divid2))
{
qhat--;
prod -= divisor;
divid2 += divis1;
}
}
divid2 -= prod;
dividend_lo64 = divid2;
return uint32_t(qhat);
}
// 128bit integer divided by 64bit integer, input make sure the quotient smaller than 2^64.
constexpr uint64_t div128by64to64(uint64_t dividend_hi64, uint64_t ÷nd_lo64, uint64_t divisor)
{
int k = 0;
if (divisor < (uint64_t(1) << 63))
{
k = hint::hint_clz(divisor);
divisor <<= k; // Normalization.
dividend_hi64 = (dividend_hi64 << k) | (dividend_lo64 >> (64 - k));
dividend_lo64 <<= k;
}
uint32_t divid_hi32 = dividend_hi64 >> 32;
uint64_t divid_lo64 = (dividend_hi64 << 32) | (dividend_lo64 >> 32);
uint64_t quotient = hint::div96by64to32(divid_hi32, divid_lo64, divisor);
divid_hi32 = divid_lo64 >> 32;
dividend_lo64 = uint32_t(dividend_lo64) | (divid_lo64 << 32);
quotient = (quotient << 32) | hint::div96by64to32(divid_hi32, dividend_lo64, divisor);
dividend_lo64 >>= k;
return quotient;
}
// uint64_t to std::string
inline std::string ui64to_string_base10(uint64_t input, uint8_t digits)
{
std::string result(digits, '0');
for (uint8_t i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
namespace transform
{
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
// 二进制逆序
template <typename It>
void binary_reverse_swap(It begin, It end)
{
const size_t len = end - begin;
// 左下标小于右下标时交换,防止重复交换
auto smaller_swap = [=](It it_left, It it_right)
{
if (it_left < it_right)
{
std::swap(it_left[0], it_right[0]);
}
};
// 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
auto get_next_bitrev = [=](It last)
{
size_t k = len / 2, indx = last - begin;
indx ^= k;
while (k > indx)
{
k >>= 1;
indx ^= k;
};
return begin + indx;
};
// 长度较短的普通逆序
if (len <= 16)
{
for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
{
smaller_swap(i, j);
j = get_next_bitrev(j);
}
return;
}
const size_t len_8 = len / 8;
const auto last = begin + len_8;
auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
{
smaller_swap(i0, j);
smaller_swap(i1, j + 1);
smaller_swap(i2, j + 2);
smaller_swap(i3, j + 3);
smaller_swap(i0 + len_8, j + 4);
smaller_swap(i1 + len_8, j + 5);
smaller_swap(i2 + len_8, j + 6);
smaller_swap(i3 + len_8, j + 7);
j = get_next_bitrev(j);
}
}
// 二进制逆序
template <typename T>
void binary_reverse_swap(T ary, const size_t len)
{
binary_reverse_swap(ary, ary + len);
}
// 多模式,自动类型,自检查快速数论变换
namespace ntt
{
constexpr uint64_t MOD0 = 2485986994308513793, ROOT0 = 5;
constexpr uint64_t MOD1 = 1945555039024054273, ROOT1 = 5;
constexpr uint64_t MOD2 = 4179340454199820289, ROOT2 = 3;
constexpr uint64_t MOD3 = 754974721, ROOT3 = 11;
constexpr uint64_t MOD4 = 469762049, ROOT4 = 3;
constexpr uint64_t MOD5 = 3489660929, ROOT5 = 3;
constexpr uint64_t MOD6 = 3221225473, ROOT6 = 5;
class Uint128
{
private:
uint64_t low, high;
public:
constexpr Uint128(uint64_t l = 0, uint64_t h = 0) : low(l), high(h) {}
constexpr Uint128(std::pair<uint64_t, uint64_t> p) : low(p.first), high(p.second) {}
constexpr Uint128 operator+(Uint128 rhs) const
{
rhs.low += low;
rhs.high += high + (rhs.low < low);
return rhs;
}
constexpr Uint128 operator-(Uint128 rhs) const
{
rhs.low = low - rhs.low;
rhs.high = high - rhs.high - (rhs.low > low);
return rhs;
}
constexpr Uint128 operator+(uint64_t rhs) const
{
rhs = low + rhs;
return Uint128(rhs, high + (rhs < low));
}
constexpr Uint128 operator-(uint64_t rhs) const
{
rhs = low - rhs;
return Uint128(rhs, high - (rhs > low));
}
// Only compute the low * rhs.low
Uint128 operator*(const Uint128 &rhs) const
{
return mul64x64to128(low, rhs.low);
}
// Only compute the low * rhs
Uint128 operator*(uint64_t rhs) const
{
return mul64x64to128(low, rhs);
}
// Only compute the 128bit / 64 bit
constexpr Uint128 operator/(const Uint128 &rhs) const
{
return *this / rhs.low;
}
// Only compute the 128bit % 64 bit
constexpr Uint128 operator%(const Uint128 &rhs) const
{
return *this % rhs.low;
}
// Only compute the 128bit / 64 bit
constexpr Uint128 operator/(uint64_t rhs) const
{
Uint128 quot = *this;
quot.selfDivRem(rhs);
return quot;
}
// Only compute the 128bit % 64 bit
constexpr Uint128 operator%(uint64_t rhs) const
{
Uint128 quot = *this;
uint64_t rem = quot.selfDivRem(rhs);
return Uint128(rem);
}
constexpr Uint128 &operator+=(const Uint128 &rhs)
{
return *this = *this + rhs;
}
constexpr Uint128 &operator-=(const Uint128 &rhs)
{
return *this = *this - rhs;
}
constexpr Uint128 &operator+=(uint64_t rhs)
{
return *this = *this + rhs;
}
constexpr Uint128 &operator-=(uint64_t rhs)
{
return *this = *this - rhs;
}
// Only compute the low * rhs.low
constexpr Uint128 &operator*=(const Uint128 &rhs)
{
return *this = mul64x64to128_base(low, rhs.low);
}
constexpr Uint128 &operator/=(const Uint128 &rhs)
{
return *this = *this / rhs;
}
constexpr Uint128 &operator%=(const Uint128 &rhs)
{
return *this = *this % rhs;
}
// Return *this % divisor, *this /= divisor
constexpr uint64_t selfDivRem(uint64_t divisor)
{
if ((divisor >> 32) == 0)
{
return div128by32(high, low, uint32_t(divisor));
}
uint64_t divid1 = high % divisor, divid0 = low;
high /= divisor;
low = div128by64to64(divid1, divid0, divisor);
return divid0;
}
static constexpr Uint128 mul64x64(uint64_t a, uint64_t b)
{
return Uint128(mul64x64to128_base(a, b));
}
constexpr bool operator<(const Uint128 &rhs) const
{
if (high != rhs.high)
{
return high < rhs.high;
}
return low < rhs.low;
}
constexpr bool operator==(const Uint128 &rhs) const
{
return high == rhs.high && low == rhs.low;
}
constexpr Uint128 operator<<(int shift) const
{
if (shift == 0)
{
return *this;
}
shift %= 128;
shift = shift < 0 ? shift + 128 : shift;
if (shift < 64)
{
return Uint128(low << shift, (high << shift) | (low >> (64 - shift)));
}
return Uint128(0, low << (shift - 64));
}
constexpr Uint128 operator>>(int shift) const
{
if (shift == 0)
{
return *this;
}
shift %= 128;
shift = shift < 0 ? shift + 128 : shift;
if (shift < 64)
{
return Uint128((low >> shift) | (high << (64 - shift)), high >> shift);
}
return Uint128(high >> (shift - 64), 0);
}
constexpr Uint128 &operator<<=(int shift)
{
return *this = *this << shift;
}
constexpr Uint128 &operator>>=(int shift)
{
return *this = *this >> shift;
}
constexpr uint64_t high64() const
{
return high;
}
constexpr uint64_t low64() const
{
return low;
}
constexpr operator uint64_t() const
{
return low64();
}
std::string toStringBase10() const
{
if (high == 0)
{
return std::to_string(low);
}
constexpr uint64_t BASE(10000'0000'0000'0000);
Uint128 copy(*this);
std::string s;
s = ui64to_string_base10(uint64_t(copy.selfDivRem(BASE)), 16) + s;
s = ui64to_string_base10(uint64_t(copy.selfDivRem(BASE)), 16) + s;
return std::to_string(uint64_t(copy.selfDivRem(BASE))) + s;
}
void printDec() const
{
std::cout << std::dec << toStringBase10() << '\n';
}
void printHex() const
{
std::cout << std::hex << "0x" << high << ' ' << low << std::dec << '\n';
}
};
class Uint192
{
private:
uint64_t low, mid, high;
public:
constexpr Uint192(uint64_t low = 0, uint64_t mi = 0, uint64_t high = 0) : low(low), mid(mi), high(high) {}
constexpr Uint192(Uint128 n) : low(n.low64()), mid(n.high64()), high(0) {}
constexpr Uint192 operator+(Uint192 rhs) const
{
bool cf = false;
rhs.low = add_half(low, rhs.low, cf);
rhs.mid = add_carry(mid, rhs.mid, cf);
rhs.high = high + rhs.high + cf;
return rhs;
}
constexpr Uint192 operator-(Uint192 rhs) const
{
bool bf = false;
rhs.low = sub_half(low, rhs.low, bf);
rhs.mid = sub_borrow(mid, rhs.mid, bf);
rhs.high = high - rhs.high - bf;
return rhs;
}
constexpr Uint192 operator/(uint64_t rhs) const
{
Uint192 result(*this);
result.selfDivRem(rhs);
return result;
}
constexpr Uint192 operator%(uint64_t rhs) const
{
Uint192 result(*this);
return result.selfDivRem(rhs);
}
constexpr Uint192 &operator+=(const Uint192 &rhs)
{
return *this = *this + rhs;
}
constexpr Uint192 &operator-=(const Uint192 &rhs)
{
return *this = *this - rhs;
}
constexpr Uint192 &operator/=(const Uint192 &rhs)
{
return *this = *this / rhs;
}
constexpr Uint192 &operator%=(const Uint192 &rhs)
{
return *this = *this % rhs;
}
constexpr Uint192 operator<<(int shift) const
{
if (shift == 0)
{
return *this;
}
shift %= 192;
shift = shift < 0 ? shift + 192 : shift;
if (shift < 64)
{
return Uint192(low << shift, (mid << shift) | (low >> (64 - shift)), (high << shift) | (mid >> (64 - shift)));
}
else if (shift < 128)
{
shift -= 64;
return Uint192(0, low << shift, (mid << shift) | (low >> (64 - shift)));
}
return Uint192(0, 0, low << (shift - 128));
}
constexpr bool operator<(Uint192 rhs) const
{
if (high != rhs.high)
{
return high < rhs.high;
}
if (mid != rhs.mid)
{
return mid < rhs.mid;
}
return low < rhs.low;
}
constexpr bool operator==(Uint192 rhs) const
{
return high == rhs.high && mid == rhs.mid && low == rhs.low;
}
static Uint192 mul128x64(Uint128 a, uint64_t b)
{
auto prod1 = mul64x64to128(b, a.low64());
auto prod2 = mul64x64to128(b, a.high64());
Uint192 result;
result.low = prod1.first;
result.mid = prod1.second + prod2.first;
result.high = prod2.second + (result.mid < prod1.second);
return result;
}
static constexpr Uint192 mul64x64x64(uint64_t a, uint64_t b, uint64_t c)
{
auto prod0 = mul64x64to128_base(a, b);
auto prod1 = mul64x64to128_base(c, prod0.first);
auto prod2 = mul64x64to128_base(c, prod0.second);
Uint192 result;
result.low = prod1.first;
result.mid = prod1.second + prod2.first;
result.high = prod2.second + (result.mid < prod1.second);
return result;
}
constexpr uint64_t selfDivRem(uint64_t divisor)
{
uint64_t divid1 = high % divisor, divid0 = mid;
high /= divisor;
mid = div128by64to64(divid1, divid0, divisor);
divid1 = divid0, divid0 = low;
low = div128by64to64(divid1, divid0, divisor);
return divid0;
}
constexpr Uint192 rShift64() const
{
return Uint192(mid, high, 0);
}
constexpr operator uint64_t() const
{
return low;
}
std::string toStringBase10() const
{
if (high == 0)
{
return Uint128(mid, low).toStringBase10();
}
constexpr uint64_t BASE(10000'0000'0000'0000);
Uint192 copy(*this);
std::string s;
s = ui64to_string_base10(uint64_t(copy.selfDivRem(BASE)), 16) + s;
s = ui64to_string_base10(uint64_t(copy.selfDivRem(BASE)), 16) + s;
s = ui64to_string_base10(uint64_t(copy.selfDivRem(BASE)), 16) + s;
return std::to_string(uint64_t(copy.selfDivRem(BASE))) + s;
}
void printDec() const
{
std::cout << std::dec << toStringBase10() << '\n';
}
void printHex() const
{
std::cout << std::hex << "0x" << high << ' ' << mid << ' ' << low << std::dec << '\n';
}
};
template <typename Int128Type>
constexpr uint64_t high64(const Int128Type &n)
{
return n >> 64;
}
constexpr uint64_t high64(const Uint128 &n)
{
return n.high64();
}
#ifdef UINT128T
using Uint128Default = __uint128_t;
#else
using Uint128Default = Uint128;
#endif // UINT128T
// Montgomery for mod > 2^32
// default R = 2^64
template <uint64_t MOD, typename Int128Type = Uint128>
class MontInt64Lazy
{
private:
static_assert(MOD > UINT32_MAX, "Montgomery64 modulus must be greater than 2^32");
static_assert(hint_log2(MOD) < 62, "MOD can't be larger than 62 bits");
uint64_t data;
public:
using IntType = uint64_t;
constexpr MontInt64Lazy() : data(0) {}
constexpr MontInt64Lazy(uint64_t n) : data(mulMontCompileTime(n, rSquare())) {}
constexpr MontInt64Lazy operator+(MontInt64Lazy rhs) const
{
rhs.data = data + rhs.data;
rhs.data = rhs.data < mod2() ? rhs.data : rhs.data - mod2();
return rhs;
}
constexpr MontInt64Lazy operator-(MontInt64Lazy rhs) const
{
rhs.data = data - rhs.data;
rhs.data = rhs.data > data ? rhs.data + mod2() : rhs.data;
return rhs;
}
MontInt64Lazy operator*(MontInt64Lazy rhs) const
{
rhs.data = mulMontRunTimeLazy(data, rhs.data);
return rhs;
}
constexpr MontInt64Lazy &operator+=(const MontInt64Lazy &rhs)
{
return *this = *this + rhs;
}
constexpr MontInt64Lazy &operator-=(const MontInt64Lazy &rhs)
{
return *this = *this - rhs;
}
constexpr MontInt64Lazy &operator*=(const MontInt64Lazy &rhs)
{
data = mulMontCompileTime(data, rhs.data);
return *this;
}
constexpr MontInt64Lazy largeNorm() const
{
MontInt64Lazy res;
res.data = data >= mod2() ? data - mod2() : data;
return res;
}
constexpr MontInt64Lazy add(MontInt64Lazy rhs) const
{
rhs.data = data + rhs.data;
return rhs;
}
constexpr MontInt64Lazy sub(MontInt64Lazy rhs) const
{
rhs.data = data - rhs.data + mod2();
return rhs;
}
constexpr operator uint64_t() const
{
return toInt(data);
}
static constexpr uint64_t mod()
{
return MOD;
}
static constexpr uint64_t mod2()
{
return MOD * 2;
}
static constexpr uint64_t modInv()
{
constexpr uint64_t mod_inv = inv_mod2pow(mod(), 64); //(mod_inv * mod)%(2^64) = 1
return mod_inv;
}
static constexpr uint64_t modInvNeg()
{
constexpr uint64_t mod_inv_neg = uint64_t(0 - modInv()); //(mod_inv_neg + mod_inv)%(2^64) = 0
return mod_inv_neg;
}
static constexpr uint64_t rSquare()
{
constexpr Int128Type r = (Int128Type(1) << 64) % Int128Type(mod()); // R % mod
constexpr uint64_t r2 = uint64_t(qpow(r, 2, Int128Type(mod()))); // R^2 % mod
return r2;
}
static_assert((mod() * modInv()) == 1, "mod_inv not correct");
static constexpr uint64_t toMont(uint64_t n)
{
return mulMontCompileTime(n, rSquare());
}
static constexpr uint64_t toInt(uint64_t n)
{
return redc(Int128Type(n));
}
static uint64_t redcFastLazy(const Int128Type &input)
{
Int128Type n = uint64_t(input) * modInvNeg();
n = n * mod();
n += input;
return high64(n);
}
static uint64_t redcFast(const Int128Type &input)
{
uint64_t n = redcFastLazy(input);
return n < mod() ? n : n - mod();
}
static constexpr uint64_t redc(const Int128Type &input)
{
Int128Type n = uint64_t(input) * modInvNeg();
n *= Int128Type(mod());
n += input;
uint64_t m = high64(n);
return m < mod() ? m : m - mod();
}
static uint64_t mulMontRunTime(uint64_t a, uint64_t b)
{
return redcFast(Int128Type(a) * b);
}
static uint64_t mulMontRunTimeLazy(uint64_t a, uint64_t b)
{
return redcFastLazy(Int128Type(a) * b);
}
static constexpr uint64_t mulMontCompileTime(uint64_t a, uint64_t b)
{
Int128Type prod(a);
prod *= Int128Type(b);
return redc(prod);
}
};
// Montgomery for mod < 2^30
// default R = 2^32
template <uint32_t MOD>
class MontInt32Lazy
{
private:
static_assert(hint_log2(MOD) < 30, "MOD can't be larger than 30 bits");
uint32_t data;
public:
using IntType = uint32_t;
constexpr MontInt32Lazy() : data(0) {}
constexpr MontInt32Lazy(uint32_t n) : data(toMont(n)) {}
constexpr MontInt32Lazy operator+(MontInt32Lazy rhs) const
{
rhs.data = data + rhs.data;
rhs.data = rhs.data < mod2() ? rhs.data : rhs.data - mod2();
return rhs;
}
constexpr MontInt32Lazy operator-(MontInt32Lazy rhs) const
{
rhs.data = data - rhs.data;
rhs.data = rhs.data > data ? rhs.data + mod2() : rhs.data;
return rhs;
}
constexpr MontInt32Lazy operator*(MontInt32Lazy rhs) const
{
rhs.data = redcLazy(uint64_t(data) * rhs.data);
return rhs;
}
constexpr MontInt32Lazy &operator+=(const MontInt32Lazy &rhs)
{
return *this = *this + rhs;
}
constexpr MontInt32Lazy &operator-=(const MontInt32Lazy &rhs)
{
return *this = *this - rhs;
}
constexpr MontInt32Lazy &operator*=(const MontInt32Lazy &rhs)
{
data = redc(uint64_t(data) * rhs.data);
return *this;
}
constexpr MontInt32Lazy largeNorm() const
{
MontInt32Lazy res;
res.data = data >= mod2() ? data - mod2() : data;
return res;
}
constexpr MontInt32Lazy add(MontInt32Lazy rhs) const
{
rhs.data = data + rhs.data;
return rhs;
}
constexpr MontInt32Lazy sub(MontInt32Lazy rhs) const
{
rhs.data = data - rhs.data + mod2();
return rhs;
}
constexpr operator uint32_t() const
{
return toInt(data);
}
static constexpr uint32_t mod()
{
return MOD;
}
static constexpr uint32_t mod2()
{
return MOD * 2;
}
static constexpr uint32_t modInv()
{
constexpr uint32_t mod_inv = uint32_t(inv_mod2pow(mod(), 32));
return mod_inv;
}
static constexpr uint32_t modNegInv()
{
constexpr uint32_t mod_neg_inv = uint32_t(0 - modInv());
return mod_neg_inv;
}
static_assert((mod() * modInv()) == 1, "mod_inv not correct");
static constexpr uint32_t toMont(uint32_t n)
{
return (uint64_t(n) << 32) % MOD;
}
static constexpr uint32_t toInt(uint32_t n)
{
return redc(n);
}
static constexpr uint32_t redcLazy(uint64_t n)
{
uint32_t prod = uint32_t(n) * modNegInv();
return (uint64_t(prod) * mod() + n) >> 32;
}
static constexpr uint32_t redc(uint64_t n)
{
uint32_t res = redcLazy(n);
return res < mod() ? res : res - mod();
}
};
// ModInt for mod < 2^32
template <uint32_t MOD>
class ModInt32
{
private:
uint32_t data;
public:
using IntType = uint32_t;
constexpr ModInt32() {}
constexpr ModInt32(uint32_t in) : data(in) {}
constexpr ModInt32 operator+(ModInt32 in) const
{
uint32_t diff = MOD - data;
return in.data > diff ? in.data - diff : in.data + data;
}
constexpr ModInt32 operator-(ModInt32 in) const
{
in.data = data - in.data;
return in.data > data ? in.data + MOD : in.data;
}
constexpr ModInt32 operator*(ModInt32 in) const
{
return mul64(in) % MOD;
}
constexpr ModInt32 &operator+=(ModInt32 in)
{
return *this = *this + in;
}
constexpr ModInt32 &operator-=(ModInt32 in)
{
return *this = *this - in;
}
constexpr ModInt32 &operator*=(ModInt32 in)
{
return *this = *this * in;
}
constexpr ModInt32 largeNorm() const
{
return data;
}
constexpr ModInt32 add(ModInt32 n) const
{
return *this + n;
}
constexpr ModInt32 sub(ModInt32 n) const
{
return *this - n;
}
constexpr uint64_t mul64(ModInt32 in) const
{
return uint64_t(data) * in.data;
}
constexpr operator uint32_t() const
{
return data;
}
static constexpr uint32_t mod()
{
return MOD;
}
};
template <typename IntType>
constexpr bool check_inv(uint64_t n, uint64_t n_inv, uint64_t mod)
{
n %= mod;
n_inv %= mod;
IntType m(n);
m *= IntType(n_inv);
m %= IntType(mod);
return m == IntType(1);
}
// 快速计算两模数的中国剩余定理
template <uint32_t MOD1, uint32_t MOD2>
inline uint64_t crt2(uint32_t num1, uint32_t num2)
{
constexpr uint64_t inv1 = mod_inv<int64_t>(MOD1, MOD2);
constexpr uint64_t inv2 = mod_inv<int64_t>(MOD2, MOD1);
static_assert(check_inv<uint64_t>(inv1, MOD1, MOD2), "Inv1 error");
static_assert(check_inv<uint64_t>(inv2, MOD2, MOD1), "Inv2 error");
if (num1 > num2)
{
return (uint64_t(num1 - num2) * uint64_t(inv2) % MOD1) * MOD2 + num2;
}
else
{
return (uint64_t(num2 - num1) * uint64_t(inv1) % MOD2) * MOD1 + num1;
}
}
// 快速计算两模数的中国剩余定理
template <uint64_t MOD1, uint64_t MOD2, typename Int128Type = Uint128>
inline Int128Type crt2(uint64_t num1, uint64_t num2)
{
constexpr uint64_t inv1 = mod_inv<int64_t>(MOD1, MOD2);
constexpr uint64_t inv2 = mod_inv<int64_t>(MOD2, MOD1);
static_assert(check_inv<Int128Type>(inv1, MOD1, MOD2), "Inv1 error");
static_assert(check_inv<Int128Type>(inv2, MOD2, MOD1), "Inv2 error");
if (num1 > num2)
{
return (Int128Type(num1 - num2) * Int128Type(inv2) % Int128Type(MOD1)) * Int128Type(MOD2) + num2;
}
else
{
return (Int128Type(num2 - num1) * Int128Type(inv1) % Int128Type(MOD2)) * Int128Type(MOD1) + num1;
}
}
// 快速计算两模数的中国剩余定理
template <typename ModInt1, typename ModInt2>
inline Uint128 crt2(ModInt1 num1, ModInt2 num2)
{
constexpr uint64_t MOD1 = ModInt1::mod();
constexpr uint64_t MOD2 = ModInt2::mod();
constexpr ModInt1 MOD2_INV1 = mod_inv<int64_t>(MOD2, MOD1);
constexpr ModInt2 MOD1_INV2 = mod_inv<int64_t>(MOD1, MOD2);
constexpr Uint128 MOD12 = Uint128::mul64x64(MOD1, MOD2);
static_assert(check_inv<Uint128>(MOD1, MOD1_INV2, MOD2), "INV1 error");
static_assert(check_inv<Uint128>(MOD2, MOD2_INV1, MOD1), "INV2 error");
num1 = num1 * MOD2_INV1;
num2 = num2 * MOD1_INV2;
Uint128 result = Uint128(MOD2) * uint64_t(num1);
result += Uint128(MOD1) * uint64_t(num2);
return result < MOD12 ? result : result - MOD12;
}
// 快速计算三模数的中国剩余定理
template <typename ModInt1, typename ModInt2, typename ModInt3>
inline Uint192 crt3(ModInt1 n1, ModInt2 n2, ModInt3 n3)
{
constexpr uint64_t MOD1 = ModInt1::mod(), MOD2 = ModInt2::mod(), MOD3 = ModInt3::mod();
constexpr Uint192 MOD123 = Uint192::mul64x64x64(MOD1, MOD2, MOD3); // MOD1*MOD2*MOD3
constexpr Uint128 MOD12 = Uint128::mul64x64(MOD1, MOD2); // MOD1*MOD2
constexpr Uint128 MOD23 = Uint128::mul64x64(MOD2, MOD3); // MOD2*MOD3
constexpr Uint128 MOD13 = Uint128::mul64x64(MOD1, MOD3); // MOD1*MOD3
constexpr uint64_t MOD23_M1 = Uint128::mul64x64(MOD2 % MOD1, MOD3 % MOD1) % Uint128(MOD1); // (MOD2*MOD3) mod MOD1
constexpr uint64_t MOD13_M2 = Uint128::mul64x64(MOD1 % MOD2, MOD3 % MOD2) % Uint128(MOD2); // (MOD1*MOD3) mod MOD2
constexpr uint64_t MOD12_M3 = Uint128::mul64x64(MOD1 % MOD3, MOD2 % MOD3) % Uint128(MOD3); // (MOD1*MOD2) mod MOD3
constexpr ModInt1 MOD23_INV1 = mod_inv<int64_t>(MOD23_M1, MOD1); // (MOD2*MOD3)^-1 mod MOD1
constexpr ModInt2 MOD13_INV2 = mod_inv<int64_t>(MOD13_M2, MOD2); // (MOD1*MOD3)^-1 mod MOD2
constexpr ModInt3 MOD12_INV3 = mod_inv<int64_t>(MOD12_M3, MOD3); // (MOD1*MOD2)^-1 mod MOD3
static_assert(check_inv<Uint128>(MOD23_INV1, MOD23_M1, MOD1), "INV1 error");
static_assert(check_inv<Uint128>(MOD13_INV2, MOD13_M2, MOD2), "INV2 error");
static_assert(check_inv<Uint128>(MOD12_INV3, MOD12_M3, MOD3), "INV3 error");
n1 = n1 * MOD23_INV1;
n2 = n2 * MOD13_INV2;
n3 = n3 * MOD12_INV3;
Uint192 result = Uint192::mul128x64(MOD23, uint64_t(n1));
result += Uint192::mul128x64(MOD13, uint64_t(n2));
result += Uint192::mul128x64(MOD12, uint64_t(n3));
result = result < MOD123 ? result : result - MOD123;
return result < MOD123 ? result : result - MOD123;
}
namespace split_radix
{
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w41(ModIntType n)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
return n * W_4_1;
}
// in: in_out0<4p, in_ou1<4p; in_out2<2p, in_ou3<2p
// out: in_out0<4p, in_ou1<4p; in_out2<4p, in_ou3<4p
template <uint64_t ROOT, typename ModIntType>
inline void dit_butterfly244(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3)
{
ModIntType temp0, temp1, temp2, temp3;
temp0 = in_out0.largeNorm();
temp1 = in_out1.largeNorm();
temp2 = in_out2 + in_out3;
temp3 = in_out2.sub(in_out3);
temp3 = mul_w41<ROOT>(temp3);
in_out0 = temp0.add(temp2);
in_out2 = temp0.sub(temp2);
in_out1 = temp1.add(temp3);
in_out3 = temp1.sub(temp3);
}
// in: in_out0<2p, in_ou1<2p; in_out2<2p, in_ou3<2p
// out: in_out0<2p, in_ou1<2p; in_out2<4p, in_ou3<4p
template <uint64_t ROOT, typename ModIntType>
inline void dif_butterfly244(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3)
{
ModIntType temp0, temp1, temp2, temp3;
temp0 = in_out0.add(in_out2);
temp2 = in_out0 - in_out2;
temp1 = in_out1.add(in_out3);
temp3 = in_out1.sub(in_out3);
temp3 = mul_w41<ROOT>(temp3);
in_out0 = temp0.largeNorm();
in_out1 = temp1.largeNorm();
in_out2 = temp2.add(temp3);
in_out3 = temp2.sub(temp3);
}
// in: in_out0<4p, in_ou1<4p
// out: in_out0<4p, in_ou1<4p
template <typename ModIntType>
inline void dit_butterfly2(ModIntType &in_out0, ModIntType &in_out1, const ModIntType &omega)
{
auto x = in_out0.largeNorm();
auto y = in_out1 * omega;
in_out0 = x.add(y);
in_out1 = x.sub(y);
}
// in: in_out0<2p, in_ou1<2p
// out: in_out0<2p, in_ou1<2p
template <typename ModIntType>
inline void dif_butterfly2(ModIntType &in_out0, ModIntType &in_out1, const ModIntType &omega)
{
auto x = in_out0 + in_out1;
auto y = in_out0.sub(in_out1);
in_out0 = x;
in_out1 = y * omega;
}
template <size_t MAX_LEN, uint64_t ROOT, typename ModIntType>
struct NTTShort
{
static constexpr size_t NTT_LEN = MAX_LEN;
static constexpr int LOG_LEN = hint_log2(NTT_LEN);
struct TableType
{
std::array<ModIntType, NTT_LEN> omega_table;
// Compute in compile time if need.
/*constexpr*/ TableType()
{
for (int omega_log_len = 0; omega_log_len <= LOG_LEN; omega_log_len++)
{
size_t omega_len = size_t(1) << omega_log_len, omega_count = omega_len / 2;
auto it = &omega_table[omega_len / 2];
ModIntType root = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / omega_len);
ModIntType omega(1);
for (size_t i = 0; i < omega_count; i++)
{
it[i] = omega;
omega *= root;
}
}
}
constexpr ModIntType &operator[](size_t i)
{
return omega_table[i];
}
constexpr const ModIntType &operator[](size_t i) const
{
return omega_table[i];
}
constexpr const ModIntType *getOmegaIt(size_t len) const
{
return &omega_table[len / 2];
}
};
static TableType table;
static void dit(ModIntType in_out[], size_t len)
{
len = std::min(NTT_LEN, len);
size_t rank = len;
if (hint_log2(len) % 2 == 0)
{
NTTShort<4, ROOT, ModIntType>::dit(in_out, len);
for (size_t i = 4; i < len; i += 4)
{
NTTShort<4, ROOT, ModIntType>::dit(in_out + i);
}
rank = 16;
}
else
{
NTTShort<8, ROOT, ModIntType>::dit(in_out, len);
for (size_t i = 8; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dit(in_out + i);
}
rank = 32;
}
for (; rank <= len; rank *= 4)
{
size_t gap = rank / 4;
auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
auto it0 = in_out, it1 = in_out + gap, it2 = in_out + gap * 2, it3 = in_out + gap * 3;
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i++)
{
auto temp0 = it0[j + i], temp1 = it1[j + i], temp2 = it2[j + i], temp3 = it3[j + i], omega = last_omega_it[i];
dit_butterfly2(temp0, temp1, omega);
dit_butterfly2(temp2, temp3, omega);
dit_butterfly2(temp0, temp2, omega_it[i]);
dit_butterfly2(temp1, temp3, omega_it[gap + i]);
it0[j + i] = temp0, it1[j + i] = temp1, it2[j + i] = temp2, it3[j + i] = temp3;
}
}
}
}
static void dif(ModIntType in_out[], size_t len)
{
len = std::min(NTT_LEN, len);
size_t rank = len;
for (; rank >= 16; rank /= 4)
{
size_t gap = rank / 4;
auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
auto it0 = in_out, it1 = in_out + gap, it2 = in_out + gap * 2, it3 = in_out + gap * 3;
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i++)
{
auto temp0 = it0[j + i], temp1 = it1[j + i], temp2 = it2[j + i], temp3 = it3[j + i], omega = last_omega_it[i];
dif_butterfly2(temp0, temp2, omega_it[i]);
dif_butterfly2(temp1, temp3, omega_it[gap + i]);
dif_butterfly2(temp0, temp1, omega);
dif_butterfly2(temp2, temp3, omega);
it0[j + i] = temp0, it1[j + i] = temp1, it2[j + i] = temp2, it3[j + i] = temp3;
}
}
}
if (hint_log2(rank) % 2 == 0)
{
NTTShort<4, ROOT, ModIntType>::dif(in_out, len);
for (size_t i = 4; i < len; i += 4)
{
NTTShort<4, ROOT, ModIntType>::dif(in_out + i);
}
}
else
{
NTTShort<8, ROOT, ModIntType>::dif(in_out, len);
for (size_t i = 8; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dif(in_out + i);
}
}
}
};
template <size_t LEN, uint64_t ROOT, typename ModIntType>
typename NTTShort<LEN, ROOT, ModIntType>::TableType NTTShort<LEN, ROOT, ModIntType>::table;
template <size_t LEN, uint64_t ROOT, typename ModIntType>
constexpr size_t NTTShort<LEN, ROOT, ModIntType>::NTT_LEN;
template <size_t LEN, uint64_t ROOT, typename ModIntType>
constexpr int NTTShort<LEN, ROOT, ModIntType>::LOG_LEN;
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<0, ROOT, ModIntType>
{
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<1, ROOT, ModIntType>
{
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<2, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dif(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dif(in_out);
}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<4, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
auto temp0 = in_out[0].largeNorm();
auto temp1 = in_out[1].largeNorm();
auto temp2 = in_out[2].largeNorm();
auto temp3 = in_out[3].largeNorm();
transform2(temp0, temp1);
auto sum = temp2.add(temp3);
auto dif = temp2.sub(temp3);
temp2 = sum.largeNorm();
temp3 = mul_w41<ROOT>(dif);
in_out[0] = temp0.add(temp2);
in_out[1] = temp1.add(temp3);
in_out[2] = temp0.sub(temp2);
in_out[3] = temp1.sub(temp3);
}
static void dif(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp2);
auto sum = temp1.add(temp3);
auto dif = temp1.sub(temp3);
temp1 = sum.largeNorm();
temp3 = mul_w41<ROOT>(dif);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, ROOT, ModIntType>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, ROOT, ModIntType>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<8, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0].largeNorm();
auto temp1 = in_out[1].largeNorm();
auto temp2 = in_out[2].largeNorm();
auto temp3 = in_out[3].largeNorm();
auto temp4 = in_out[4].largeNorm();
auto temp5 = in_out[5].largeNorm();
auto temp6 = in_out[6].largeNorm();
auto temp7 = in_out[7].largeNorm();
transform2(temp0, temp1);
transform2(temp4, temp5);
auto sum = temp2.add(temp3);
auto dif = temp2.sub(temp3);
temp2 = sum.largeNorm();
temp3 = mul_w41<ROOT>(dif);
sum = temp6.add(temp7);
dif = temp6.sub(temp7);
temp6 = sum.largeNorm();
temp7 = mul_w41<ROOT>(dif);
transform2(temp0, temp2);
transform2(temp1, temp3);
sum = temp4.add(temp6);
dif = temp4.sub(temp6);
temp4 = sum.largeNorm();
temp6 = dif * w2;
sum = temp5.add(temp7);
dif = temp5.sub(temp7);
temp5 = sum * w1;
temp7 = dif * w3;
in_out[0] = temp0.add(temp4);
in_out[1] = temp1.add(temp5);
in_out[2] = temp2.add(temp6);
in_out[3] = temp3.add(temp7);
in_out[4] = temp0.sub(temp4);
in_out[5] = temp1.sub(temp5);
in_out[6] = temp2.sub(temp6);
in_out[7] = temp3.sub(temp7);
}
static void dif(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp4);
auto sum = temp1.add(temp5);
auto dif = temp1.sub(temp5);
temp1 = sum.largeNorm();
temp5 = dif * w1;
sum = temp2.add(temp6);
dif = temp2.sub(temp6);
temp2 = sum.largeNorm();
temp6 = dif * w2;
sum = temp3.add(temp7);
dif = temp3.sub(temp7);
temp3 = sum.largeNorm();
temp7 = dif * w3;
transform2(temp0, temp2);
transform2(temp4, temp6);
sum = temp1.add(temp3);
dif = temp1.sub(temp3);
temp1 = sum.largeNorm();
temp3 = mul_w41<ROOT>(dif);
sum = temp5.add(temp7);
dif = temp5.sub(temp7);
temp5 = sum.largeNorm();
temp7 = mul_w41<ROOT>(dif);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
in_out[4] = temp4 + temp5;
in_out[5] = temp4 - temp5;
in_out[6] = temp6 + temp7;
in_out[7] = temp6 - temp7;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, ROOT, ModIntType>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, ROOT, ModIntType>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint64_t MOD, uint64_t ROOT, typename Int128Type = Uint128Default>
struct NTT
{
static constexpr uint64_t mod()
{
return MOD;
}
static constexpr uint64_t root()
{
return ROOT;
}
static constexpr uint64_t rootInv()
{
constexpr uint64_t IROOT = mod_inv<int64_t>(ROOT, MOD);
return IROOT;
}
static_assert(root() < mod(), "ROOT must be smaller than MOD");
static_assert(check_inv<Int128Type>(root(), rootInv(), mod()), "IROOT * ROOT % MOD must be 1");
static constexpr int MOD_BITS = hint_log2(mod()) + 1;
static constexpr int MAX_LOG_LEN = hint_ctz(mod() - 1);
static constexpr size_t getMaxLen()
{
if (MAX_LOG_LEN < sizeof(size_t) * CHAR_BIT)
{
return size_t(1) << MAX_LOG_LEN;
}
return size_t(1) << (sizeof(size_t) * CHAR_BIT - 1);
}
static constexpr size_t NTT_MAX_LEN = getMaxLen();
using INTT = NTT<mod(), rootInv(), Int128Type>;
using ModInt32Type = typename std::conditional<(MOD_BITS > 30), ModInt32<uint32_t(MOD)>, MontInt32Lazy<uint32_t(MOD)>>::type;
using ModInt64Type = MontInt64Lazy<MOD, Int128Type>;
using ModIntType = typename std::conditional<(MOD_BITS > 32), ModInt64Type, ModInt32Type>::type;
using IntType = typename ModIntType::IntType;
static constexpr size_t L2_BYTE = size_t(1) << 20; // 1MB L2 cache size, change this if you know your cache size.
static constexpr size_t LONG_THRESHOLD = std::min(L2_BYTE / sizeof(ModIntType), NTT_MAX_LEN);
using NTTTemplate = NTTShort<LONG_THRESHOLD, root(), ModIntType>;
static void dit244(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), NTT_MAX_LEN);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dit(in_out, ntt_len);
return;
}
size_t quarter_len = ntt_len / 4;
dit244(in_out + quarter_len * 3, ntt_len / 4);
dit244(in_out + quarter_len * 2, ntt_len / 4);
dit244(in_out, ntt_len / 2);
const ModIntType unit_omega1 = qpow(ModIntType(root()), (mod() - 1) / ntt_len);
const ModIntType unit_omega3 = qpow(unit_omega1, 3);
ModIntType omega1(1), omega3(1);
auto it0 = in_out, it1 = in_out + quarter_len, it2 = in_out + quarter_len * 2, it3 = in_out + quarter_len * 3;
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = it0[i], temp1 = it1[i], temp2 = it2[i] * omega1, temp3 = it3[i] * omega3;
dit_butterfly244<ROOT>(temp0, temp1, temp2, temp3);
it0[i] = temp0, it1[i] = temp1, it2[i] = temp2, it3[i] = temp3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
}
static void dif244(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), NTT_MAX_LEN);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dif(in_out, ntt_len);
return;
}
size_t quarter_len = ntt_len / 4;
const ModIntType unit_omega1 = qpow(ModIntType(root()), (mod() - 1) / ntt_len);
const ModIntType unit_omega3 = qpow(unit_omega1, 3);
ModIntType omega1(1), omega3(1);
auto it0 = in_out, it1 = in_out + quarter_len, it2 = in_out + quarter_len * 2, it3 = in_out + quarter_len * 3;
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = it0[i], temp1 = it1[i], temp2 = it2[i], temp3 = it3[i];
dif_butterfly244<ROOT>(temp0, temp1, temp2, temp3);
it0[i] = temp0, it1[i] = temp1, it2[i] = temp2 * omega1, it3[i] = temp3 * omega3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
dif244(in_out, ntt_len / 2);
dif244(in_out + quarter_len * 3, ntt_len / 4);
dif244(in_out + quarter_len * 2, ntt_len / 4);
}
static void convolution(ModIntType in1[], ModIntType in2[], ModIntType out[], size_t ntt_len, bool normlize = true)
{
dif244(in1, ntt_len);
dif244(in2, ntt_len);
if (normlize)
{
const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i] * inv_len;
}
}
else
{
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i];
}
}
INTT::dit244(out, ntt_len);
}
static void convolutionRecursion(ModIntType in1[], ModIntType in2[], ModIntType out[], size_t ntt_len, bool normlize = true)
{
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dif(in1, ntt_len);
if (in1 != in2)
{
NTTTemplate::dif(in2, ntt_len);
}
if (normlize)
{
const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i] * inv_len;
}
}
else
{
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i];
}
}
INTT::NTTTemplate::dit(out, ntt_len);
return;
}
const size_t quarter_len = ntt_len / 4;
ModIntType unit_omega1 = qpow(ModIntType(root()), (mod() - 1) / ntt_len);
ModIntType unit_omega3 = qpow(unit_omega1, 3);
ModIntType omega1(1), omega3(1);
if (in1 != in2)
{
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = in1[i], temp1 = in1[quarter_len + i], temp2 = in1[quarter_len * 2 + i], temp3 = in1[quarter_len * 3 + i];
dif_butterfly244<ROOT>(temp0, temp1, temp2, temp3);
in1[i] = temp0, in1[quarter_len + i] = temp1, in1[quarter_len * 2 + i] = temp2 * omega1, in1[quarter_len * 3 + i] = temp3 * omega3;
temp0 = in2[i], temp1 = in2[quarter_len + i], temp2 = in2[quarter_len * 2 + i], temp3 = in2[quarter_len * 3 + i];
dif_butterfly244<ROOT>(temp0, temp1, temp2, temp3);
in2[i] = temp0, in2[quarter_len + i] = temp1, in2[quarter_len * 2 + i] = temp2 * omega1, in2[quarter_len * 3 + i] = temp3 * omega3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
}
else
{
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = in1[i], temp1 = in1[quarter_len + i], temp2 = in1[quarter_len * 2 + i], temp3 = in1[quarter_len * 3 + i];
dif_butterfly244<ROOT>(temp0, temp1, temp2, temp3);
in1[i] = temp0, in1[quarter_len + i] = temp1, in1[quarter_len * 2 + i] = temp2 * omega1, in1[quarter_len * 3 + i] = temp3 * omega3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
}
convolutionRecursion(in1, in2, out, ntt_len / 2, false);
convolutionRecursion(in1 + quarter_len * 2, in2 + quarter_len * 2, out + quarter_len * 2, ntt_len / 4, false);
convolutionRecursion(in1 + quarter_len * 3, in2 + quarter_len * 3, out + quarter_len * 3, ntt_len / 4, false);
unit_omega1 = qpow(ModIntType(rootInv()), (mod() - 1) / ntt_len);
unit_omega3 = qpow(unit_omega1, 3);
if (normlize)
{
const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
omega1 = inv_len, omega3 = inv_len;
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = out[i] * inv_len, temp1 = out[quarter_len + i] * inv_len, temp2 = out[quarter_len * 2 + i] * omega1, temp3 = out[quarter_len * 3 + i] * omega3;
dit_butterfly244<rootInv()>(temp0, temp1, temp2, temp3);
out[i] = temp0, out[quarter_len + i] = temp1, out[quarter_len * 2 + i] = temp2, out[quarter_len * 3 + i] = temp3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
}
else
{
omega1 = 1, omega3 = 1;
for (size_t i = 0; i < quarter_len; i++)
{
ModIntType temp0 = out[i], temp1 = out[quarter_len + i], temp2 = out[quarter_len * 2 + i] * omega1, temp3 = out[quarter_len * 3 + i] * omega3;
dit_butterfly244<rootInv()>(temp0, temp1, temp2, temp3);
out[i] = temp0, out[quarter_len + i] = temp1, out[quarter_len * 2 + i] = temp2, out[quarter_len * 3 + i] = temp3;
omega1 = omega1 * unit_omega1;
omega3 = omega3 * unit_omega3;
}
}
}
};
template <uint64_t MOD, uint64_t ROOT, typename Int128Type>
constexpr int NTT<MOD, ROOT, Int128Type>::MOD_BITS;
template <uint64_t MOD, uint64_t ROOT, typename Int128Type>
constexpr int NTT<MOD, ROOT, Int128Type>::MAX_LOG_LEN;
template <uint64_t MOD, uint64_t ROOT, typename Int128Type>
constexpr size_t NTT<MOD, ROOT, Int128Type>::NTT_MAX_LEN;
} // namespace split_radix
using NTT0 = split_radix::NTT<MOD0, ROOT0>; // using 64bit integer, Montgomery speed up
using NTT1 = split_radix::NTT<MOD1, ROOT1>; // using 64bit integer, Montgomery speed up
using NTT2 = split_radix::NTT<MOD2, ROOT2>; // using 64bit integer, Montgomery speed up
using NTT3 = split_radix::NTT<MOD3, ROOT3>; // using 32bit integer, Montgomery speed up
using NTT4 = split_radix::NTT<MOD4, ROOT4>; // using 32bit integer, Montgomery speed up
using NTT5 = split_radix::NTT<MOD5, ROOT5>; // using 32bit integer
using NTT6 = split_radix::NTT<MOD6, ROOT6>; // using 32bit integer
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
using namespace std;
using namespace hint;
using namespace transform::ntt;
size_t conv_len = m + n + 1, ntt_len = int_ceil2(conv_len);
using NTT = split_radix::NTT<998244353, 3>;
using ModInt = NTT::ModIntType;
ModInt *a_ntt = new ModInt[ntt_len];
ModInt *b_ntt = new ModInt[ntt_len];
std::fill(a_ntt + n + 1, a_ntt + ntt_len, ModInt{});
std::fill(b_ntt + m + 1, b_ntt + ntt_len, ModInt{});
std::copy(a, a + n + 1, a_ntt);
std::copy(b, b + m + 1, b_ntt);
NTT::convolutionRecursion(a_ntt, b_ntt, a_ntt, ntt_len);
size_t rem_len = conv_len % 16, i = 0;
for (; i < conv_len; i++)
{
c[i] = uint32_t(a_ntt[i]);
}
delete[] a_ntt;
delete[] b_ntt;
}
#include <cstring>
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
const char *tostr(uint16_t num) const
{
return reinterpret_cast<const char *>(&table[num]);
}
};
class Qin
{
private:
char *data = nullptr;
char *p = nullptr;
public:
Qin(size_t max_len)
{
p = data = new char[max_len];
fread(data, 1, max_len, stdin);
}
~Qin()
{
if (data != nullptr)
{
delete[] data;
}
}
void skipSpace()
{
while (*p <= ' ')
{
p++;
}
}
template <typename T>
Qin &operator>>(T &n)
{
uint64_t x = 0;
skipSpace();
while (*p > ' ')
{
x = x * 10 + (*p - '0');
}
n = x;
return *this;
}
};
class QPrint
{
private:
char *data = nullptr;
char *p = nullptr;
char *top = nullptr;
static constexpr ItoStrBase10000 itostr{};
public:
QPrint(size_t size)
{
p = data = new char[size];
top = data + size * 4 / 5;
}
~QPrint()
{
if (data != nullptr)
{
outPut();
delete[] data;
}
}
void setDigit2(uint32_t n)
{
if (n <= 9)
{
*p = n + '0';
p++;
}
else
{
p[0] = n / 10 + '0';
p[1] = n % 10 + '0';
p += 2;
}
}
void setDigit4(uint32_t n)
{
const char *num_p = itostr.tostr(n);
if (n > 99)
{
if (n > 999) // 4digits
{
memcpy(p, num_p, 4);
p += 4;
}
else // 3digits
{
memcpy(p, num_p + 1, 3);
p += 3;
}
}
else
{
if (n > 9) // 2digits
{
memcpy(p, num_p + 2, 2);
p += 2;
}
else // 1digit
{
*p = n + '0';
p++;
}
}
}
void setAllDigit4(uint32_t n)
{
memcpy(p, itostr.tostr(n), sizeof(uint32_t));
p += 4;
}
void setAllDigit8(uint32_t n)
{
memcpy(p, itostr.tostr(n / 10000), sizeof(uint32_t));
memcpy(p + 4, itostr.tostr(n % 10000), sizeof(uint32_t));
p += 8;
}
QPrint &operator<<(uint32_t n)
{
if (n >= 100000000)
{
setDigit2(n / 100000000);
setAllDigit8(n % 100000000);
}
else if (n >= 10000)
{
setDigit4(n / 10000);
setAllDigit4(n % 10000);
}
else
{
setDigit4(n);
}
if (p > top) [[ublikely]]
{
outPut();
}
return *this;
}
QPrint &operator<<(char c)
{
*p = c;
p++;
return *this;
}
void outPut()
{
fwrite(data, 1, p - data, stdout);
p = data;
}
};
constexpr ItoStrBase10000 QPrint::itostr;
int main()
{
// Qin qin(1000000);
QPrint qout(10000);
size_t len1, len2;
std::cin >> len1 >> len2;
len1++;
len2++;
std::vector<uint32_t> a(len1), b(len2), c(len1 + len2 - 1);
for (size_t i = 0; i < len1; i++)
{
scanf("%d", &a[i]);
}
for (size_t i = 0; i < len2; i++)
{
scanf("%d", &b[i]);
}
poly_multiply(a.data(), len1, b.data(), len2, c.data());
for (auto i : c)
{
qout << i << ' ';
}
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Subtask #1 Testcase #1 | 2.407 ms | 2 MB + 40 KB | Accepted | Score: 100 | 显示更多 |
Subtask #1 Testcase #2 | 20.227 ms | 6 MB + 1012 KB | Runtime Error | Score: -100 | 显示更多 |
Subtask #1 Testcase #3 | 10.317 ms | 4 MB + 76 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #4 | 10.309 ms | 4 MB + 68 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #5 | 2.402 ms | 2 MB + 36 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #6 | 2.402 ms | 2 MB + 40 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #7 | 2.403 ms | 2 MB + 40 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #8 | 18.999 ms | 6 MB + 472 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #9 | 18.999 ms | 6 MB + 472 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #10 | 17.981 ms | 5 MB + 960 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #11 | 20.194 ms | 7 MB + 68 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #12 | 18.096 ms | 5 MB + 964 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #13 | 2.403 ms | 2 MB + 40 KB | Accepted | Score: 0 | 显示更多 |