#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <climits>
#include <string>
#include <array>
#include <type_traits>
// Windows 64位快速乘法宏
#if defined(_WIN64)
#include <intrin.h>
#define UMUL128
#endif //_WIN64
// Unix 64位快速乘法宏
#if defined(__unix__) && defined(__x86_64__) && defined(__GNUC__)
#define UINT128T
#endif //__unix__
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
std::string ui64to_string(uint64_t input, uint8_t digits)
{
std::string result(digits, '0');
for (uint8_t i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
// 模板快速幂
template <typename T, typename T1>
constexpr T qpow(T m, T1 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
}
m *= m;
n >>= 1;
}
return result;
}
// 模板快速幂
template <typename T, typename T1>
constexpr T qpow(T m, T1 n, T mod)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result *= m;
result %= mod;
}
m *= m;
m %= mod;
n >>= 1;
}
return result;
}
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * CHAR_BIT;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
// bits个二进制全为1的数,等于2^bits-1
template <typename T>
constexpr T all_one(int bits)
{
T temp = T(1) << (bits - 1);
return temp - 1 + temp;
}
template <typename IntTy>
constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
IntTy k = a / b;
IntTy g = exgcd(b, a - k * b, y, x);
y -= k * x;
return g;
}
template <typename IntTy>
constexpr IntTy mod_inv(IntTy n, IntTy mod)
{
n %= mod;
IntTy x = 0, y = 0;
exgcd(n, mod, x, y);
if (x < 0)
{
x += mod;
}
else if (x >= mod)
{
x -= mod;
}
return x;
}
// return n^-1 mod 2^pow, Newton iteration
constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
{
const uint64_t mask = all_one<uint64_t>(pow);
uint64_t xn = 1, t = n & mask;
while (t != 1)
{
xn = (xn * (2 - t));
t = (xn * n) & mask;
}
return xn & mask;
}
// 整数log2
template <typename UintTy>
constexpr int hint_log2(UintTy n)
{
constexpr int bits = 8 * sizeof(UintTy);
constexpr UintTy mask = all_one<UintTy>(bits / 2) << (bits / 2);
UintTy m = mask;
int res = 0, shift = bits / 2;
while (shift > 0)
{
if ((n & m))
{
res += shift;
n >>= shift;
}
shift /= 2;
m >>= shift;
}
return res;
}
template <typename IntTy>
constexpr int hint_clz(IntTy x)
{
return sizeof(IntTy) * CHAR_BIT - 1 - hint_log2(x);
}
template <typename IntTy>
constexpr int hint_ctz(IntTy x)
{
return hint_log2(x ^ (x - 1));
}
namespace hint_transform
{
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
// 二进制逆序
template <typename It>
void binary_reverse_swap(It begin, It end)
{
const size_t len = end - begin;
// 左下标小于右下标时交换,防止重复交换
auto smaller_swap = [=](It it_left, It it_right)
{
if (it_left < it_right)
{
std::swap(it_left[0], it_right[0]);
}
};
// 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
auto get_next_bitrev = [=](It last)
{
size_t k = len / 2, indx = last - begin;
indx ^= k;
while (k > indx)
{
k >>= 1;
indx ^= k;
};
return begin + indx;
};
// 长度较短的普通逆序
if (len <= 16)
{
for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
{
smaller_swap(i, j);
j = get_next_bitrev(j);
}
return;
}
const size_t len_8 = len / 8;
const auto last = begin + len_8;
auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
{
smaller_swap(i0, j);
smaller_swap(i1, j + 1);
smaller_swap(i2, j + 2);
smaller_swap(i3, j + 3);
smaller_swap(i0 + len_8, j + 4);
smaller_swap(i1 + len_8, j + 5);
smaller_swap(i2 + len_8, j + 6);
smaller_swap(i3 + len_8, j + 7);
j = get_next_bitrev(j);
}
}
// 二进制逆序
template <typename T>
void binary_reverse_swap(T ary, const size_t len)
{
binary_reverse_swap(ary, ary + len);
}
// 多模式,自检查快速数论变换
namespace hint_ntt
{
constexpr uint64_t MOD1 = 1945555039024054273, ROOT1 = 5;
constexpr uint64_t MOD2 = 4179340454199820289, ROOT2 = 3;
// Compute Integer multiplication, 64bit x 64bit to 128bit, basic algorithm
// first is low 64bit, second is high 64bit
constexpr std::pair<uint64_t, uint64_t> mul64x64to128_base(uint64_t a, uint64_t b)
{
uint64_t ah = a >> 32, bh = b >> 32;
a = uint32_t(a), b = uint32_t(b);
uint64_t r0 = a * b, r1 = a * bh, r2 = ah * b, r3 = ah * bh;
r3 += (r1 >> 32) + (r2 >> 32);
r1 = uint32_t(r1), r2 = uint32_t(r2);
r1 += r2;
r1 += (r0 >> 32);
r3 += (r1 >> 32);
r0 = (r1 << 32) | uint32_t(r0);
return std::make_pair(r0, r3);
}
inline std::pair<uint64_t, uint64_t> mul64x64to128(uint64_t a, uint64_t b)
{
#ifdef UMUL128
// #pragma message("Using _umul128 to compute 64bit x 64bit to 128bit")
unsigned long long lo, hi;
lo = _umul128(a, b, &hi);
return std::make_pair(lo, hi);
#else // No UMUL128
#ifdef UINT128T
// #pragma message("Using __uint128_t to compute 64bit x 64bit to 128bit")
__uint128_t x(a);
x *= b;
return std::make_pair(uint64_t(x), uint64_t(x >> 64));
#else // No __uint128_t
// #pragma message("Using basic function to compute 64bit x 64bit to 128bit")
return mul64x64to128_base(a, b);
#endif // UINT128T
#endif // UMUL128
}
constexpr std::pair<uint64_t, uint64_t> div128by32(uint64_t dividend_hi64, uint64_t ÷nd_lo64, uint32_t divisor)
{
uint64_t quot_hi64 = 0, quot_lo64 = 0;
uint64_t q = 0, r = dividend_hi64 >> 32;
q = r / divisor;
r = r % divisor;
quot_hi64 = q << 32;
r = (r << 32) | uint32_t(dividend_hi64);
q = r / divisor;
r = r % divisor;
quot_hi64 |= q;
r = (r << 32) | (dividend_lo64 >> 32);
q = r / divisor;
r = r % divisor;
quot_lo64 = q << 32;
r = (r << 32) | uint32_t(dividend_lo64);
q = r / divisor;
r = r % divisor;
quot_lo64 |= q;
dividend_lo64 = r;
return std::make_pair(quot_lo64, quot_hi64);
}
// 整数96位除以64位,输入数据需要保证商不大于32位
constexpr uint32_t div96by64to32(uint32_t dividend_hi64, uint64_t ÷nd_lo64, uint64_t divisor)
{
uint64_t divid2 = (uint64_t(dividend_hi64) << 32) | (dividend_lo64 >> 32);
uint64_t divis1 = divisor >> 32;
divisor = uint32_t(divisor);
uint64_t qhat = divid2 / divis1;
divid2 %= divis1;
divid2 = (divid2 << 32) | uint32_t(dividend_lo64);
uint64_t prod = qhat * divisor;
divis1 <<= 32;
if (prod > divid2)
{
qhat--;
prod -= divisor;
divid2 += divis1;
// divid2 > divis1是判断上一次加法后是否出现溢出,若溢出,prod不可能大于divid2
if ((divid2 > divis1) && (prod > divid2))
{
qhat--;
prod -= divisor;
divid2 += divis1;
}
}
divid2 -= prod;
dividend_lo64 = divid2;
return qhat;
}
// 整数128位除以64位,输入数据需要保证商不大于64位
constexpr uint64_t div128by64to64(uint64_t dividend_hi64, uint64_t ÷nd_lo64, uint64_t divisor)
{
if ((divisor >> 32) == 0)
{
return div128by32(dividend_hi64, dividend_lo64, uint32_t(divisor)).first;
}
uint32_t q1 = 0, q0 = 0;
uint32_t divid_hi32 = dividend_hi64 >> 32;
uint64_t divid_lo64 = (dividend_hi64 << 32) | (dividend_lo64 >> 32);
if (divid_hi32 != 0)
{
q1 = div96by64to32(divid_hi32, divid_lo64, divisor);
}
divid_hi32 = divid_lo64 >> 32;
dividend_lo64 = uint32_t(dividend_lo64) | (divid_lo64 << 32);
q0 = div96by64to32(divid_hi32, dividend_lo64, divisor);
return (uint64_t(q1) << 32) | q0;
}
class Uint128
{
private:
uint64_t lo, hi;
public:
constexpr Uint128() : Uint128(0, 0) {}
constexpr Uint128(uint64_t l) : Uint128(l, 0) {}
constexpr Uint128(uint64_t l, uint64_t h) : lo(l), hi(h) {}
constexpr Uint128(std::pair<uint64_t, uint64_t> p) : lo(p.first), hi(p.second) {}
constexpr Uint128 operator+(Uint128 rhs) const
{
rhs.lo += lo;
rhs.hi += hi + (rhs.lo < lo);
return rhs;
}
constexpr Uint128 operator-(Uint128 rhs) const
{
rhs.lo = lo - rhs.lo;
rhs.hi = hi - rhs.hi - (rhs.lo > lo);
return rhs;
}
constexpr Uint128 operator+(uint64_t rhs) const
{
rhs = lo + rhs;
return Uint128(rhs, hi + (rhs < lo));
}
constexpr Uint128 operator-(uint64_t rhs) const
{
rhs = lo - rhs;
return Uint128(rhs, hi - (rhs > lo));
}
// Only compute the lo * rhs.lo
Uint128 operator*(const Uint128 &rhs) const
{
return mul64x64to128(lo, rhs.lo);
}
// Only compute the 128bit / 64 bit
constexpr Uint128 operator/(const Uint128 &rhs) const
{
uint64_t divisor = rhs.lo;
if ((divisor >> 32) == 0)
{
uint64_t rem = lo;
return div128by32(hi, rem, divisor);
}
int k = hint_clz(divisor);
divisor <<= k;
Uint128 divid = operator<<(k);
return div128by64to64(divid.hi, divid.lo, divisor);
}
// Only compute the 128bit % 64 bit
constexpr Uint128 operator%(const Uint128 &rhs) const
{
uint64_t divisor = rhs.lo;
if ((divisor >> 32) == 0)
{
uint64_t rem = lo;
div128by32(hi, rem, divisor);
return Uint128(rem);
}
const int k = hint_clz(divisor);
divisor <<= k;
Uint128 divid = *this << k;
div128by64to64(divid.hi, divid.lo, divisor);
return Uint128(divid.lo) >> k;
}
constexpr Uint128 &operator+=(const Uint128 &rhs)
{
return *this = *this + rhs;
}
constexpr Uint128 &operator-=(const Uint128 &rhs)
{
return *this = *this - rhs;
}
constexpr Uint128 &operator+=(uint64_t rhs)
{
return *this = *this + rhs;
}
constexpr Uint128 &operator-=(uint64_t rhs)
{
return *this = *this - rhs;
}
// Only compute the lo * rhs.lo
constexpr Uint128 &operator*=(const Uint128 &rhs)
{
return *this = mul64x64to128_base(lo, rhs.lo);
}
constexpr Uint128 &operator/=(const Uint128 &rhs)
{
return *this = *this / rhs;
}
constexpr Uint128 &operator%=(const Uint128 &rhs)
{
return *this = *this % rhs;
}
constexpr bool operator>(const Uint128 &rhs) const
{
if (hi != rhs.hi)
{
return hi > rhs.hi;
}
return lo > rhs.lo;
}
constexpr bool operator<(const Uint128 &rhs) const
{
if (hi != rhs.hi)
{
return hi < rhs.hi;
}
return lo < rhs.lo;
}
constexpr bool operator>=(const Uint128 &rhs) const
{
return !(*this < rhs);
}
constexpr bool operator<=(const Uint128 &rhs) const
{
return !(*this > rhs);
}
constexpr bool operator==(const Uint128 &rhs) const
{
return hi == rhs.hi && lo == rhs.lo;
}
constexpr bool operator!=(const Uint128 &rhs) const
{
return !(*this == rhs);
}
constexpr Uint128 operator<<(int shift) const
{
if (shift == 0)
{
return *this;
}
shift %= 128;
shift = shift < 0 ? shift + 128 : shift;
if (shift < 64)
{
return Uint128(lo << shift, (hi << shift) | (lo >> (64 - shift)));
}
return Uint128(0, lo << (shift - 64));
}
constexpr Uint128 operator>>(int shift) const
{
if (shift == 0)
{
return *this;
}
shift %= 128;
shift = shift < 0 ? shift + 128 : shift;
if (shift < 64)
{
return Uint128((lo >> shift) | (hi << (64 - shift)), hi >> shift);
}
return Uint128(hi >> (shift - 64), 0);
}
constexpr Uint128 &operator<<=(int shift)
{
return *this = *this << shift;
}
constexpr Uint128 &operator>>=(int shift)
{
return *this = *this >> shift;
}
constexpr uint64_t high64() const
{
return hi;
}
constexpr uint64_t low64() const
{
return lo;
}
constexpr operator uint64_t() const
{
return low64();
}
void printDec() const
{
if (hi == 0)
{
std::cout << std::dec << lo << "\n";
return;
}
constexpr Uint128 BASE(1e16);
Uint128 copy(*this);
std::string s;
s = ui64to_string(uint64_t(copy % BASE), 16) + s;
copy /= BASE;
s = ui64to_string(uint64_t(copy % BASE), 16) + s;
copy /= BASE;
std::cout << std::to_string(uint64_t(copy % BASE)) + s << "\n";
}
void printHex() const
{
std::cout << std::hex << "0x" << hi << " 0x" << lo << "\n";
}
};
// Montgomery for mod > 2^32
// default R = 2^64
template <typename Int128Type = Uint128>
class Montgomery64
{
public:
uint64_t mod = 0; // modulus, must be odd and < 2^64
uint64_t mod_inv = 0; // mod^-1
uint64_t mod_inv_neg = 0; //-mod^-1
uint64_t mod2 = 0; // mod*2
uint64_t r2 = 0; // r*r%mod
public:
constexpr Montgomery64(uint64_t mod_in) : mod(mod_in), mod2(mod_in * 2)
{
mod_inv = inv_mod2pow(mod, 64); //(mod_inv * mod)%(2^64) = 1
mod_inv_neg = uint64_t(0 - mod_inv); //(mod_inv_neg + mod_inv)%(2^64) = 0
Int128Type R = (Int128Type(1) << 64) % Int128Type(mod);
R *= R;
r2 = uint64_t(R % Int128Type(mod));
}
uint64_t redcLazy(const Int128Type &input) const
{
Int128Type n = uint64_t(input) * mod_inv_neg;
n = n * Int128Type(mod);
n += input;
return n >> 64;
}
uint64_t redcRT(const Int128Type &input) const
{
uint64_t m = redcLazy(input);
return m < mod ? m : m - mod;
}
constexpr uint64_t redcCT(const Int128Type &input) const
{
Int128Type n = uint64_t(input) * mod_inv_neg;
n *= Int128Type(mod);
n += input;
uint64_t m = n >> 64;
return m < mod ? m : m - mod;
}
uint64_t mulMontRunTime(uint64_t a, uint64_t b) const
{
return redcRT(Int128Type(a) * Int128Type(b));
}
uint64_t mulMontRunTimeLazy(uint64_t a, uint64_t b) const
{
return redcLazy(Int128Type(a) * Int128Type(b));
}
constexpr uint64_t mulMontCompileTime(uint64_t a, uint64_t b) const
{
Int128Type prod(a);
prod *= Int128Type(b);
return redcCT(prod);
}
constexpr uint64_t addMont(uint64_t a, uint64_t b) const
{
b = a + b;
return b < mod ? b : b - mod;
}
constexpr uint64_t addMontLazy(uint64_t a, uint64_t b) const
{
b = a + b;
return b < mod2 ? b : b - mod2;
}
constexpr uint64_t subMont(uint64_t a, uint64_t b) const
{
b = a - b;
return b > a ? b + mod : b;
}
constexpr uint64_t subMontLazy(uint64_t a, uint64_t b) const
{
b = a - b;
return b > a ? b + mod2 : b;
}
constexpr uint64_t largeNorm2(uint64_t n) const
{
return n >= mod2 ? n - mod2 : n;
}
constexpr uint64_t toMont(uint64_t a) const
{
return mulMontCompileTime(a, r2);
}
constexpr uint64_t toInt(uint64_t a) const
{
return redcCT(Int128Type(a));
}
constexpr bool selfCheck() const
{
return uint64_t((mod_inv * mod) == 1) && (mod_inv_neg + mod_inv == 0);
}
};
template <uint64_t MOD, typename Int128Type = Uint128>
class MontInt64Lazy
{
private:
static_assert(MOD > UINT32_MAX, "Montgomery64 modulus must be greater than 2^32");
static_assert(MOD < (uint64_t(1) << 62), "MOD can't be larger than 62 bits");
uint64_t data;
public:
using MontgomeryType = Montgomery64<Int128Type>;
static constexpr MontgomeryType montgomery = MontgomeryType(MOD);
static_assert(montgomery.selfCheck(), "Montgomery64 modulus is not correct");
constexpr MontInt64Lazy() : data(0) {}
constexpr MontInt64Lazy(uint64_t n) : data(montgomery.toMont(n)) {}
constexpr MontInt64Lazy operator+(MontInt64Lazy rhs) const
{
rhs.data = montgomery.addMontLazy(data, rhs.data);
return rhs;
}
constexpr MontInt64Lazy operator-(MontInt64Lazy rhs) const
{
rhs.data = montgomery.subMontLazy(data, rhs.data);
return rhs;
}
MontInt64Lazy operator*(MontInt64Lazy rhs) const
{
rhs.data = montgomery.mulMontRunTimeLazy(data, rhs.data);
return rhs;
}
constexpr MontInt64Lazy &operator+=(const MontInt64Lazy &rhs)
{
data = montgomery.addMontLazy(data, rhs.data);
return *this;
}
constexpr MontInt64Lazy &operator-=(const MontInt64Lazy &rhs)
{
data = montgomery.subMontLazy(data, rhs.data);
return *this;
}
constexpr MontInt64Lazy &operator*=(const MontInt64Lazy &rhs)
{
data = montgomery.mulMontCompileTime(data, rhs.data);
return *this;
}
constexpr MontInt64Lazy largeNorm() const
{
MontInt64Lazy res;
res.data = montgomery.largeNorm2(data);
return res;
}
constexpr MontInt64Lazy rawAdd(MontInt64Lazy n) const
{
n.data = data + n.data;
return n;
}
constexpr MontInt64Lazy rawSub(MontInt64Lazy n) const
{
n.data = data - n.data + mod2();
return n;
}
constexpr uint64_t montToInt() const
{
return montgomery.toInt(data);
}
constexpr operator uint64_t() const
{
return montToInt();
}
static constexpr uint64_t mod()
{
return MOD;
}
static constexpr uint64_t mod2()
{
return MOD * 2;
}
};
template <uint64_t MOD, typename Int128Type>
constexpr typename MontInt64Lazy<MOD, Int128Type>::MontgomeryType MontInt64Lazy<MOD, Int128Type>::montgomery;
template <uint32_t MOD>
class MontInt32Lazy
{
private:
static_assert(MOD < (uint32_t(1) << 30), "MOD can't be larger than 30 bits");
uint32_t data;
public:
constexpr MontInt32Lazy() : data(0) {}
constexpr MontInt32Lazy(uint32_t n) : data(toMont(n)) {}
static constexpr uint32_t mod()
{
return MOD;
}
static constexpr uint32_t mod2()
{
return MOD * 2;
}
static constexpr uint32_t modInv()
{
constexpr uint32_t mod_inv = uint32_t(inv_mod2pow(mod(), 32));
return mod_inv;
}
static constexpr uint32_t modNegInv()
{
constexpr uint32_t mod_neg_inv = uint32_t(-modInv());
return mod_neg_inv;
}
static_assert((mod() * modInv()) == 1, "mod_inv not correct");
static constexpr uint32_t toMont(uint32_t n)
{
return (uint64_t(n) << 32) % MOD;
}
static constexpr uint32_t redcLazy(uint64_t n)
{
uint64_t prod = uint32_t(n) * modNegInv();
return (prod * mod() + n) >> 32;
}
static constexpr uint32_t redc(uint64_t n)
{
uint32_t res = redcLazy(n);
return res < mod() ? res : res - mod();
}
constexpr MontInt32Lazy operator+(MontInt32Lazy rhs) const
{
rhs.data = data + rhs.data;
rhs.data = rhs.data < mod2() ? rhs.data : rhs.data - mod2();
return rhs;
}
constexpr MontInt32Lazy operator-(MontInt32Lazy rhs) const
{
rhs.data = data - rhs.data;
rhs.data = rhs.data > data ? rhs.data + mod2() : rhs.data;
return rhs;
}
constexpr MontInt32Lazy operator*(MontInt32Lazy rhs) const
{
rhs.data = redcLazy(uint64_t(data) * rhs.data);
return rhs;
}
constexpr MontInt32Lazy &operator+=(const MontInt32Lazy &rhs)
{
rhs.data = data + rhs.data;
data = rhs.data < mod2() ? rhs.data : rhs.data - mod2();
return *this;
}
constexpr MontInt32Lazy &operator-=(const MontInt32Lazy &rhs)
{
rhs.data = data - rhs.data;
data = rhs.data > data ? rhs.data + mod2() : rhs.data;
return *this;
}
constexpr MontInt32Lazy &operator*=(const MontInt32Lazy &rhs)
{
data = redcLazy(uint64_t(data) * rhs.data);
return *this;
}
constexpr MontInt32Lazy largeNorm() const
{
MontInt32Lazy res;
res.data = data >= mod2() ? data - mod2() : data;
return res;
}
constexpr MontInt32Lazy rawAdd(MontInt32Lazy n) const
{
n.data = data + n.data;
return n;
}
constexpr MontInt32Lazy rawSub(MontInt32Lazy n) const
{
n.data = data - n.data + mod2();
return n;
}
constexpr uint32_t montToInt() const
{
return redc(data);
}
constexpr operator uint32_t() const
{
return montToInt();
}
};
template <typename IntType>
constexpr bool check_inv(uint64_t n, uint64_t n_inv, uint64_t mod)
{
n %= mod;
n_inv %= mod;
IntType m(n);
m *= IntType(n_inv);
m %= IntType(mod);
return m == IntType(1);
}
// 快速计算两模数的中国剩余定理
template <uint64_t MOD1, uint64_t MOD2, typename Int128Type = Uint128>
inline Int128Type qcrt(uint64_t num1, uint64_t num2)
{
constexpr uint64_t inv1 = mod_inv<int64_t>(MOD1, MOD2);
constexpr uint64_t inv2 = mod_inv<int64_t>(MOD2, MOD1);
static_assert(check_inv<Int128Type>(inv1, MOD1, MOD2), "Inv1 error");
static_assert(check_inv<Int128Type>(inv2, MOD2, MOD1), "Inv2 error");
if (num1 > num2)
{
return (Int128Type(num1 - num2) * Int128Type(inv2) % Int128Type(MOD1)) * Int128Type(MOD2) + num2;
}
else
{
return (Int128Type(num2 - num1) * Int128Type(inv1) % Int128Type(MOD2)) * Int128Type(MOD1) + num1;
}
}
// 快速计算两模数的中国剩余定理
template <uint32_t MOD1, uint32_t MOD2>
inline uint64_t qcrt(uint32_t num1, uint32_t num2)
{
constexpr uint64_t inv1 = mod_inv<int64_t>(MOD1, MOD2);
constexpr uint64_t inv2 = mod_inv<int64_t>(MOD2, MOD1);
static_assert(check_inv<uint64_t>(inv1, MOD1, MOD2), "Inv1 error");
static_assert(check_inv<uint64_t>(inv2, MOD2, MOD1), "Inv2 error");
if (num1 > num2)
{
return (uint64_t(num1 - num2) * uint64_t(inv2) % MOD1) * MOD2 + num2;
}
else
{
return (uint64_t(num2 - num1) * uint64_t(inv1) % MOD2) * MOD1 + num1;
}
}
namespace radix2
{
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w41(ModIntType n)
{
constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
return n * W_4_1;
}
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w81(ModIntType n)
{
constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
return n * W_8_1;
}
template <uint64_t ROOT, typename ModIntType>
inline ModIntType mul_w83(ModIntType n)
{
constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 8) * 3);
return n * W_8_3;
}
template <typename ModIntType>
inline void dit_butterfly(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3, ModIntType omega1, ModIntType omega2, ModIntType omega_last)
{
ModIntType temp0, temp1, temp2, temp3;
in_out1 = in_out1 * omega_last;
in_out3 = in_out3 * omega_last;
in_out0 = in_out0.largeNorm();
in_out2 = in_out2.largeNorm();
temp0 = in_out0 + in_out1;
temp1 = in_out0 - in_out1;
temp2 = in_out2.rawAdd(in_out3);
temp3 = in_out2.rawSub(in_out3);
temp2 = omega1 * temp2;
temp3 = omega2 * temp3;
in_out0 = temp0.rawAdd(temp2);
in_out1 = temp1.rawAdd(temp3);
in_out2 = temp0.rawSub(temp2);
in_out3 = temp1.rawSub(temp3);
}
template <typename ModIntType>
inline void dif_butterfly(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3, ModIntType omega1, ModIntType omega2, ModIntType omega_last)
{
ModIntType temp0, temp1, temp2, temp3;
temp0 = in_out0 + in_out2;
temp1 = in_out1 + in_out3;
temp2 = in_out0.rawSub(in_out2) * omega1;
temp3 = in_out1.rawSub(in_out3) * omega2;
in_out0 = temp0 + temp1;
in_out1 = temp0.rawSub(temp1) * omega_last;
in_out2 = temp2 + temp3;
in_out3 = temp2.rawSub(temp3) * omega_last;
}
template <size_t LEN, uint64_t ROOT, typename ModIntType>
struct NTTShort
{
static constexpr size_t ntt_len = LEN;
static constexpr int log_len = hint_log2(ntt_len);
using TableType = std::array<ModIntType, ntt_len>;
static constexpr ModIntType *getOmegaIt(size_t len)
{
return &table[len / 2];
}
static constexpr TableType getNTTTable()
{
for (int omega_log_len = 0; omega_log_len <= log_len; omega_log_len++)
{
size_t omega_len = size_t(1) << omega_log_len, omega_count = omega_len / 2;
auto it = getOmegaIt(omega_len);
ModIntType root = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / omega_len);
ModIntType omega(1);
for (size_t i = 0; i < omega_count; i++)
{
it[i] = omega;
omega *= root;
}
}
return table;
}
static TableType table;
static void dit(ModIntType in_out[], size_t len)
{
size_t rank = len;
if (hint_log2(len) % 2 == 0)
{
NTTShort<16, ROOT, ModIntType>::dit(in_out, len);
for (size_t i = 16; i < len; i += 16)
{
NTTShort<16, ROOT, ModIntType>::dit(in_out + i);
}
rank = 64;
}
else
{
NTTShort<8, ROOT, ModIntType>::dit(in_out, len);
for (size_t i = 8; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dit(in_out + i);
}
rank = 32;
}
for (; rank <= len; rank *= 4)
{
size_t gap = rank / 4;
auto in_out2 = in_out + gap * 2;
auto omega_it = getOmegaIt(rank), last_omega_it = getOmegaIt(rank / 2);
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i++)
{
dit_butterfly(in_out[i + j], in_out[gap + i + j], in_out2[i + j], in_out2[gap + i + j], omega_it[i], omega_it[gap + i], last_omega_it[i]);
}
}
}
}
static void dif(ModIntType in_out[], size_t len)
{
size_t rank = len;
for (; rank >= 32; rank /= 4)
{
size_t gap = rank / 4;
auto in_out2 = in_out + gap * 2;
auto omega_it = getOmegaIt(rank), last_omega_it = getOmegaIt(rank / 2);
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i++)
{
dif_butterfly(in_out[i + j], in_out[gap + i + j], in_out2[i + j], in_out2[gap + i + j], omega_it[i], omega_it[gap + i], last_omega_it[i]);
}
}
}
if (hint_log2(rank) % 2 == 0)
{
NTTShort<16, ROOT, ModIntType>::dif(in_out, len);
for (size_t i = 16; i < len; i += 16)
{
NTTShort<16, ROOT, ModIntType>::dif(in_out + i);
}
}
else
{
NTTShort<8, ROOT, ModIntType>::dif(in_out, len);
for (size_t i = 8; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dif(in_out + i);
}
}
}
static void dit2(ModIntType in_out[], size_t len)
{
for (size_t i = 0; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dit(in_out + i);
}
for (size_t rank = 16; rank <= len; rank *= 2)
{
size_t gap = rank / 2;
auto in_out2 = in_out + gap;
auto omega_it = getOmegaIt(rank), last_omega_it = getOmegaIt(rank / 2);
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i += 2)
{
auto butterfly = [=](size_t shift)
{
auto x = in_out[j + i + shift].largeNorm(), y = in_out2[j + i + shift] * omega_it[i + shift];
in_out[j + i + shift] = x.rawAdd(y), in_out2[j + i + shift] = x.rawSub(y);
};
butterfly(0);
butterfly(1);
}
}
}
}
static void dif2(ModIntType in_out[], size_t len)
{
for (size_t rank = len; rank >= 16; rank /= 2)
{
size_t gap = rank / 2;
auto in_out2 = in_out + gap;
auto omega_it = getOmegaIt(rank), last_omega_it = getOmegaIt(rank / 2);
for (size_t j = 0; j < len; j += rank)
{
for (size_t i = 0; i < gap; i += 2)
{
auto butterfly = [=](size_t shift)
{
auto x = in_out[j + i + shift], y = in_out2[j + i + shift];
in_out[j + i + shift] = x + y, in_out2[j + i + shift] = x.rawSub(y) * omega_it[i + shift];
};
butterfly(0);
butterfly(1);
}
}
}
for (size_t i = 0; i < len; i += 8)
{
NTTShort<8, ROOT, ModIntType>::dif(in_out + i);
}
}
};
template <size_t LEN, uint64_t ROOT, typename ModIntType>
typename NTTShort<LEN, ROOT, ModIntType>::TableType NTTShort<LEN, ROOT, ModIntType>::table = NTTShort<LEN, ROOT, ModIntType>::getNTTTable();
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<0, ROOT, ModIntType>
{
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<1, ROOT, ModIntType>
{
static void dit(ModIntType in_out[]) {}
static void dif(ModIntType in_out[]) {}
static void dit(ModIntType in_out[], size_t len) {}
static void dif(ModIntType in_out[], size_t len) {}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<2, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dif(ModIntType in_out[])
{
transform2(in_out[0], in_out[1]);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 2)
{
return;
}
dif(in_out);
}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<4, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp1);
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp2;
in_out[1] = temp1 + temp3;
in_out[2] = temp0 - temp2;
in_out[3] = temp1 - temp3;
}
static void dif(ModIntType in_out[])
{
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
transform2(temp0, temp2);
transform2(temp1, temp3);
temp3 = mul_w41<ROOT>(temp3);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, ROOT, ModIntType>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 4)
{
NTTShort<2, ROOT, ModIntType>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<8, ROOT, ModIntType>
{
static void dit(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp1);
transform2(temp2, temp3);
transform2(temp4, temp5);
transform2(temp6, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
in_out[0] = temp0 + temp4;
in_out[1] = temp1 + temp5;
in_out[2] = temp2 + temp6;
in_out[3] = temp3 + temp7;
in_out[4] = temp0 - temp4;
in_out[5] = temp1 - temp5;
in_out[6] = temp2 - temp6;
in_out[7] = temp3 - temp7;
}
static void dif(ModIntType in_out[])
{
static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
static constexpr ModIntType w2 = qpow(w1, 2);
static constexpr ModIntType w3 = qpow(w1, 3);
auto temp0 = in_out[0];
auto temp1 = in_out[1];
auto temp2 = in_out[2];
auto temp3 = in_out[3];
auto temp4 = in_out[4];
auto temp5 = in_out[5];
auto temp6 = in_out[6];
auto temp7 = in_out[7];
transform2(temp0, temp4);
transform2(temp1, temp5);
transform2(temp2, temp6);
transform2(temp3, temp7);
temp5 = temp5 * w1;
temp6 = temp6 * w2;
temp7 = temp7 * w3;
transform2(temp0, temp2);
transform2(temp1, temp3);
transform2(temp4, temp6);
transform2(temp5, temp7);
temp3 = mul_w41<ROOT>(temp3);
temp7 = mul_w41<ROOT>(temp7);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
in_out[4] = temp4 + temp5;
in_out[5] = temp4 - temp5;
in_out[6] = temp6 + temp7;
in_out[7] = temp6 - temp7;
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, ROOT, ModIntType>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 8)
{
NTTShort<4, ROOT, ModIntType>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint64_t ROOT, typename ModIntType>
struct NTTShort<16, ROOT, ModIntType>
{
using NTT4 = NTTShort<4, ROOT, ModIntType>;
using NTT8 = NTTShort<8, ROOT, ModIntType>;
static void dit(ModIntType in_out[])
{
constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
constexpr ModIntType w2 = qpow(w1, 2);
constexpr ModIntType w3 = qpow(w1, 3);
constexpr ModIntType w6 = qpow(w1, 6);
constexpr ModIntType w9 = qpow(w1, 9);
NTT4::dit(in_out + 12);
NTT4::dit(in_out + 8);
NTT8::dit(in_out);
ModIntType temp0, temp1, temp2, temp3;
temp2 = in_out[8];
temp3 = in_out[12];
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[0];
temp1 = in_out[4];
in_out[0] = temp0 + temp2;
in_out[4] = temp1 + temp3;
in_out[8] = temp0 - temp2;
in_out[12] = temp1 - temp3;
temp2 = in_out[9] * w1;
temp3 = in_out[13] * w3;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[1];
temp1 = in_out[5];
in_out[1] = temp0 + temp2;
in_out[5] = temp1 + temp3;
in_out[9] = temp0 - temp2;
in_out[13] = temp1 - temp3;
temp2 = in_out[10] * w2;
temp3 = in_out[14] * w6;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[2];
temp1 = in_out[6];
in_out[2] = temp0 + temp2;
in_out[6] = temp1 + temp3;
in_out[10] = temp0 - temp2;
in_out[14] = temp1 - temp3;
temp2 = in_out[11] * w3;
temp3 = in_out[15] * w9;
transform2(temp2, temp3);
temp3 = mul_w41<ROOT>(temp3);
temp0 = in_out[3];
temp1 = in_out[7];
in_out[3] = temp0 + temp2;
in_out[7] = temp1 + temp3;
in_out[11] = temp0 - temp2;
in_out[15] = temp1 - temp3;
}
static void dif(ModIntType in_out[])
{
constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
constexpr ModIntType w2 = qpow(w1, 2);
constexpr ModIntType w3 = qpow(w1, 3);
constexpr ModIntType w6 = qpow(w1, 6);
constexpr ModIntType w9 = qpow(w1, 9);
ModIntType temp0, temp1, temp2, temp3;
temp0 = in_out[0];
temp1 = in_out[4];
temp2 = in_out[8];
temp3 = in_out[12];
in_out[0] = temp0 + temp2;
in_out[4] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[8] = temp2;
in_out[12] = temp3;
temp0 = in_out[1];
temp1 = in_out[5];
temp2 = in_out[9];
temp3 = in_out[13];
in_out[1] = temp0 + temp2;
in_out[5] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[9] = temp2 * w1;
in_out[13] = temp3 * w3;
temp0 = in_out[2];
temp1 = in_out[6];
temp2 = in_out[10];
temp3 = in_out[14];
in_out[2] = temp0 + temp2;
in_out[6] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[10] = temp2 * w2;
in_out[14] = temp3 * w6;
temp0 = in_out[3];
temp1 = in_out[7];
temp2 = in_out[11];
temp3 = in_out[15];
in_out[3] = temp0 + temp2;
in_out[7] = temp1 + temp3;
temp2 = temp0 - temp2;
temp3 = temp1 - temp3;
temp3 = mul_w41<ROOT>(temp3);
transform2(temp2, temp3);
in_out[11] = temp2 * w3;
in_out[15] = temp3 * w9;
NTT8::dif(in_out);
NTT4::dif(in_out + 8);
NTT4::dif(in_out + 12);
}
static void dit(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, ROOT, ModIntType>::dit(in_out, len);
return;
}
dit(in_out);
}
static void dif(ModIntType in_out[], size_t len)
{
if (len < 16)
{
NTTShort<8, ROOT, ModIntType>::dif(in_out, len);
return;
}
dif(in_out);
}
};
template <uint64_t MOD, uint64_t ROOT, typename Int128Type = Uint128>
struct NTT
{
static constexpr uint64_t mod()
{
return MOD;
}
static constexpr uint64_t root()
{
return ROOT;
}
static constexpr uint64_t iroot()
{
return mod_inv<int64_t>(root(), mod());
}
static constexpr bool selfCheck()
{
Int128Type n = root();
n *= Int128Type(iroot());
n %= Int128Type(mod());
return n == Int128Type(1);
}
static_assert(root() < mod(), "ROOT must be smaller than MOD");
static_assert(selfCheck(), "IROOT * ROOT % MOD must be 1");
static constexpr int mod_bits = hint_log2(mod()) + 1;
static constexpr int max_log_len = hint_ctz(mod() - 1);
static constexpr size_t getMaxLen()
{
if (max_log_len < sizeof(size_t) * CHAR_BIT)
{
return size_t(1) << max_log_len;
}
return size_t(1) << (sizeof(size_t) * CHAR_BIT - 1);
}
static constexpr size_t ntt_max_len = getMaxLen();
using INTT = NTT<mod(), iroot(), Int128Type>;
using ModIntType = typename std::conditional<(mod_bits > 32), MontInt64Lazy<MOD, Int128Type>, MontInt32Lazy<uint32_t(MOD)>>::type;
using IntType = typename std::conditional<(mod_bits > 32), uint64_t, uint32_t>::type;
static constexpr size_t LONG_THRESHOLD = size_t(1) << 16;
using NTTTemplate = NTTShort<LONG_THRESHOLD, root(), ModIntType>;
static void dit(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dit2(in_out, ntt_len);
return;
}
size_t quarter_len = ntt_len / 4;
dit(in_out + quarter_len * 3, quarter_len);
dit(in_out + quarter_len * 2, quarter_len);
dit(in_out + quarter_len, quarter_len);
dit(in_out, quarter_len);
const ModIntType unit1 = qpow(ModIntType(root()), (mod() - 1) / ntt_len), unit2(qpow(unit1, 2));
ModIntType omega1(1), omega2(qpow(unit1, quarter_len)), last(1);
for (auto it0 = in_out, it2 = in_out + ntt_len / 2; it0 < in_out + quarter_len; it0++, it2++)
{
dit_butterfly(it0[0], it0[quarter_len], it2[0], it2[quarter_len], omega1, omega2, last);
omega1 *= unit1;
omega2 *= unit1;
last *= unit2;
}
}
static void dif(ModIntType in_out[], size_t ntt_len)
{
ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
if (ntt_len <= LONG_THRESHOLD)
{
NTTTemplate::dif2(in_out, ntt_len);
return;
}
size_t quarter_len = ntt_len / 4;
const ModIntType unit1 = qpow(ModIntType(root()), (mod() - 1) / ntt_len), unit2(qpow(unit1, 2));
ModIntType omega1(1), omega2(qpow(unit1, quarter_len)), last(1);
for (auto it0 = in_out, it2 = in_out + ntt_len / 2; it0 < in_out + quarter_len; it0++, it2++)
{
dif_butterfly(it0[0], it0[quarter_len], it2[0], it2[quarter_len], omega1, omega2, last);
omega1 *= unit1;
omega2 *= unit1;
last *= unit2;
}
dif(in_out + quarter_len * 3, quarter_len);
dif(in_out + quarter_len * 2, quarter_len);
dif(in_out + quarter_len, quarter_len);
dif(in_out, quarter_len);
}
static void convolution(ModIntType in1[], ModIntType in2[], ModIntType out[], size_t ntt_len, bool normlize = true)
{
dif(in1, ntt_len);
dif(in2, ntt_len);
if (normlize)
{
const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i] * inv_len;
}
}
else
{
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = in1[i] * in2[i];
}
}
INTT::dit(out, ntt_len);
}
};
}
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
using namespace std;
using namespace hint;
using namespace hint_transform::hint_ntt;
size_t conv_len = m + n + 1, ntt_len = int_ceil2(conv_len);
using ntt = hint_transform::hint_ntt::radix2::NTT<998244353, 3>;
using ModInt = ntt::ModIntType;
ModInt *a_ntt = new ModInt[ntt_len];
ModInt *b_ntt = new ModInt[ntt_len];
std::fill(a_ntt + n + 1, a_ntt + ntt_len, ModInt{});
std::fill(b_ntt + m + 1, b_ntt + ntt_len, ModInt{});
std::copy(a, a + n + 1, a_ntt);
std::copy(b, b + m + 1, b_ntt);
ntt::convolution(a_ntt, b_ntt, a_ntt, ntt_len);
size_t rem_len = conv_len % 16, i = 0;
for (; i < conv_len; i++)
{
c[i] = uint32_t(a_ntt[i]);
}
delete[] a_ntt;
delete[] b_ntt;
}
#include <cstring>
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
const char *tostr(uint16_t num) const
{
return reinterpret_cast<const char *>(&table[num]);
}
};
class Qin
{
private:
char *data = nullptr;
char *p = nullptr;
public:
Qin(size_t max_len)
{
p = data = new char[max_len];
fread(data, 1, max_len, stdin);
}
~Qin()
{
if (data != nullptr)
{
delete[] data;
}
}
void skipSpace()
{
while (*p <= ' ')
{
p++;
}
}
template <typename T>
Qin &operator>>(T &n)
{
uint64_t x = 0;
skipSpace();
while (*p > ' ')
{
x = x * 10 + (*p - '0');
}
n = x;
return *this;
}
};
class QPrint
{
private:
char *data = nullptr;
char *p = nullptr;
size_t len = 0;
static constexpr ItoStrBase10000 itostr{};
public:
QPrint(size_t max_len) : len(max_len)
{
p = data = new char[max_len];
}
~QPrint()
{
if (data != nullptr)
{
outPut();
delete[] data;
}
}
void setDigit2(uint32_t n)
{
if (n <= 9)
{
*p = n + '0';
p++;
}
else
{
p[0] = n / 10 + '0';
p[1] = n % 10 + '0';
p += 2;
}
}
void setDigit4(uint32_t n)
{
const char *num_p = itostr.tostr(n);
if (n > 99)
{
if (n > 999) // 4digits
{
memcpy(p, num_p, 4);
p += 4;
}
else // 3digits
{
memcpy(p, num_p + 1, 3);
p += 3;
}
}
else
{
if (n > 9) // 2digits
{
memcpy(p, num_p + 2, 2);
p += 2;
}
else // 1digit
{
*p = n + '0';
p++;
}
}
}
void setAllDigit4(uint32_t n)
{
memcpy(p, itostr.tostr(n), sizeof(uint32_t));
p += 4;
}
void setAllDigit8(uint32_t n)
{
memcpy(p, itostr.tostr(n / 10000), sizeof(uint32_t));
memcpy(p + 4, itostr.tostr(n % 10000), sizeof(uint32_t));
p += 8;
}
QPrint &operator<<(uint32_t n)
{
if (n >= 100000000)
{
setDigit2(n / 100000000);
setAllDigit8(n % 100000000);
}
else if (n >= 10000)
{
setDigit4(n / 10000);
setAllDigit4(n % 10000);
}
else
{
setDigit4(n);
}
return *this;
}
QPrint &operator<<(char c)
{
*p = c;
p++;
return *this;
}
void outPut() const
{
fwrite(data, 1, p - data, stdout);
}
};
constexpr ItoStrBase10000 QPrint::itostr;
int main()
{
// Qin qin(1000000);
QPrint qout(1000000);
size_t len1, len2;
std::cin >> len1 >> len2;
len1++;
len2++;
std::vector<uint32_t> a(len1), b(len2), c(len1 + len2 - 1);
for (size_t i = 0; i < len1; i++)
{
scanf("%d", &a[i]);
}
for (size_t i = 0; i < len2; i++)
{
scanf("%d", &b[i]);
}
poly_multiply(a.data(), len1, b.data(), len2, c.data());
for (auto i : c)
{
qout << i << ' ';
}
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Subtask #1 Testcase #1 | 517.18 us | 552 KB | Accepted | Score: 100 | 显示更多 |
Subtask #1 Testcase #2 | 17.565 ms | 5 MB + 12 KB | Runtime Error | Score: -100 | 显示更多 |
Subtask #1 Testcase #3 | 8.574 ms | 2 MB + 592 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #4 | 8.599 ms | 2 MB + 580 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #5 | 508.51 us | 548 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #6 | 508.87 us | 552 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #7 | 508.16 us | 552 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #8 | 16.543 ms | 4 MB + 768 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #9 | 16.546 ms | 4 MB + 768 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #10 | 15.527 ms | 4 MB + 440 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #11 | 17.581 ms | 5 MB + 12 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #12 | 17.179 ms | 4 MB + 456 KB | Runtime Error | Score: 0 | 显示更多 |
Subtask #1 Testcase #13 | 507.24 us | 548 KB | Runtime Error | Score: 0 | 显示更多 |