#include <cstdint>
using namespace std;
using i64 = int64_t;
using u32 = uint32_t;
using u64 = uint64_t;
#include <cstring>
#include <string>
#include <vector>
namespace detail
{
template <class Buf>
struct FastI : Buf
{
using Buf::pop;
using Buf::top;
FastI(FILE *f, u32 size = 1 << 18) : Buf(f, size) {}
void skipSpace()
{
while (top() <= ' ')
pop();
}
FastI &operator>>(char &x)
{
skipSpace();
x = pop();
return *this;
}
FastI &operator>>(string &x)
{
x.resize(0);
skipSpace();
while (isgraph(top()))
x.push_back(pop());
return *this;
}
template <unsigned_integral T>
FastI &operator>>(T &x)
{
x = 0;
skipSpace();
while (top() >= '0')
x = x * 10 + (pop() & 0xf);
return *this;
}
template <signed_integral T>
FastI &operator>>(T &x)
{
bool neg = false;
x = 0;
skipSpace();
if (top() == '-')
neg = true, pop();
while (top() >= '0')
x = x * 10 + (pop() & 0xf);
x = neg ? -x : x;
return *this;
}
};
template <class Buf>
struct FastO : Buf
{
using Buf::push;
using Buf::push_uncheck;
using Buf::puts;
vector<u32> pre;
FastO(FILE *f, u32 size = 1 << 18) : Buf(f, size), pre(u64(1E4))
{
for (int i = 0; i < u64(1E4); ++i)
{
int ti = i;
for (int j = 0; j < 4; ++j)
{
pre[i] = pre[i] << 8 | ti % 10 | 0x30;
ti /= 10;
}
}
}
~FastO()
{
Buf::flush();
}
template <signed_integral T>
FastO &operator<<(T x)
{
if (x < 0)
push('-'), x = -x;
return *this << make_unsigned<T>::type(x);
}
void output4(int t)
{
auto tp = (const char *)&pre[t];
if (t >= u64(1E2))
{
if (t >= u64(1E3))
push_uncheck(tp, 4);
else
push_uncheck(tp + 1, 3);
}
else
{
if (t >= u64(1E1))
push_uncheck(tp + 2, 2);
else
push_uncheck(t | 0x30);
}
};
template <unsigned_integral T>
FastO &operator<<(T x)
{
Buf::reserve(32);
if (x >= u64(1E8))
{
u64 q0 = x / u64(1E8), r0 = x % u64(1E8);
if (x >= u64(1E16))
{
u64 q1 = q0 / u64(1E8), r1 = q0 % u64(1E8);
output4(q1);
push_uncheck(&pre[r1 / u64(1E4)], 4);
push_uncheck(&pre[r1 % u64(1E4)], 4);
}
else if (x >= u64(1E12))
{
output4(q0 / u64(1E4));
push_uncheck(&pre[q0 % u64(1E4)], 4);
}
else
{
output4(q0);
}
push_uncheck(&pre[r0 / u64(1E4)], 4);
push_uncheck(&pre[r0 % u64(1E4)], 4);
}
else
{
if (x >= u64(1E4))
{
output4(x / u64(1E4));
push_uncheck(&pre[x % u64(1E4)], 4);
}
else
{
output4(x);
}
}
return *this;
}
FastO &operator<<(char x)
{
return push(x), *this;
}
FastO &operator<<(const char *x)
{
return puts(x), *this;
}
template <size_t N>
FastO &operator<<(const char x[N])
{
return push(x, N), *this;
}
FastO &operator<<(const string &x)
{
return push(x.c_str(), x.size()), *this;
}
};
struct BufO
{
FILE *f;
char *beg, *end, *p;
BufO(FILE *f_, u32 sz) : f(f_), beg(new char[sz]), end(beg + sz - 1), p(beg) {}
~BufO()
{
delete[] beg;
}
void flush()
{
fwrite(beg, 1, p - beg, f);
p = beg;
}
void reserve(u32 len)
{
if (end - p <= int(len))
flush();
}
void push(char s)
{
*p++ = s;
reserve(0);
}
void push(const char *s, u32 len)
{
reserve(len);
push_uncheck(s, len);
}
void push_uncheck(char s) { *p++ = s; }
void push_uncheck(const void *s, u32 len)
{
memcpy(p, s, len);
p += len;
}
void puts(const char *s)
{
while (*s)
push(*s++);
}
};
}
#include <sys/mman.h>
#include <sys/stat.h>
namespace detail
{
struct BufI
{
struct stat sb;
char *p;
BufI(FILE *f, u32)
{
int fd = fileno(f);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
}
~BufI() { munmap(p, sb.st_size); }
char pop() { return *p++; }
char top() const { return *p; }
};
}
using FastI = detail::FastI<detail::BufI>;
using FastO = detail::FastO<detail::BufO>;
#include <type_traits>
template <class T, T MOD>
struct MontgomerySpace;
template <u32 MOD>
struct MontgomerySpace<u32, MOD>
{
static_assert(2 < MOD && MOD < u32(1) << 30, "mod must in [3, 2^30)");
static_assert(MOD % 2 == 1, "mod must be odd");
using ValueT = u32;
using TransT = u32;
using rawU32 = false_type;
using isMontgomery = true_type;
constexpr static u32 get_nr()
{
u32 x = 1;
for (int i = 0; i < 5; ++i)
x *= 2 - x * MOD;
return x;
}
consteval static u32 mod()
{
return MOD;
}
enum : u32
{
R = u32(u64(1) << 32 % MOD),
IR = u32(-get_nr()),
MOD2 = MOD * 2,
};
constexpr static TransT trans(ValueT x)
{
return (u64(x) << 32) % MOD;
}
constexpr static u32 reduce(u64 x)
{
return (x + u64(u32(x) * IR) * MOD) >> 32;
}
constexpr static u32 reduce_m(u32 n)
{
return n >> 31 ? n + MOD : n;
}
constexpr static u32 reduce_2m(u32 n)
{
return n >> 31 ? n + MOD2 : n;
}
constexpr static u32 add(u32 a, u32 b)
{
return reduce_2m(a + b - MOD2);
}
constexpr static u32 sub(u32 a, u32 b)
{
return reduce_2m(a - b);
}
constexpr static u32 mul(u32 a, u32 b)
{
return reduce(u64(a) * b);
}
constexpr static u32 safe(i64 x)
{
return reduce_m(x % MOD);
}
constexpr static ValueT val(TransT x)
{
return reduce_m(reduce(x) - MOD);
}
constexpr static u32 shift2(u32 x)
{
x = reduce(x);
return (x & 1 ? x + MOD : x) >> 1;
}
};
constexpr u32 qpow(u32 a, u64 b, u32 m)
{
u32 r = 1;
for (; b > 0; b /= 2)
{
if (b % 2 == 1)
r = u64(a) * r % m;
a = u64(a) * a % m;
}
return r;
}
#include <algorithm>
#include <cassert>
#include <optional>
u32 legendre(u32 a, u32 p)
{
return qpow(a, (p - 1) / 2, p);
}
optional<int> cipola(u32 n, u32 p)
{
if (n == 0)
return 0;
if (legendre(n, p) != 1)
return nullopt;
if (p == 2)
return 1;
for (u32 a = 0; a < p; a++)
{
u32 i = (a * a - n + p) % p;
using FP2 = pair<u64, u64>;
auto mul = [p, i](const FP2 &l, const FP2 &r)
{
auto [la, lb] = l;
auto [ra, rb] = r;
return FP2{(la * ra + lb * rb % p * i) % p, (lb * ra + la * rb) % p};
};
if (legendre(i, p) == p - 1)
{
FP2 x = {1, 1}, u = {a, 1};
for (int b = (p + 1) / 2; b; b /= 2)
{
if (b % 2 == 1)
x = mul(x, u);
u = mul(u, u);
}
return min(x.first, p - x.first);
}
}
return nullopt;
}
#include <type_traits>
template <class T, T MOD>
struct BasicModSpace;
template <u32 MOD>
struct BasicModSpace<u32, MOD>
{
static_assert(2 < MOD && MOD < u32(1) << 31, "mod must in [3, 2^31)");
using ValueT = u32;
using TransT = u32;
using rawU32 = true_type;
using isMontgomery = false_type;
enum : u32
{
MOD2 = MOD * 2,
};
constexpr static u32 mod()
{
return MOD;
}
constexpr static TransT trans(ValueT x)
{
return x;
}
constexpr static ValueT val(TransT x)
{
return reduce_m(x);
}
constexpr static u32 reduce_m(ValueT n)
{
return n >> 31 ? n + MOD : n;
}
constexpr static u32 reduce_2m(u32 n)
{
return n >> 31 ? n + MOD2 : n;
}
constexpr static u32 add(u32 a, u32 b)
{
return reduce_m(a + b - MOD);
}
constexpr static u32 sub(u32 a, u32 b)
{
return reduce_m(a - b);
}
constexpr static u32 mul(u32 a, u32 b)
{
return u64(a) * b % MOD;
}
constexpr static u32 safe(i64 x)
{
return reduce_m(x % MOD);
}
constexpr static u32 shift2(u32 x)
{
return (x & 1 ? x + MOD : x) >> 1;
}
};
#include <iostream>
// 封装 Modint,功能由 Space 提供
template <class Space_>
struct StaticModint
{
using Space = Space_;
using ValueT = typename Space::ValueT;
using TransT = typename Space::TransT;
using isStatic = true_type;
using rawU32 = typename Space::rawU32;
using isMontgomery = typename Space::isMontgomery;
TransT v;
constexpr StaticModint() = default;
constexpr StaticModint(ValueT v_) : v(Space::trans(v_)) {}
using Self = StaticModint;
explicit operator ValueT() const
{
return val();
}
constexpr static Self safe(i64 v)
{
return Self(Space::safe(v));
}
constexpr ValueT val() const
{
return Space::val(v);
}
constexpr TransT raw() const
{
return v;
}
constexpr static ValueT mod()
{
return Space::mod();
}
constexpr Self &operator+=(const Self &rhs)
{
v = Space::add(v, rhs.v);
return *this;
}
constexpr Self &operator-=(const Self &rhs)
{
v = Space::sub(v, rhs.v);
return *this;
}
constexpr Self &operator*=(const Self &rhs)
{
v = Space::mul(v, rhs.v);
return *this;
}
friend constexpr inline Self operator+(const Self &lhs, const Self &rhs)
{
return Self(lhs) += rhs;
}
friend constexpr inline Self operator-(const Self &lhs, const Self &rhs)
{
return Self(lhs) -= rhs;
}
friend constexpr inline Self operator*(const Self &lhs, const Self &rhs)
{
return Self(lhs) *= rhs;
}
constexpr Self pow(u64 n) const
{
Self r(1), a(*this);
for (; n > 0; n /= 2)
{
if (n % 2 == 1)
r *= a;
a *= a;
}
return r;
}
constexpr Self inv() const
{
return pow(Space::mod() - 2);
}
constexpr Self &operator/=(const Self &rhs)
{
return *this *= rhs.inv();
}
friend constexpr inline Self operator/(const Self &lhs, const Self &rhs)
{
return Self(lhs) /= rhs;
}
constexpr Self operator-() const
{
return Self() -= *this;
}
constexpr optional<Self> sqrt() const
{
return cipola(val(), mod());
}
constexpr Self shift2() const
{
return Space::shift2(v);
}
friend inline istream &operator>>(istream &is, Self &m)
{
i64 x;
is >> x;
m = Self::safe(x);
return is;
}
friend inline ostream &operator<<(ostream &os, const Self &m)
{
return os << m.val();
}
friend inline bool operator==(const Self &lhs, const Self &rhs)
{
return lhs.val() == rhs.val();
}
friend inline bool operator!=(const Self &lhs, const Self &rhs)
{
return !(lhs == rhs);
}
};
template <class T, T MOD>
using BasicStaticModint = StaticModint<BasicModSpace<T, MOD>>;
template <class Space>
inline FastI &operator>>(FastI &is, StaticModint<Space> &m)
{
i64 x;
is >> x;
m = StaticModint<Space>(x);
return is;
}
template <class Space>
inline FastO &operator<<(FastO &os, const StaticModint<Space> &m)
{
return os << m.val();
}
#include <type_traits>
template <class ModT>
concept static_modint_concept = ModT::isStatic::value;
template <class ModT>
concept raw32_modint_concept = ModT::rawU32::value;
template <class ModT>
concept static_raw32_modint_concept = static_modint_concept<ModT> && raw32_modint_concept<ModT>;
template <class ModT>
concept runtime_modint_concept = !
ModT::isStatic::value;
template <class ModT>
concept montgomery_modint_concept = ModT::isMontgomery::value;
template <class ModT>
concept static_basic_modint_concept = !
montgomery_modint_concept<ModT> &&static_modint_concept<ModT>;
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
u32 ntt_size = 0;
} // namespace detail
// #include "ntt-twisted-radix-2-basic.hpp"
// #include "ntt-barrett.hpp"
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
template <static_modint_concept ModT>
struct NttClassicalInfo
{
using ValueT = typename ModT::ValueT;
static constexpr ValueT P = ModT::mod();
static constexpr ValueT g = 3;
static constexpr int rank2 = countr_zero(P - 1);
array<ModT, rank2 + 1> rt, irt;
array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
constexpr NttClassicalInfo()
{
rt[rank2] = ModT(g).pow((P - 1) >> rank2);
irt[rank2] = rt[rank2].inv();
for (int i = rank2; i >= 1; --i)
{
rt[i - 1] = rt[i] * rt[i];
irt[i - 1] = irt[i] * irt[i];
}
ModT prod = 1, iprod = 1;
for (int i = 0; i < rank2 - 1; ++i)
{
rate2[i] = prod * rt[i + 2];
irate2[i] = iprod * irt[i + 2];
prod *= irt[i + 2];
iprod *= rt[i + 2];
}
}
};
template <static_modint_concept ModT>
static void ntt_classical_basic(span<ModT> f)
{ // dif
static constexpr NttClassicalInfo<ModT> info;
int n = f.size();
for (int l = n / 2; l > 0; l /= 2)
{
ModT r = 1;
for (int i = 0, k = 0; i < n; i += l * 2, ++k)
{
for (int j = 0; j < l; ++j)
{
ModT x = f[i + j], y = f[i + j + l] * r;
f[i + j] = x + y;
f[i + j + l] = x - y;
}
r *= info.rate2[countr_one<u32>(k)];
}
}
}
template <static_modint_concept ModT>
static void intt_classical_basic(span<ModT> f)
{ // dit
static constexpr NttClassicalInfo<ModT> info;
int n = f.size();
for (int l = 1; l < n; l *= 2)
{
ModT r = 1;
for (int i = 0, k = 0; i < n; i += l * 2, ++k)
{
for (int j = 0; j < l; ++j)
{
ModT x = f[i + j], y = f[i + j + l];
f[i + j] = x + y;
f[i + j + l] = r * (x - y);
}
r *= info.irate2[countr_one<u32>(k)];
}
}
const ModT ivn = ModT(n).inv();
for (int i = 0; i < n; i++)
f[i] *= ivn;
}
} // namespace detail
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
template <static_modint_concept ModT>
struct NttClassicalInfo4
{
using ValueT = typename ModT::ValueT;
static constexpr ValueT P = ModT::mod();
static constexpr ValueT g = 3;
static constexpr int rank2 = countr_zero(P - 1);
array<ModT, rank2 + 1> rt, irt;
array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
array<ModT, max<int>(0, rank2 - 2)> rate3, irate3;
constexpr NttClassicalInfo4()
{
rt[rank2] = ModT(g).pow((P - 1) >> rank2);
irt[rank2] = rt[rank2].inv();
for (int i = rank2; i >= 1; --i)
{
rt[i - 1] = rt[i] * rt[i];
irt[i - 1] = irt[i] * irt[i];
}
ModT prod = 1, iprod = 1;
for (int i = 0; i < rate2.size(); ++i)
{
rate2[i] = prod * rt[i + 2];
irate2[i] = iprod * irt[i + 2];
prod *= irt[i + 2];
iprod *= rt[i + 2];
}
prod = 1, iprod = 1;
for (int i = 0; i < rate3.size(); ++i)
{
rate3[i] = prod * rt[i + 3];
irate3[i] = iprod * irt[i + 3];
prod *= irt[i + 3];
iprod *= rt[i + 3];
}
}
};
template <static_modint_concept ModT>
static void ntt_classical_basic4(span<ModT> f)
{ // dif
static constexpr NttClassicalInfo4<ModT> info;
int n = f.size(), l = n / 2, n_4b = countr_zero<u32>(n) & 1;
if (n_4b)
{
for (int j = 0; j < l; ++j)
{
ModT x = f[j], y = f[j + l];
f[j] = x + y;
f[j + l] = x - y;
}
l >>= 1;
}
for (l /= 2; l >= 1; l /= 4)
{
ModT r = 1, img = info.rt[2];
for (int i = 0, k = 0; i < n; i += l * 4, ++k)
{
ModT r2 = r * r, r3 = r2 * r;
for (int j = 0; j < l; ++j)
{
ModT x0 = f[i + j + 0 * l];
ModT x1 = f[i + j + 1 * l] * r;
ModT x2 = f[i + j + 2 * l] * r2;
ModT x3 = f[i + j + 3 * l] * r3;
ModT x1x3 = (x1 - x3) * img;
f[i + j + 0 * l] = x0 + x2 + x1 + x3;
f[i + j + 1 * l] = x0 + x2 - x1 - x3;
f[i + j + 2 * l] = x0 - x2 + x1x3;
f[i + j + 3 * l] = x0 - x2 - x1x3;
}
r *= info.rate3[countr_one<u32>(k)];
}
}
}
template <static_modint_concept ModT>
static void intt_classical_basic4(span<ModT> f)
{ // dit
static constexpr NttClassicalInfo4<ModT> info;
int n = f.size(), l = 1, n_4b = countr_zero<u32>(n) & 1;
for (; l < (n_4b ? n / 2 : n); l *= 4)
{
ModT r = 1, img = info.irt[2];
for (int i = 0, k = 0; i < n; i += l * 4, ++k)
{
ModT r2 = r * r, r3 = r2 * r;
for (int j = 0; j < l; ++j)
{
ModT x0 = f[i + j + 0 * l];
ModT x1 = f[i + j + 1 * l];
ModT x2 = f[i + j + 2 * l];
ModT x3 = f[i + j + 3 * l];
ModT x2x3 = (x2 - x3) * img;
f[i + j + 0 * l] = x0 + x1 + x2 + x3;
f[i + j + 1 * l] = (x0 - x1 + x2x3) * r;
f[i + j + 2 * l] = (x0 + x1 - x2 - x3) * r2;
f[i + j + 3 * l] = (x0 - x1 - x2x3) * r3;
}
r *= info.irate3[countr_one<u32>(k)];
}
}
if (n_4b)
{
for (int j = 0; j < l; ++j)
{
ModT x = f[j], y = f[j + l];
f[j] = x + y;
f[j + l] = x - y;
}
}
const ModT ivn = ModT(n).inv();
for (int i = 0; i < n; i++)
f[i] *= ivn;
}
} // namespace detail
// #include "ntt-twisted-radix-2-avx.hpp"
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
#include <type_traits>
// https://judge.yosupo.jp/submission/92714
#pragma GCC target("avx2")
#include <immintrin.h>
#include <array>
namespace simd
{
using I256 = __m256i;
namespace i256
{
inline I256 loadu(const I256 *p) { return _mm256_loadu_si256(p); }
inline I256 load(const I256 *p) { return _mm256_load_si256(p); }
inline void store(I256 *p, const I256 &v) { _mm256_store_si256(p, v); }
inline void storeu(I256 *p, const I256 &v) { _mm256_storeu_si256(p, v); }
template <class T>
inline auto to_array(const I256 &v)
{
constexpr u32 sizeT = sizeof(T);
static_assert(sizeof(I256) % sizeT == 0);
alignas(32) array<T, sizeT> arr;
_mm256_store_si256((I256 *)arr.data(), v);
return arr;
}
inline I256 bit_and(const I256 &a, const I256 &b) { return _mm256_and_si256(a, b); }
}
namespace i128x2
{
template <int imm>
inline I256 permute(const I256 &a, const I256 &b)
{
return _mm256_permute2x128_si256(a, b, imm);
}
template <int imm>
inline I256 shuffle(const I256 &a)
{
return permute<imm>(a, a);
}
} // namespace i128x2
namespace i64x4
{
inline I256 add(const I256 &a, const I256 &b)
{
return _mm256_add_epi64(a, b);
}
} // namespace i64x4
namespace i32x8
{
inline I256 from(int v)
{
return _mm256_set1_epi32(v);
}
inline I256 add(const I256 &a, const I256 &b)
{
return _mm256_add_epi32(a, b);
}
inline I256 sub(const I256 &a, const I256 &b)
{
return _mm256_sub_epi32(a, b);
}
inline I256 mul(const I256 &a, const I256 &b) { return _mm256_mul_epi32(a, b); }
template <int imm>
inline I256 shuffle(const I256 &a) { return _mm256_shuffle_epi32(a, imm); }
template <int imm>
inline I256 blend(const I256 &a, const I256 &b) { return _mm256_blend_epi32(a, b, imm); }
inline I256 zero() { return _mm256_setzero_si256(); }
inline I256 sign(const I256 &a) { return _mm256_cmpgt_epi32(zero(), a); }
inline pair<I256, I256> mul_0246_1357(const I256 &a, const I256 &b)
{
auto x0246 = mul(a, b);
auto x1357 = mul(shuffle<0b11110101>(a), shuffle<0b11110101>(b));
return {x0246, x1357};
}
inline I256 abs(const I256 &a) { return _mm256_abs_epi32(a); }
}
namespace u32x8
{
inline I256 mul(const I256 &a, const I256 &b) { return _mm256_mul_epu32(a, b); }
}
}
namespace simd
{
template <class ModT>
struct M32x8
{
I256 v;
M32x8() : v() {}
M32x8(const I256 &a) : v(a) {}
template <class S>
M32x8(const M32x8<S> &a) : v(a.v) {}
template <class U32>
M32x8(const array<U32, 8> &a)
{
static_assert(sizeof(U32) == 4);
v = i256::load((const I256 *)a.data());
}
template <bool aligned = false>
static M32x8 load(const I256 *p)
{
M32x8 r;
if constexpr (aligned)
{
r = i256::load(p);
}
else
{
r = i256::loadu(p);
}
return r;
}
static M32x8 from(int v)
{
return i32x8::from(v);
}
static M32x8 from(ModT v)
{
return from(v.raw());
}
inline static I256 Rx8 = i32x8::from(ModT::Space::R);
inline static I256 IRx8 = i32x8::from(ModT::Space::IR);
inline static I256 MOD2x8 = i32x8::from(ModT::Space::MOD2);
inline static I256 MODx8 = i32x8::from(ModT::Space::mod());
M32x8 &operator+=(const M32x8 &rhs)
{
v = i32x8::add(v, rhs.v);
v = i32x8::sub(v, MOD2x8);
I256 sign = i32x8::sign(v);
v = i32x8::add(v, i256::bit_and(sign, MOD2x8));
return *this;
}
M32x8 &operator-=(const M32x8 &rhs)
{
v = i32x8::sub(v, rhs.v);
I256 sign = i32x8::sign(v);
v = i32x8::add(v, i256::bit_and(sign, MOD2x8));
return *this;
}
static I256 reduce(const I256 &x0246, const I256 &x1357)
{
auto km0246 = u32x8::mul(u32x8::mul(x0246, IRx8), MODx8);
auto km1357 = u32x8::mul(u32x8::mul(x1357, IRx8), MODx8);
auto z0246 = i64x4::add(x0246, km0246);
z0246 = i32x8::shuffle<0b11110101>(z0246);
auto z1357 = i64x4::add(x1357, km1357);
z1357 = i32x8::shuffle<0b11110101>(z1357);
return i32x8::blend<0b10101010>(z0246, z1357);
}
M32x8 &operator*=(const M32x8 &rhs)
{
auto [x0246, x1357] = i32x8::mul_0246_1357(v, rhs.v);
v = reduce(x0246, x1357);
return *this;
}
friend M32x8 operator+(const M32x8 &lhs, const M32x8 &rhs)
{
return M32x8(lhs) += rhs;
}
friend M32x8 operator-(const M32x8 &lhs, const M32x8 &rhs)
{
return M32x8(lhs) -= rhs;
}
friend M32x8 operator*(const M32x8 &lhs, const M32x8 &rhs)
{
return M32x8(lhs) *= rhs;
}
I256 raw() const
{
return v;
}
template <int imm>
M32x8 neg() const
{
auto m2 = i32x8::blend<imm>(i32x8::zero(), MOD2x8);
return i32x8::abs(i32x8::sub(v, m2));
}
template <bool aligned = false>
void store(I256 *p)
{
if constexpr (aligned)
{
i256::store(p, v);
}
else
{
i256::storeu(p, v);
}
}
auto to_array() const
{
return i256::to_array<u32>(v);
}
template <int imm>
M32x8 shuffle() const
{
return i32x8::shuffle<imm>(v);
}
template <int imm>
M32x8 shufflex4() const
{
return i128x2::shuffle<imm>(v);
}
};
} // namespace simd
namespace detail
{
template <montgomery_modint_concept ModT>
struct NttClassicalInfoAvx
{
using X8 = simd::M32x8<ModT>;
using ValueT = typename ModT::ValueT;
static constexpr ValueT P = ModT::mod();
static constexpr ValueT g = 3;
static constexpr int rank2 = countr_zero(P - 1);
array<ModT, rank2 + 1> rt, irt;
array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
array<ModT, max<int>(0, rank2 - 3)> rate4, irate4;
array<X8, max<int>(0, rank2 - 1)> rate2x8, irate2x8;
array<X8, max<int>(0, rank2 - 3)> rate4ix8, irate4ix8;
constexpr NttClassicalInfoAvx()
{
rt[rank2] = ModT(g).pow((P - 1) >> rank2);
irt[rank2] = rt[rank2].inv();
for (int i = rank2; i >= 1; --i)
{
rt[i - 1] = rt[i] * rt[i];
irt[i - 1] = irt[i] * irt[i];
}
{
ModT prod = 1, iprod = 1;
for (int i = 0; i < rate2.size(); ++i)
{
rate2[i] = prod * rt[i + 2];
irate2[i] = iprod * irt[i + 2];
prod *= irt[i + 2];
iprod *= rt[i + 2];
rate2x8[i] = X8::from(rate2[i]);
irate2x8[i] = X8::from(irate2[i]);
}
prod = 1, iprod = 1;
for (int i = 0; i < rate4.size(); ++i)
{
rate4[i] = prod * rt[i + 4];
irate4[i] = iprod * irt[i + 4];
prod *= irt[i + 4];
iprod *= rt[i + 4];
array<ModT, 8> buf, ibuf;
for (int j = 0; j < 8; ++j)
{
buf[j] = rate4[i].pow(j);
ibuf[j] = irate4[i].pow(j);
}
rate4ix8[i] = buf;
irate4ix8[i] = ibuf;
}
}
}
template <int L>
X8 rt_small()
{
array<ModT, 8> r;
fill(r.begin(), r.end(), 1);
if constexpr (L == 2)
{
r[3] = r[7] = rt[2];
}
else if constexpr (L == 4)
{
for (int i = 5; i < 8; ++i)
r[i] = r[i - 1] * rt[3];
}
return r;
}
template <int L>
X8 irt_small()
{
array<ModT, 8> r;
fill(r.begin(), r.end(), 1);
if constexpr (L == 2)
{
r[3] = r[7] = irt[2];
}
else if constexpr (L == 4)
{
for (int i = 5; i < 8; ++i)
r[i] = r[i - 1] * irt[3];
}
return r;
}
};
template <montgomery_modint_concept ModT, bool aligned>
static void ntt_classical_avx(span<ModT> f0)
{ // dif
using X8 = simd::M32x8<ModT>;
static NttClassicalInfoAvx<ModT> info;
int n8 = f0.size(), n = n8 / 8;
assert(n8 % 16 == 0);
span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
static X8 rt2 = info.template rt_small<2>();
static X8 rt4 = info.template rt_small<4>();
for (int l = n / 2; l >= 1 * 1; l /= 2)
{
X8 r = X8::from(ModT(1));
for (int i = 0, k = 0; i < n; i += l * 2, ++k)
{
for (int j = 0; j < l; ++j)
{
X8 fx = X8::template load<aligned>(&f[i + j]);
X8 fy = X8::template load<aligned>(&f[i + j + l]) * r;
X8 rx = fx + fy;
X8 ry = fx - fy;
rx.template store<aligned>(&f[i + j]);
ry.template store<aligned>(&f[i + j + l]);
}
r *= info.rate2x8[countr_one<u32>(k)];
}
}
X8 rti = X8::from(ModT(1));
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi *= rti;
fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
fi *= rt4;
fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
fi *= rt2;
fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
fi.template store<aligned>(&f[i]);
rti *= info.rate4ix8[countr_one<u32>(i)];
}
}
template <montgomery_modint_concept ModT, bool aligned>
static void intt_classical_avx(span<ModT> f0)
{ // dit
using X8 = simd::M32x8<ModT>;
static NttClassicalInfoAvx<ModT> info;
int n8 = f0.size(), n = n8 / 8;
assert(n8 % 16 == 0);
span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
static X8 rt2 = info.template irt_small<2>();
static X8 rt4 = info.template irt_small<4>();
X8 rti = X8::from(ModT(1));
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
fi *= rt2;
fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
fi *= rt4;
fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
fi *= rti;
fi.template store<aligned>(&f[i]);
rti *= info.irate4ix8[countr_one<u32>(i)];
}
for (i64 l = 1; l < n; l *= 2)
{
X8 r = X8::from(ModT(1));
for (int i = 0, k = 0; i < n; i += l * 2, ++k)
{
for (int j = 0; j < l; ++j)
{
X8 fx = X8::template load<aligned>(&f[i + j]);
X8 fy = X8::template load<aligned>(&f[i + j + l]);
X8 rx = fx + fy;
X8 ry = r * (fx - fy);
rx.template store<aligned>(&f[i + j]);
ry.template store<aligned>(&f[i + j + l]);
}
r *= info.irate2x8[countr_one<u32>(k)];
}
}
X8 ivn8 = X8::from(ModT(n8).inv());
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi *= ivn8;
fi.template store<aligned>(&f[i]);
}
}
} // namespace detail
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
#include <iostream>
namespace detail
{
template <montgomery_modint_concept ModT>
struct NttClassicalInfoAvx4
{
using X8 = simd::M32x8<ModT>;
using ValueT = typename ModT::ValueT;
static constexpr ValueT P = ModT::mod();
static constexpr ValueT g = 3;
static constexpr int rank2 = countr_zero(P - 1);
array<ModT, rank2 + 1> rt, irt;
array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
array<ModT, max<int>(0, rank2 - 2)> rate3, irate3;
array<ModT, max<int>(0, rank2 - 3)> rate4, irate4;
array<X8, max<int>(0, rank2 - 1)> rate2x8, irate2x8;
array<X8, max<int>(0, rank2 - 2)> rate3x8, irate3x8;
array<X8, max<int>(0, rank2 - 3)> rate4ix8, irate4ix8;
constexpr NttClassicalInfoAvx4()
{
rt[rank2] = ModT(g).pow((P - 1) >> rank2);
irt[rank2] = rt[rank2].inv();
for (int i = rank2; i >= 1; --i)
{
rt[i - 1] = rt[i] * rt[i];
irt[i - 1] = irt[i] * irt[i];
}
{
ModT prod = 1, iprod = 1;
for (int i = 0; i < rate2.size(); ++i)
{
rate2[i] = prod * rt[i + 2];
irate2[i] = iprod * irt[i + 2];
prod *= irt[i + 2];
iprod *= rt[i + 2];
rate2x8[i] = X8::from(rate2[i]);
irate2x8[i] = X8::from(irate2[i]);
}
prod = 1, iprod = 1;
for (int i = 0; i < rate3.size(); ++i)
{
rate3[i] = prod * rt[i + 3];
irate3[i] = iprod * irt[i + 3];
prod *= irt[i + 3];
iprod *= rt[i + 3];
rate3x8[i] = X8::from(rate3[i]);
irate3x8[i] = X8::from(irate3[i]);
}
prod = 1, iprod = 1;
for (int i = 0; i < rate4.size(); ++i)
{
rate4[i] = prod * rt[i + 4];
irate4[i] = iprod * irt[i + 4];
prod *= irt[i + 4];
iprod *= rt[i + 4];
array<ModT, 8> buf, ibuf;
for (int j = 0; j < 8; ++j)
{
buf[j] = rate4[i].pow(j);
ibuf[j] = irate4[i].pow(j);
}
rate4ix8[i] = buf;
irate4ix8[i] = ibuf;
}
}
}
template <int L>
X8 rt_small()
{
array<ModT, 8> r;
fill(r.begin(), r.end(), 1);
if constexpr (L == 2)
{
r[3] = r[7] = rt[2];
}
else if constexpr (L == 4)
{
for (int i = 5; i < 8; ++i)
r[i] = r[i - 1] * rt[3];
}
return r;
}
template <int L>
X8 irt_small()
{
array<ModT, 8> r;
fill(r.begin(), r.end(), 1);
if constexpr (L == 2)
{
r[3] = r[7] = irt[2];
}
else if constexpr (L == 4)
{
for (int i = 5; i < 8; ++i)
r[i] = r[i - 1] * irt[3];
}
return r;
}
};
template <montgomery_modint_concept ModT, bool aligned>
static void ntt_classical_avx4(span<ModT> f0)
{ // dif
using X8 = simd::M32x8<ModT>;
static NttClassicalInfoAvx4<ModT> info;
int n8 = f0.size(), n = n8 / 8, l = n / 2, n_4b = countr_zero<u32>(n) & 1;
assert(n8 % 16 == 0);
span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
static X8 rt2 = info.template rt_small<2>();
static X8 rt4 = info.template rt_small<4>();
if (n_4b)
{
for (int j = 0; j < l; ++j)
{
X8 fx = X8::template load<aligned>(&f[j]);
X8 fy = X8::template load<aligned>(&f[j + l]);
X8 rx = fx + fy;
X8 ry = fx - fy;
rx.template store<aligned>(&f[j]);
ry.template store<aligned>(&f[j + l]);
}
l /= 2;
}
for (l /= 2; l >= 1; l /= 4)
{
X8 r = X8::from(ModT(1)), img = X8::from(info.rt[2]);
for (int i = 0, k = 0; i < n; i += l * 4, ++k)
{
X8 r2 = r * r, r3 = r2 * r;
for (int j = 0; j < l; ++j)
{
X8 x0 = X8::template load<aligned>(&f[i + j + 0 * l]);
X8 x1 = X8::template load<aligned>(&f[i + j + 1 * l]) * r;
X8 x2 = X8::template load<aligned>(&f[i + j + 2 * l]) * r2;
X8 x3 = X8::template load<aligned>(&f[i + j + 3 * l]) * r3;
X8 x1x3 = (x1 - x3) * img;
X8 y0 = x0 + x2 + x1 + x3;
X8 y1 = x0 + x2 - x1 - x3;
X8 y2 = x0 - x2 + x1x3;
X8 y3 = x0 - x2 - x1x3;
y0.template store<aligned>(&f[i + j + 0 * l]);
y1.template store<aligned>(&f[i + j + 1 * l]);
y2.template store<aligned>(&f[i + j + 2 * l]);
y3.template store<aligned>(&f[i + j + 3 * l]);
}
r *= info.rate3x8[countr_one<u32>(k)];
}
}
X8 rti = X8::from(ModT(1));
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi *= rti;
fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
fi *= rt4;
fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
fi *= rt2;
fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
fi.template store<aligned>(&f[i]);
rti *= info.rate4ix8[countr_one<u32>(i)];
}
}
template <montgomery_modint_concept ModT, bool aligned>
static void intt_classical_avx4(span<ModT> f0)
{ // dit
using X8 = simd::M32x8<ModT>;
static NttClassicalInfoAvx4<ModT> info;
int n8 = f0.size(), n = n8 / 8, l = 1, n_4b = countr_zero<u32>(n) & 1;
assert(n8 % 16 == 0);
span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
static X8 rt2 = info.template irt_small<2>();
static X8 rt4 = info.template irt_small<4>();
X8 rti = X8::from(ModT(1));
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
fi *= rt2;
fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
fi *= rt4;
fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
fi *= rti;
fi.template store<aligned>(&f[i]);
rti *= info.irate4ix8[countr_one<u32>(i)];
}
for (; l < (n_4b ? n / 2 : n); l *= 4)
{
X8 r = X8::from(ModT(1)), img = X8::from(info.irt[2]);
for (int i = 0, k = 0; i < n; i += l * 4, ++k)
{
X8 r2 = r * r, r3 = r2 * r;
for (int j = 0; j < l; ++j)
{
X8 x0 = X8::template load<aligned>(&f[i + j + 0 * l]);
X8 x1 = X8::template load<aligned>(&f[i + j + 1 * l]);
X8 x2 = X8::template load<aligned>(&f[i + j + 2 * l]);
X8 x3 = X8::template load<aligned>(&f[i + j + 3 * l]);
X8 x2x3 = (x2 - x3) * img;
X8 y0 = x0 + x1 + x2 + x3;
X8 y1 = (x0 - x1 + x2x3) * r;
X8 y2 = (x0 + x1 - x2 - x3) * r2;
X8 y3 = (x0 - x1 - x2x3) * r3;
y0.template store<aligned>(&f[i + j + 0 * l]);
y1.template store<aligned>(&f[i + j + 1 * l]);
y2.template store<aligned>(&f[i + j + 2 * l]);
y3.template store<aligned>(&f[i + j + 3 * l]);
}
r *= info.irate3x8[countr_one<u32>(k)];
}
}
if (n_4b)
{
for (int j = 0; j < l; ++j)
{
X8 fx = X8::template load<aligned>(&f[j]);
X8 fy = X8::template load<aligned>(&f[j + l]);
X8 rx = fx + fy;
X8 ry = fx - fy;
rx.template store<aligned>(&f[j]);
ry.template store<aligned>(&f[j + l]);
}
}
X8 ivn8 = X8::from(ModT(n8).inv());
for (int i = 0; i < n; ++i)
{
X8 fi = X8::template load<aligned>(&f[i]);
fi *= ivn8;
fi.template store<aligned>(&f[i]);
}
}
}
template <static_modint_concept ModT>
void ntt_classical(span<ModT> f)
{
if constexpr (montgomery_modint_concept<ModT>)
{
if (f.size() < 16)
detail::ntt_classical_basic4(f);
else if (u64(f.data()) & 0x1f)
detail::ntt_classical_avx4<ModT, false>(f);
else
detail::ntt_classical_avx4<ModT, true>(f);
}
else if constexpr (raw32_modint_concept<ModT>)
{
detail::ntt_classical_basic4(f);
}
else
{
detail::ntt_classical_basic4(f);
}
}
template <static_modint_concept ModT>
void intt_classical(span<ModT> f)
{
if constexpr (montgomery_modint_concept<ModT>)
{
if (f.size() < 16)
detail::intt_classical_basic4(f);
else if (u64(f.data()) & 0x1f)
detail::intt_classical_avx4<ModT, false>(f);
else
detail::intt_classical_avx4<ModT, true>(f);
}
else if constexpr (raw32_modint_concept<ModT>)
{
detail::intt_classical_basic4(f);
}
else
{
detail::intt_classical_basic4(f);
}
}
template <static_modint_concept ModT>
void ntt(span<ModT> f)
{
assert(has_single_bit<u32>(f.size()));
detail::ntt_size += f.size();
ntt_classical(f);
// ntt_twisted(f);
}
template <static_modint_concept ModT>
void intt(span<ModT> f)
{
assert(has_single_bit<u32>(f.size()));
detail::ntt_size += f.size();
intt_classical(f);
// intt_twisted(f);
}
#include <span>
template <static_modint_concept ModT>
static void dot_basic(span<ModT> f, span<const ModT> g, span<ModT> dst)
{
u32 n = dst.size();
for (u32 i = 0; i < n; i++)
dst[i] = f[i] * g[i];
}
template <static_modint_concept ModT>
static void dot_basic(span<ModT> f, span<const ModT> g)
{
u32 n = f.size();
for (u32 i = 0; i < n; i++)
f[i] *= g[i];
}
template <montgomery_modint_concept ModT>
static void dot_avx(span<ModT> f, span<const ModT> g)
{
u32 n8 = f.size();
u32 i = 0;
using X8 = simd::M32x8<ModT>;
for (; i + 7 < n8; i += 8)
{
X8 fi = X8::load((simd::I256 *)&f[i]);
X8 gi = X8::load((simd::I256 *)&g[i]);
fi *= gi;
fi.store((simd::I256 *)&f[i]);
}
for (; i < n8; i++)
f[i] *= g[i];
}
template <montgomery_modint_concept ModT>
static void dot_avx(span<simd::I256> f, span<const ModT> g, span<ModT> dst)
{
u32 n = dst.size();
u32 i = 0;
using X8 = simd::M32x8<ModT>;
for (; i + 7 < n; i += 8)
{
X8 fi = X8::load((simd::I256 *)&f[i]);
X8 gi = X8::load((simd::I256 *)&g[i]);
X8 di = fi * gi;
di.store((simd::I256 *)&dst[i]);
}
for (; i < n; i++)
dst[i] = f[i] * g[i];
}
template <static_modint_concept ModT>
static void dot(span<ModT> f, span<const ModT> g, span<ModT> dst)
{
if constexpr (montgomery_modint_concept<ModT>)
{
dot_avx(f, g, dst);
}
else
{
dot_basic(f, g, dst);
}
}
template <static_modint_concept ModT>
static void dot(span<ModT> f, span<const ModT> g)
{
if constexpr (montgomery_modint_concept<ModT>)
dot_avx(f, g);
else
dot_basic(f, g);
}
using Space = MontgomerySpace<u32, 998244353>;
using ModT = StaticModint<Space>;
main()
{
FastI fin(stdin);
FastO fout(stdout);
u32 n, m;
fin >> n >> m;
++n, ++m;
u32 L = bit_ceil(n + m - 1);
ModT *f = new (align_val_t(32)) ModT[L];
ModT *g = new (align_val_t(32)) ModT[L];
for (int i = 0; i < n; ++i)
fin >> f[i];
for (int i = 0; i < m; ++i)
fin >> g[i];
ntt<ModT>({f, L}), ntt<ModT>({g, L});
dot<ModT>({f, L}, {g, L});
intt<ModT>({f, L});
for (int i = 0; i < n + m - 1; ++i)
fout << f[i] << ' ';
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |