提交记录 21651


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002. 测测你的多项式乘法 Compile Error 0 0 ns 0 KB C++14 56.87 KB
提交时间 评测时间
2024-04-26 16:32:05 2024-04-27 15:02:07
#pragma GCC target("avx2")
#pragma GCC optimize("O2")
#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <array>
#include <ctime>
#include <cstring>
#include <immintrin.h>
template <typename T, size_t LEN>
class AlignAry
{
private:
    alignas(32) T ary[LEN];

public:
    constexpr AlignAry() {}
    constexpr T &operator[](size_t index)
    {
        return ary[index];
    }
    constexpr const T &operator[](size_t index) const
    {
        return ary[index];
    }
    constexpr T *data()
    {
        return ary;
    }
    constexpr T *begin()
    {
        return ary;
    }
    constexpr T *end()
    {
        return begin() + LEN;
    }
    constexpr const T *data() const
    {
        return ary;
    }
    template <typename Ty>
    Ty *cast_ptr()
    {
        return reinterpret_cast<Ty *>(ary);
    }
    template <typename Ty>
    const Ty *cast_ptr() const
    {
        return reinterpret_cast<const Ty *>(ary);
    }
};
namespace hint
{
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    // bits个二进制全为1的数,等于2^bits-1
    template <typename T>
    constexpr T all_one(int bits)
    {
        T tmp = T(1) << (bits - 1);
        return tmp - 1 + tmp;
    }

    // 整数log2
    template <typename UintTy>
    constexpr int hint_log2(UintTy n)
    {
        constexpr int bits = 8 * sizeof(UintTy);
        constexpr UintTy mask = all_one<UintTy>(bits / 2) << (bits / 2);
        UintTy m = mask;
        int res = 0, shift = bits / 2;
        while (shift > 0)
        {
            if ((n & m))
            {
                res += shift;
                n >>= shift;
            }
            shift /= 2;
            m >>= shift;
        }
        return res;
    }
    template <typename IntTy>
    constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
    {
        if (b == 0)
        {
            x = 1;
            y = 0;
            return a;
        }
        IntTy k = a / b;
        IntTy g = exgcd(b, a - k * b, y, x);
        y -= k * x;
        return g;
    }

    template <typename IntTy>
    constexpr IntTy mod_inv(IntTy n, IntTy mod)
    {
        n %= mod;
        IntTy x = 0, y = 0;
        exgcd(n, mod, x, y);
        if (x < 0)
        {
            x += mod;
        }
        else if (x >= mod)
        {
            x -= mod;
        }
        return x;
    }
    template <typename IntTy>
    constexpr int hint_ctz(IntTy x)
    {
        return hint_log2(x ^ (x - 1));
    }
    // return n^-1 mod 2^pow, Newton iteration
    constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
    {
        const uint64_t mask = all_one<uint64_t>(pow);
        uint64_t xn = 1, t = n & mask;
        while (t != 1)
        {
            xn = (xn * (2 - t));
            t = (xn * n) & mask;
        }
        return xn & mask;
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, uint32_t n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        template <typename T>
        inline void transform2(T &sum, T &diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }
        namespace hint_ntt
        {
            template <uint32_t MOD>
            struct MontInt32
            {
                static constexpr uint64_t R = uint64_t(1) << 32;
                static constexpr uint32_t R_MASK = R - 1;
                static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
                static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
                static constexpr uint32_t MOD2 = MOD * 2;
                static_assert(hint_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
                static_assert(uint32_t(MOD_INV * MOD) == 1, "Montgomery32 modulus is not correct");

                uint32_t data;
                constexpr MontInt32() : data(0) {}
                constexpr MontInt32(uint32_t n) : data(toMont(n)) {}

                static constexpr uint32_t toMont(uint32_t n)
                {
                    return (uint64_t(n) << 32) % MOD;
                }
                static constexpr uint32_t redcLazy(uint64_t input)
                {
                    uint64_t n = uint32_t(input) * MOD_INV_NEG;
                    n = n * MOD + input;
                    return n >> 32;
                }
                static constexpr uint32_t redc(uint64_t input)
                {
                    uint32_t n = redcLazy(input);
                    return n < MOD ? n : n - MOD;
                }
                static constexpr uint32_t toInt(uint32_t n)
                {
                    return redc(n);
                }
                static constexpr uint32_t addMont(uint32_t m, uint32_t n)
                {
                    n = m + n;
                    return n < MOD * 2 ? n : n - MOD * 2;
                }
                static constexpr uint32_t subMont(uint32_t m, uint32_t n)
                {
                    n = m - n;
                    return n > m ? n + MOD * 2 : n;
                }
                static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
                {
                    return redcLazy(uint64_t(m) * n);
                }
                constexpr void fromInt(uint32_t n)
                {
                    data = toMont(n);
                }
                constexpr uint32_t toInt() const
                {
                    return toInt(data);
                }
                constexpr operator uint32_t() const
                {
                    return toInt();
                }
                constexpr MontInt32 operator+(MontInt32 rhs) const
                {
                    rhs.data = addMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32 operator-(MontInt32 rhs) const
                {
                    rhs.data = subMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32 operator*(MontInt32 rhs) const
                {
                    rhs.data = mulMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32 &operator+=(const MontInt32 &rhs)
                {
                    data = addMont(data, rhs.data);
                    return *this;
                }
                constexpr MontInt32 &operator-=(const MontInt32 &rhs)
                {
                    data = subMont(data, rhs.data);
                    return *this;
                }
                constexpr MontInt32 &operator*=(const MontInt32 &rhs)
                {
                    data = mulMont(data, rhs.data);
                    return *this;
                }
                static constexpr uint32_t mod()
                {
                    return MOD;
                }
            };

            template <uint32_t MOD>
            struct MontInt32X8
            {
                using MontInt = MontInt32<MOD>;
                __m256i data;

                MontInt32X8() : data(_mm256_setzero_si256()) {}
                MontInt32X8(MontInt x) : data(_mm256_set1_epi32(x.data)) {}
                MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
                {
                    data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
                }
                MontInt32X8(__m256i rhs) : data(rhs) {}
                template <typename T>
                MontInt32X8(const T *p)
                {
                    loadu(p);
                }
                static constexpr uint32_t mod()
                {
                    return MOD;
                }
                static MontInt32X8 zeroX8()
                {
                    return _mm256_setzero_si256();
                }
                static MontInt32X8 modX8()
                {
                    return _mm256_set1_epi32(mod());
                }
                static MontInt32X8 mod1X8()
                {
                    constexpr uint32_t MOD1 = mod() - 1;
                    return _mm256_set1_epi32(MOD1);
                }
                static MontInt32X8 mod2X8()
                {
                    constexpr uint32_t MOD2 = mod() * 2;
                    return _mm256_set1_epi32(MOD2);
                }
                static MontInt32X8 modNX8()
                {
                    constexpr uint32_t MOD_INV_NEG = MontInt::MOD_INV_NEG;
                    return _mm256_set1_epi32(MOD_INV_NEG);
                }
                static MontInt32X8 RX8()
                {
                    constexpr uint32_t R = (uint64_t(1) << 32) % mod();
                    return _mm256_set1_epi32(R);
                }

                MontInt32X8 mul64(MontInt32X8 rhs) const
                {
                    return _mm256_mul_epu32(data, rhs.data);
                }
                MontInt32X8 lShift64(int n) const
                {
                    return _mm256_slli_epi64(data, n);
                }
                MontInt32X8 rShift64(int n) const
                {
                    return _mm256_srli_epi64(data, n);
                }
                MontInt32X8 lShiftByte128(int n) const
                {
                    return _mm256_bslli_epi128(data, n);
                }
                MontInt32X8 rShiftByte128(int n) const
                {
                    return _mm256_bsrli_epi128(data, n);
                }
                template <int N>
                static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
                {
                    return _mm256_blend_epi32(a.data, b.data, N);
                }
                template <int N>
                static MontInt32X8 permute2X128(MontInt32X8 a, MontInt32X8 b)
                {
                    return _mm256_permute2x128_si256(a.data, b.data, N);
                }
                // a,b,c,d -> a,0,b,0
                MontInt32X8 evenElements() const
                {
                    return blend<0b10101010>(data, zeroX8());
                }
                // a,b,c,d -> 0,b,0,d
                MontInt32X8 oddElements() const
                {
                    return blend<0b01010101>(data, zeroX8());
                }
                std::pair<MontInt32X8, MontInt32X8> mul32X32To64(MontInt32X8 rhs) const
                {
                    return std::make_pair(mul64(rhs), rShift64(32).mul64(rhs.rShift64(32)));
                }
                static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
                {
                    MontInt32X8 p0 = even64.mul64(modNX8());
                    MontInt32X8 p1 = odd64.mul64(modNX8());
                    p0 = p0.mul64(modX8()).rawAdd64(even64).rShift64(32);
                    p1 = p1.mul64(modX8()).rawAdd64(odd64);
                    return blend<0b10101010>(p0, p1);
                }
                static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
                {
                    return montRedcLazy(even64, odd64).largeNorm();
                }
                MontInt32X8 toMont() const
                {
                    alignas(32) uint32_t temp[8];
                    store(temp);
                    for (auto &&i : temp)
                    {
                        i = (uint64_t(i) << 32) % MOD;
                    }
                    return MontInt32X8(temp);
                }
                MontInt32X8 toInt() const
                {
                    MontInt32X8 e = evenElements();
                    MontInt32X8 o = rShift64(32);
                    return montRedc(e, o);
                }
                MontInt32X8 largeNorm() const
                {
                    MontInt32X8 sub = (*this > mod1X8()) & modX8();
                    return rawSub(sub);
                }
                MontInt32X8 smallNorm() const
                {
                    MontInt32X8 add = (zeroX8() > *this) & modX8();
                    return rawAdd(add);
                }
                MontInt32X8 smallNorm2() const
                {
                    MontInt32X8 add = (zeroX8() > *this) & mod2X8();
                    return rawAdd(add);
                }

                MontInt32X8 addMont(MontInt32X8 rhs) const
                {
                    return rawAdd(rhs).largeNorm();
                }
                MontInt32X8 subMont(MontInt32X8 rhs) const
                {
                    return rawSub(rhs).smallNorm();
                }
                MontInt32X8 mulMont(MontInt32X8 rhs) const
                {
                    auto mulhl = mul32X32To64(rhs);
                    return montRedc(mulhl.first, mulhl.second);
                }

                MontInt32X8 addMont2(MontInt32X8 rhs) const
                {
                    return rawAdd(rhs).rawSub(mod2X8()).smallNorm2();
                }
                MontInt32X8 subMont2(MontInt32X8 rhs) const
                {
                    return rawSub(rhs).smallNorm2();
                }
                MontInt32X8 mulMont2(MontInt32X8 rhs) const
                {
                    auto mulhl = mul32X32To64(rhs);
                    return montRedcLazy(mulhl.first, mulhl.second);
                }

                MontInt32X8 operator+(MontInt32X8 rhs) const
                {
                    return addMont2(rhs);
                }
                MontInt32X8 operator-(MontInt32X8 rhs) const
                {
                    return subMont2(rhs);
                }
                MontInt32X8 operator*(MontInt32X8 rhs) const
                {
                    return mulMont2(rhs);
                }
                MontInt32X8 rawAdd(MontInt32X8 rhs) const
                {
                    return _mm256_add_epi32(data, rhs.data);
                }
                MontInt32X8 rawSub(MontInt32X8 rhs) const
                {
                    return _mm256_sub_epi32(data, rhs.data);
                }
                MontInt32X8 rawAdd64(MontInt32X8 rhs) const
                {
                    return _mm256_add_epi64(data, rhs.data);
                }
                MontInt32X8 rawSub64(MontInt32X8 rhs) const
                {
                    return _mm256_sub_epi64(data, rhs.data);
                }

                MontInt32X8 operator>(MontInt32X8 n) const
                {
                    return _mm256_cmpgt_epi32(data, n.data);
                }
                MontInt32X8 operator<(MontInt32X8 n) const
                {
                    return n > *this;
                }
                MontInt32X8 operator==(MontInt32X8 n) const
                {
                    return _mm256_cmpeq_epi32(data, n.data);
                }
                MontInt32X8 operator&(MontInt32X8 n) const
                {
                    return _mm256_and_si256(data, n.data);
                }
                MontInt32X8 operator|(MontInt32X8 n) const
                {
                    return _mm256_or_si256(data, n.data);
                }
                MontInt32X8 operator^(MontInt32X8 n) const
                {
                    return _mm256_xor_si256(data, n.data);
                }
                void set1(int32_t n)
                {
                    data = _mm256_set1_epi32(n);
                }
                template <typename T>
                void loadu(const T *p)
                {
                    data = _mm256_loadu_si256((const __m256i *)p);
                }
                template <typename T>
                void load(const T *p)
                {
                    data = _mm256_load_si256((const __m256i *)p);
                    // data = *reinterpret_cast<const __m256i *>(p);
                }
                template <typename T>
                void storeu(T *p) const
                {
                    _mm256_storeu_si256((__m256i *)p, data);
                }
                template <typename T>
                void store(T *p) const
                {
                    _mm256_store_si256((__m256i *)p, data);
                    // *reinterpret_cast<__m256i *>(p) = data;
                }
                void printI32() const
                {
                    alignas(32) int32_t v[8];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                              << "," << v[2] << "," << v[3]
                              << "," << v[4] << "," << v[5]
                              << "," << v[6] << "," << v[7] << "]" << std::endl;
                }
                void printU32() const
                {
                    alignas(32) uint32_t v[8];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                              << "," << v[2] << "," << v[3]
                              << "," << v[4] << "," << v[5]
                              << "," << v[6] << "," << v[7] << "]" << std::endl;
                }
                void printU64() const
                {
                    alignas(32) uint64_t v[4];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                              << "," << v[2] << "," << v[3] << "]" << std::endl;
                }
            };

            namespace split_radix_avx
            {
                template <uint32_t ROOT, typename ModIntType>
                inline ModIntType mul_w41(ModIntType n)
                {
                    constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                    return n * W_4_1;
                }
                template <uint64_t ROOT, typename ModIntType>
                inline ModIntType mul_w81(ModIntType n)
                {
                    constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                    return n * W_8_1;
                }
                template <uint64_t ROOT, typename ModIntType>
                inline ModIntType mul_w83(ModIntType n)
                {
                    constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 8) * 3);
                    return n * W_8_3;
                }
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                struct NTTShort
                {
                    static constexpr size_t ntt_len = LEN;
                    static constexpr size_t half_len = ntt_len / 2;
                    static constexpr size_t quarter_len = ntt_len / 4;
                    static constexpr size_t octant_len = ntt_len / 8;
                    static constexpr size_t rank = quarter_len;
                    static constexpr int log_len = hint_log2(ntt_len);

                    using ModIntX8 = MontInt32X8<MOD>;
                    using ModIntType = typename ModIntX8::MontInt;
                    using HalfNTT = NTTShort<half_len, MOD, ROOT>;
                    using QuarterNTT = NTTShort<quarter_len, MOD, ROOT>;
                    using TableType = AlignAry<ModIntType, quarter_len>;

                    static constexpr TableType getNTTTable(int factor)
                    {
                        ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / LEN) * factor);
                        ModIntType omega(1);
                        TableType res;
                        for (auto &&i : res)
                        {
                            i = omega;
                            omega *= root;
                        }
                        return res;
                    }

                    static TableType table1;
                    static TableType table3;

                    static constexpr uint64_t mod()
                    {
                        return ModIntType::mod();
                    }

                    static constexpr uint64_t root()
                    {
                        return ROOT;
                    }
                    static void dit(ModIntType in_out[])
                    {
                        QuarterNTT::dit(in_out + half_len + quarter_len);
                        QuarterNTT::dit(in_out + half_len);
                        HalfNTT::dit(in_out);

                        ModIntX8 omega1(&table1[0]), omega3(&table3[0]);
                        for (auto it = in_out, it1 = &table1[0], it3 = &table3[0]; it < in_out + quarter_len; it += 8, it1 += 8, it3 += 8)
                        {
                            omega1.load(it1), omega3.load(it3);
                            ModIntX8 temp0, temp1, temp2, temp3;
                            temp2.load(&it[rank * 2]);
                            temp3.load(&it[rank * 3]);
                            temp2 = temp2 * omega1;
                            temp3 = temp3 * omega3;

                            transform2(temp2, temp3);
                            constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (ModIntType::mod() - 1) / 4);
                            temp3 = temp3 * ModIntX8(W_4_1);

                            temp0.load(&it[0]);
                            temp1.load(&it[rank]);
                            (temp0 + temp2).store(&it[0]);
                            (temp1 + temp3).store(&it[rank]);
                            (temp0 - temp2).store(&it[rank * 2]);
                            (temp1 - temp3).store(&it[rank * 3]);
                        }
                    }

                    static void dif(ModIntType in_out[])
                    {
                        // constexpr ModIntType u1 = qpow(ModIntType(ROOT), (MOD - 1) / ntt_len * 8);
                        // constexpr ModIntType u3 = qpow(u1, 3);
                        ModIntX8 omega1(&table1[0]), omega3(&table3[0]);
                        for (auto it = in_out, it1 = &table1[0], it3 = &table3[0]; it < in_out + quarter_len; it += 8, it1 += 8, it3 += 8)
                        {
                            omega1.load(it1), omega3.load(it3);
                            ModIntX8 temp0, temp1, temp2, temp3;
                            temp0.load(&it[0]);
                            temp1.load(&it[rank]);
                            temp2.load(&it[rank * 2]);
                            temp3.load(&it[rank * 3]);
                            (temp0 + temp2).store(&it[0]);
                            (temp1 + temp3).store(&it[rank]);

                            temp2 = temp0 - temp2;
                            temp3 = temp1 - temp3;
                            constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (ModIntType::mod() - 1) / 4);
                            temp3 = temp3 * ModIntX8(W_4_1);
                            transform2(temp2, temp3);

                            (temp2 * omega1).store(&it[rank * 2]);
                            (temp3 * omega3).store(&it[rank * 3]);
                            // omega1 = omega1 * ModIntX8(u1);
                            // omega3 = omega3 * ModIntX8(u3);
                        }
                        HalfNTT::dif(in_out);
                        QuarterNTT::dif(in_out + half_len);
                        QuarterNTT::dif(in_out + half_len + quarter_len);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < LEN)
                        {
                            HalfNTT::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < LEN)
                        {
                            HalfNTT::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table1 = NTTShort<LEN, MOD, ROOT>::getNTTTable(1);
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table3 = NTTShort<LEN, MOD, ROOT>::getNTTTable(3);

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<0, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[]) {}
                    static void dif(ModIntType in_out[]) {}
                    static void dit(ModIntType in_out[], size_t len) {}
                    static void dif(ModIntType in_out[], size_t len) {}
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<1, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[]) {}
                    static void dif(ModIntType in_out[]) {}
                    static void dit(ModIntType in_out[], size_t len) {}
                    static void dif(ModIntType in_out[], size_t len) {}
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<2, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[])
                    {
                        transform2(in_out[0], in_out[1]);
                    }
                    static void dif(ModIntType in_out[])
                    {
                        transform2(in_out[0], in_out[1]);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 2)
                        {
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 2)
                        {
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<4, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];

                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);

                        in_out[0] = temp0 + temp2;
                        in_out[1] = temp1 + temp3;
                        in_out[2] = temp0 - temp2;
                        in_out[3] = temp1 - temp3;
                    }
                    static void dif(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        temp3 = mul_w41<ROOT>(temp3);

                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                        in_out[2] = temp2 + temp3;
                        in_out[3] = temp2 - temp3;
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 4)
                        {
                            NTTShort<2, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 4)
                        {
                            NTTShort<2, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<8, MOD, ROOT>
                {
                    using ModIntX8 = MontInt32X8<MOD>;
                    using ModIntType = typename ModIntX8::MontInt;

                    static void dit(ModIntType in_out[])
                    {
                        static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        static constexpr ModIntType w2 = qpow(w1, 2);
                        static constexpr ModIntType w3 = qpow(w1, 3);
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];
                        auto temp4 = in_out[4];
                        auto temp5 = in_out[5];
                        auto temp6 = in_out[6];
                        auto temp7 = in_out[7];

                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        transform2(temp4, temp5);
                        transform2(temp6, temp7);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp7 = mul_w41<ROOT>(temp7);

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        transform2(temp4, temp6);
                        transform2(temp5, temp7);
                        temp5 = temp5 * w1;
                        temp6 = temp6 * w2;
                        temp7 = temp7 * w3;

                        in_out[0] = temp0 + temp4;
                        in_out[1] = temp1 + temp5;
                        in_out[2] = temp2 + temp6;
                        in_out[3] = temp3 + temp7;
                        in_out[4] = temp0 - temp4;
                        in_out[5] = temp1 - temp5;
                        in_out[6] = temp2 - temp6;
                        in_out[7] = temp3 - temp7;
                    }
                    static void dif(ModIntType in_out[])
                    {
                        static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        static constexpr ModIntType w2 = qpow(w1, 2);
                        static constexpr ModIntType w3 = qpow(w1, 3);
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];
                        auto temp4 = in_out[4];
                        auto temp5 = in_out[5];
                        auto temp6 = in_out[6];
                        auto temp7 = in_out[7];

                        transform2(temp0, temp4);
                        transform2(temp1, temp5);
                        transform2(temp2, temp6);
                        transform2(temp3, temp7);
                        temp5 = temp5 * w1;
                        temp6 = temp6 * w2;
                        temp7 = temp7 * w3;

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        transform2(temp4, temp6);
                        transform2(temp5, temp7);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp7 = mul_w41<ROOT>(temp7);

                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                        in_out[2] = temp2 + temp3;
                        in_out[3] = temp2 - temp3;
                        in_out[4] = temp4 + temp5;
                        in_out[5] = temp4 - temp5;
                        in_out[6] = temp6 + temp7;
                        in_out[7] = temp6 - temp7;
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 8)
                        {
                            NTTShort<4, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 8)
                        {
                            NTTShort<4, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<16, MOD, ROOT>
                {
                    using ModIntX8 = MontInt32X8<MOD>;
                    using ModIntType = typename ModIntX8::MontInt;
                    static constexpr ModIntType W_16_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
                    static constexpr ModIntType W_16_2 = qpow(W_16_1, 2);
                    static constexpr ModIntType W_16_3 = qpow(W_16_1, 3);
                    static constexpr ModIntType W_16_4 = qpow(W_16_1, 4);
                    static constexpr ModIntType W_16_5 = qpow(W_16_1, 5);
                    static constexpr ModIntType W_16_6 = qpow(W_16_1, 6);
                    static constexpr ModIntType W_16_7 = qpow(W_16_1, 7);
                    static ModIntX8 transform2X4(ModIntX8 in)
                    {
                        ModIntX8 temp1 = in.rShift64(32);                    // b, 0
                        ModIntX8 temp2 = in.lShift64(32);                    // 0, a
                        temp1 = in.rawSub(ModIntX8::mod2X8()).rawAdd(temp1); // a + b ,X
                        temp2 = temp2.rawSub(in);                            // X, a - b
                        return ModIntX8::template blend<0b10101010>(temp1, temp2).smallNorm2();
                    }
                    static void dit4X4(ModIntX8 &A, ModIntX8 &B)
                    {
                        constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp0 = transform2X4(A); // A
                        temp1 = transform2X4(B); // B

                        temp2 = temp0.rShiftByte128(8); // A2,A3,X,X
                        temp3 = temp1.lShiftByte128(8); // X,X,B0,B1

                        temp0 = ModIntX8::template blend<0b11001100>(temp0, temp3); // A0,A1,B0,B1
                        temp1 = ModIntX8::template blend<0b11001100>(temp2, temp1); // A2,A3,B2,B3

                        omega.load(w_arr);
                        temp1 = temp1 * omega; // (A2,A3,B2,B3)*w

                        temp2 = temp0 + temp1; // A0,A1,B0,B1
                        temp3 = temp0 - temp1; // A2,A3,B2,B3

                        temp0 = temp2.rShiftByte128(8); // B0,B1,X,X
                        temp1 = temp3.lShiftByte128(8); // X,X,A2,A3

                        A = ModIntX8::template blend<0b11001100>(temp2, temp1); // A0,A1,A2,A3
                        B = ModIntX8::template blend<0b11001100>(temp0, temp3); // B0,B1,B2,B3
                    }
                    static void dit8X2(ModIntX8 &A, ModIntX8 &B)
                    {
                        constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        constexpr ModIntType W_8_2 = qpow(W_8_1, 2);
                        constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};

                        dit4X4(A, B);
                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp0 = ModIntX8::template permute2X128<0x20>(A, B); // A0,B0
                        temp1 = ModIntX8::template permute2X128<0x31>(A, B); // A1,B1
                        omega.load(w_arr);
                        temp1 = temp1 * omega;
                        temp2 = temp0 + temp1;                                   // A0,B0
                        temp3 = temp0 - temp1;                                   // A1,B1
                        A = ModIntX8::template permute2X128<0x20>(temp2, temp3); // A0,A1
                        B = ModIntX8::template permute2X128<0x31>(temp2, temp3); // B0,B1
                    }
                    static void dit(ModIntX8 &temp0, ModIntX8 &temp1)
                    {
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_16_1, W_16_2, W_16_3, W_16_4, W_16_5, W_16_6, W_16_7};
                        ModIntX8 omega;
                        omega.load(w_arr);
                        dit8X2(temp0, temp1);
                        temp1 = temp1 * omega;
                        transform2(temp0, temp1);
                    }
                    static void dit(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1;
                        temp0.load(&in_out[0]), temp1.load(&in_out[8]);
                        dit(temp0, temp1);
                        temp0.store(&in_out[0]);
                        temp1.store(&in_out[8]);
                    }
                    static void dif4X4(ModIntX8 &A, ModIntX8 &B)
                    {
                        constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1, ModIntType(1), W_4_1};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp2 = A.rShiftByte128(8); // A2,A3,X,X
                        temp3 = B.lShiftByte128(8); // X,X,B0,B1

                        temp0 = ModIntX8::template blend<0b11001100>(A, temp3); // A0,A1,B0,B1
                        temp1 = ModIntX8::template blend<0b11001100>(temp2, B); // A2,A3,B2,B3

                        temp2 = temp0 + temp1; // A0,A1,B0,B1
                        temp3 = temp0 - temp1; // A2,A3,B2,B3
                        omega.load(w_arr);
                        temp3 = temp3 * omega; // (A2,A3,B2,B3)*w

                        temp0 = temp2.rShiftByte128(8); // B0,B1,X,X
                        temp1 = temp3.lShiftByte128(8); // X,X,A2,A3

                        temp2 = ModIntX8::template blend<0b11001100>(temp2, temp1); // A0,A1,A2,A3
                        temp3 = ModIntX8::template blend<0b11001100>(temp0, temp3); // B0,B1,B2,B3

                        A = transform2X4(temp2); // A
                        B = transform2X4(temp3); // B
                    }
                    static void dif8X2(ModIntX8 &A, ModIntX8 &B)
                    {
                        constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        constexpr ModIntType W_8_2 = qpow(W_8_1, 2);
                        constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp0 = ModIntX8::template permute2X128<0x20>(A, B); // A0,B0
                        temp1 = ModIntX8::template permute2X128<0x31>(A, B); // A1,B1
                        temp2 = temp0 + temp1;                               // A0,B0
                        temp3 = temp0 - temp1;                               // A1,B1
                        omega.load(w_arr);
                        temp3 = temp3 * omega;
                        A = ModIntX8::template permute2X128<0x20>(temp2, temp3); // A0,A1
                        B = ModIntX8::template permute2X128<0x31>(temp2, temp3); // B0,B1
                        dif4X4(A, B);
                    }
                    static void dif(ModIntX8 &temp0, ModIntX8 &temp1)
                    {
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_16_1, W_16_2, W_16_3, W_16_4, W_16_5, W_16_6, W_16_7};
                        ModIntX8 omega;
                        omega.load(w_arr);
                        transform2(temp0, temp1);
                        temp1 = temp1 * omega;
                        dif8X2(temp0, temp1);
                    }
                    static void dif(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1;
                        temp0.load(&in_out[0]), temp1.load(&in_out[8]);
                        dif(temp0, temp1);
                        temp0.store(&in_out[0]);
                        temp1.store(&in_out[8]);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 16)
                        {
                            NTTShort<8, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 16)
                        {
                            NTTShort<8, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };
                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<32, MOD, ROOT>
                {
                    using ModIntX8 = MontInt32X8<MOD>;
                    using ModIntType = typename ModIntX8::MontInt;

                    using NTT16 = NTTShort<16, MOD, ROOT>;
                    using TableType = AlignAry<ModIntType, 8>;

                    static constexpr TableType getNTTTable(int factor)
                    {
                        ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 32) * factor);
                        ModIntType omega(1);
                        TableType res;
                        for (auto &&i : res)
                        {
                            i = omega;
                            omega *= root;
                        }
                        return res;
                    }

                    static void dit(ModIntType in_out[])
                    {
                        constexpr TableType w1_arr = getNTTTable(1);
                        constexpr TableType w3_arr = getNTTTable(3);
                        ModIntX8 temp0, temp1, temp2, temp3, omega1, omega3;
                        omega1.load(w1_arr.data());
                        omega3.load(w3_arr.data());
                        temp0.load(&in_out[0]);
                        temp1.load(&in_out[8]);
                        temp2.load(&in_out[16]);
                        temp3.load(&in_out[24]);
                        NTT16::dit(temp0, temp1);
                        NTT16::dit8X2(temp2, temp3);
                        {
                            temp2 = temp2 * omega1;
                            temp3 = temp3 * omega3;

                            transform2(temp2, temp3);
                            constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (MOD - 1) / 4);
                            temp3 = temp3 * ModIntX8(W_4_1);

                            transform2(temp0, temp2);
                            transform2(temp1, temp3);
                        }
                        temp0.store(&in_out[0]);
                        temp1.store(&in_out[8]);
                        temp2.store(&in_out[16]);
                        temp3.store(&in_out[24]);
                    }
                    static void dif(ModIntType in_out[])
                    {
                        constexpr TableType w1_arr = getNTTTable(1);
                        constexpr TableType w3_arr = getNTTTable(3);
                        ModIntX8 temp0, temp1, temp2, temp3, omega1, omega3;
                        omega1.load(w1_arr.data());
                        omega3.load(w3_arr.data());
                        temp0.load(&in_out[0]);
                        temp1.load(&in_out[8]);
                        temp2.load(&in_out[16]);
                        temp3.load(&in_out[24]);
                        {
                            transform2(temp0, temp2);
                            transform2(temp1, temp3);

                            constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (MOD - 1) / 4);
                            temp3 = temp3 * ModIntX8(W_4_1);
                            transform2(temp2, temp3);

                            temp2 = temp2 * omega1;
                            temp3 = temp3 * omega3;
                        }
                        NTT16::dif(temp0, temp1);
                        NTT16::dif8X2(temp2, temp3);
                        temp0.store(&in_out[0]);
                        temp1.store(&in_out[8]);
                        temp2.store(&in_out[16]);
                        temp3.store(&in_out[24]);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 32)
                        {
                            NTT16::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 32)
                        {
                            NTT16::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTT
                {
                    static constexpr uint32_t mod()
                    {
                        return MOD;
                    }
                    static constexpr uint32_t root()
                    {
                        return ROOT;
                    }
                    static constexpr uint32_t iroot()
                    {
                        return mod_inv<int64_t>(root(), mod());
                    }
                    static constexpr bool selfCheck()
                    {
                        uint64_t n = root();
                        n *= uint64_t(iroot());
                        n %= uint64_t(mod());
                        return n == uint64_t(1);
                    }
                    static_assert(root() < mod(), "ROOT must be smaller than MOD");
                    static_assert(selfCheck(), "IROOT * ROOT % MOD must be 1");
                    static constexpr int mod_bits = hint_log2(mod()) + 1;
                    static constexpr int max_log_len = hint_ctz(mod() - 1);

                    static constexpr size_t getMaxLen()
                    {
                        if (max_log_len < sizeof(size_t) * CHAR_BIT)
                        {
                            return size_t(1) << max_log_len;
                        }
                        return size_t(1) << (sizeof(size_t) * CHAR_BIT - 1);
                    }
                    static constexpr size_t ntt_max_len = getMaxLen();

                    using INTT = NTT<mod(), iroot()>;

                    static constexpr size_t LONG_THRESHOLD = size_t(1) << 12;

                    using NTTTemplate = NTTShort<LONG_THRESHOLD, MOD, ROOT>;
                    using ModIntType = typename NTTTemplate::ModIntType;
                    using ModIntX8 = typename NTTTemplate::ModIntX8;

                    static constexpr ModIntType W_4_1 = qpow(ModIntType(root()), (mod() - 1) / 4);
                    static constexpr ModIntType W_8_1 = qpow(ModIntType(root()), (mod() - 1) / 8);
                    static constexpr ModIntType W_8_3 = qpow(W_8_1, 3);

                    static ModIntX8 unitx8(size_t ntt_len, int factor)
                    {
                        return ModIntX8(qpow(ModIntType(root()), (mod() - 1) / ntt_len * factor * 8));
                    }
                    static ModIntX8 omegax8(size_t ntt_len, int factor)
                    {
                        alignas(32) ModIntType w_arr[8]{};
                        ModIntType w(1), unit(qpow(ModIntType(root()), (mod() - 1) / ntt_len * factor));
                        for (auto &&i : w_arr)
                        {
                            i = w;
                            w = w * unit;
                        }
                        return ModIntX8(w_arr);
                    }
                    static void dit(ModIntType in_out[], size_t ntt_len)
                    {
                        ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
                        if (ntt_len <= LONG_THRESHOLD)
                        {
                            NTTTemplate::dit(in_out, ntt_len);
                            return;
                        }
                        size_t octant_len = ntt_len / 8;
                        dit(in_out + octant_len * 7, ntt_len / 8);
                        dit(in_out + octant_len * 6, ntt_len / 8);
                        dit(in_out + octant_len * 4, ntt_len / 4);
                        dit(in_out, ntt_len / 2);
                        const ModIntX8 unit1_x8 = unitx8(ntt_len, 1), unit3_x8 = unitx8(ntt_len, 3), unit7_x8 = unitx8(ntt_len, 7);
                        ModIntX8 omega1 = omegax8(ntt_len, 1), omega3 = omegax8(ntt_len, 3), omega7 = omegax8(ntt_len, 7);
                        for (auto it = in_out; it < in_out + octant_len; it += 8)
                        {
                            {
                                ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
                                temp0.load(&it[0]);
                                temp1.load(&it[octant_len]);
                                temp2.load(&it[octant_len * 2]);
                                temp3.load(&it[octant_len * 3]);
                                temp4.load(&it[octant_len * 4]);
                                temp5.load(&it[octant_len * 5]);
                                temp6.load(&it[octant_len * 6]);
                                temp7.load(&it[octant_len * 7]);
                                temp4 = temp4 * omega1;
                                temp5 = temp5 * omega1;
                                temp6 = temp6 * omega3;
                                temp7 = temp7 * omega7;

                                transform2(temp6, temp7);
                                transform2(temp4, temp6);
                                temp6 = temp6 * ModIntX8(W_4_1);
                                temp7 = temp7 * ModIntX8(W_4_1);
                                transform2(temp5, temp7);
                                temp5 = temp5 * ModIntX8(W_8_1);
                                temp7 = temp7 * ModIntX8(W_8_3);

                                (temp0 + temp4).store(&it[0]);
                                (temp1 + temp5).store(&it[octant_len]);
                                (temp2 + temp6).store(&it[octant_len * 2]);
                                (temp3 + temp7).store(&it[octant_len * 3]);
                                (temp0 - temp4).store(&it[octant_len * 4]);
                                (temp1 - temp5).store(&it[octant_len * 5]);
                                (temp2 - temp6).store(&it[octant_len * 6]);
                                (temp3 - temp7).store(&it[octant_len * 7]);
                            }
                            omega1 = omega1 * unit1_x8;
                            omega3 = omega3 * unit3_x8;
                            omega7 = omega7 * unit7_x8;
                        }
                    }
                    static void dif(ModIntType in_out[], size_t ntt_len)
                    {
                        ntt_len = std::min(int_floor2(ntt_len), ntt_max_len);
                        if (ntt_len <= LONG_THRESHOLD)
                        {
                            NTTTemplate::dif(in_out, ntt_len);
                            return;
                        }
                        size_t octant_len = ntt_len / 8;
                        const ModIntX8 unit1_x8 = unitx8(ntt_len, 1), unit3_x8 = unitx8(ntt_len, 3), unit7_x8 = unitx8(ntt_len, 7);
                        ModIntX8 omega1 = omegax8(ntt_len, 1), omega3 = omegax8(ntt_len, 3), omega7 = omegax8(ntt_len, 7);
                        for (auto it = in_out; it < in_out + octant_len; it += 8)
                        {
                            {
                                ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
                                temp0.load(&it[0]);
                                temp1.load(&it[octant_len]);
                                temp2.load(&it[octant_len * 2]);
                                temp3.load(&it[octant_len * 3]);
                                temp4.load(&it[octant_len * 4]);
                                temp5.load(&it[octant_len * 5]);
                                temp6.load(&it[octant_len * 6]);
                                temp7.load(&it[octant_len * 7]);
                                transform2(temp0, temp4);
                                transform2(temp1, temp5);
                                transform2(temp2, temp6);
                                transform2(temp3, temp7);

                                temp5 = temp5 * ModIntX8(W_8_1);
                                temp7 = temp7 * ModIntX8(W_8_3);
                                transform2(temp5, temp7);
                                temp6 = temp6 * ModIntX8(W_4_1);
                                temp7 = temp7 * ModIntX8(W_4_1);
                                transform2(temp4, temp6);
                                transform2(temp6, temp7);

                                (temp0).store(&it[0]);
                                (temp1).store(&it[octant_len]);
                                (temp2).store(&it[octant_len * 2]);
                                (temp3).store(&it[octant_len * 3]);
                                (temp4 * omega1).store(&it[octant_len * 4]);
                                (temp5 * omega1).store(&it[octant_len * 5]);
                                (temp6 * omega3).store(&it[octant_len * 6]);
                                (temp7 * omega7).store(&it[octant_len * 7]);
                            }
                            omega1 = omega1 * unit1_x8;
                            omega3 = omega3 * unit3_x8;
                            omega7 = omega7 * unit7_x8;
                        }
                        dif(in_out, octant_len * 4);
                        dif(in_out + octant_len * 4, octant_len * 2);
                        dif(in_out + octant_len * 6, octant_len);
                        dif(in_out + octant_len * 7, octant_len);
                    }
                    static void convolution(ModIntType in1[], ModIntType in2[], ModIntType out[], size_t ntt_len)
                    {
                        const ModIntType inv_len(qpow(ModIntType(ntt_len), mod() - 2));
                        dif(in1, ntt_len);
                        dif(in2, ntt_len);
                        if (ntt_len < 8)
                        {
                            for (size_t i = 0; i < ntt_len; i++)
                            {
                                out[i] = in1[i] * in2[i] * inv_len;
                            }
                        }
                        else
                        {
                            const ModIntX8 inv8(inv_len);
                            for (size_t i = 0; i < ntt_len; i += 8)
                            {
                                ModIntX8 temp0, temp1;
                                temp0.load(&in1[i]), temp1.load(&in2[i]);
                                (temp0 * temp1 * inv8).store(&out[i]);
                            }
                        }
                        INTT::dit(out, ntt_len);
                    }
                };
            };
        }
    }
}

using namespace std;
using namespace hint;
using namespace hint_transform::hint_ntt::split_radix_avx;

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    size_t conv_len = m + n + 1, ntt_len = int_ceil2(conv_len);
    using ntt = NTT<998244353, 3>;
    using ModInt = ntt::ModIntType;
    ModInt *a_ntt = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
    ModInt *b_ntt = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
    std::copy(a, a + n, a_ntt);
    std::copy(b, a + m, b_ntt);
    ntt::convolution(a_ntt, b_ntt, a_ntt, ntt_len);
    std::copy(a_ntt, a_ntt + conv_len, c);
    _mm_free(a_ntt);
    _mm_free(b_ntt);
}

CompilationN/AN/ACompile ErrorScore: N/A


Judge Duck Online | 评测鸭在线
Server Time: 2024-05-09 15:15:38 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用