提交记录 27922


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002i. 【模板题】多项式乘法 Accepted 100 6.927 ms 5220 KB C++14 88.70 KB
提交时间 评测时间
2025-02-12 14:12:47 2025-02-12 14:12:53
// https://github.com/With-Sky/

#include <vector>
#include <complex>
#include <iostream>
#include <cassert>
#include <cstring>
#include <ctime>
#include <cstddef>
#include <cstdint>
#include <climits>
#include <string>
#include <array>
#include <fstream>
#include <type_traits>
#include <immintrin.h>
#pragma GCC optimize("inline")
#pragma GCC target("avx2")

namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;
    constexpr size_t L1_BYTE = size_t(1) << 14; // L1 cache size, change this if you know your cache size.
    constexpr size_t L2_BYTE = size_t(1) << 16; // L2 cache size, change this if you know your cache size.

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    // bits of 1, equals to 2^bits - 1
    template <typename T>
    constexpr T all_one(int bits)
    {
        T temp = T(1) << (bits - 1);
        return temp - 1 + temp;
    }

    // Leading zeros
    template <typename IntTy>
    constexpr int hint_clz(IntTy x)
    {
        constexpr uint32_t MASK32 = uint32_t(0xFFFF) << 16;
        int res = sizeof(IntTy) * CHAR_BIT;
        if (x & MASK32)
        {
            res -= 16;
            x >>= 16;
        }
        if (x & (MASK32 >> 8))
        {
            res -= 8;
            x >>= 8;
        }
        if (x & (MASK32 >> 12))
        {
            res -= 4;
            x >>= 4;
        }
        if (x & (MASK32 >> 14))
        {
            res -= 2;
            x >>= 2;
        }
        if (x & (MASK32 >> 15))
        {
            res -= 1;
            x >>= 1;
        }
        return res - x;
    }
    // Leading zeros
    constexpr int hint_clz(uint64_t x)
    {
        if (x & (uint64_t(0xFFFFFFFF) << 32))
        {
            return hint_clz(uint32_t(x >> 32));
        }
        return hint_clz(uint32_t(x)) + 32;
    }

    // Integer bit length
    template <typename IntTy>
    constexpr int hint_bit_length(IntTy x)
    {
        if (x == 0)
        {
            return 0;
        }
        return sizeof(IntTy) * CHAR_BIT - hint_clz(x);
    }

    // Integer log2
    template <typename IntTy>
    constexpr int hint_log2(IntTy x)
    {
        return (sizeof(IntTy) * CHAR_BIT - 1) - hint_clz(x);
    }

    constexpr int hint_ctz(uint32_t x)
    {
        int r = 31;
        x &= (-x);
        if (x & 0x0000FFFF)
        {
            r -= 16;
        }
        if (x & 0x00FF00FF)
        {
            r -= 8;
        }
        if (x & 0x0F0F0F0F)
        {
            r -= 4;
        }
        if (x & 0x33333333)
        {
            r -= 2;
        }
        if (x & 0x55555555)
        {
            r -= 1;
        }
        return r;
    }

    constexpr int hint_ctz(uint64_t x)
    {
        if (x & 0xFFFFFFFF)
        {
            return hint_ctz(uint32_t(x));
        }
        return hint_ctz(uint32_t(x >> 32)) + 32;
    }

    // Fast power
    template <typename T, typename T1>
    constexpr T qpow(T m, T1 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result *= m;
            }
            m *= m;
            n >>= 1;
        }
        return result;
    }

    // Fast power with mod
    template <typename T, typename T1>
    constexpr T qpow(T m, T1 n, T mod)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result *= m;
                result %= mod;
            }
            m *= m;
            m %= mod;
            n >>= 1;
        }
        return result;
    }

    // Get cloest power of 2 that not larger than n
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * CHAR_BIT;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    // Get cloest power of 2 that not smaller than n
    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * CHAR_BIT;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    // x + y = sum with carry
    template <typename UintTy>
    constexpr UintTy add_half(UintTy x, UintTy y, bool &cf)
    {
        x = x + y;
        cf = (x < y);
        return x;
    }

    // x - y = diff with borrow
    template <typename UintTy>
    constexpr UintTy sub_half(UintTy x, UintTy y, bool &bf)
    {
        y = x - y;
        bf = (y > x);
        return y;
    }

    // x + y + cf = sum with carry
    template <typename UintTy>
    constexpr UintTy add_carry(UintTy x, UintTy y, bool &cf)
    {
        UintTy sum = x + cf;
        cf = (sum < x);
        sum += y;             // carry
        cf = cf || (sum < y); // carry
        return sum;
    }

    // x - y - bf = diff with borrow
    template <typename UintTy>
    constexpr UintTy sub_borrow(UintTy x, UintTy y, bool &bf)
    {
        UintTy diff = x - bf;
        bf = (diff > x);
        y = diff - y;          // borrow
        bf = bf || (y > diff); // borrow
        return y;
    }

    // a * x + b * y = gcd(a,b)
    template <typename IntTy>
    constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
    {
        if (b == 0)
        {
            x = 1;
            y = 0;
            return a;
        }
        IntTy k = a / b;
        IntTy g = exgcd(b, a - k * b, y, x);
        y -= k * x;
        return g;
    }

    // return n^-1 mod mod
    template <typename IntTy>
    constexpr IntTy mod_inv(IntTy n, IntTy mod)
    {
        n %= mod;
        IntTy x = 0, y = 0;
        exgcd(n, mod, x, y);
        if (x < 0)
        {
            x += mod;
        }
        else if (x >= mod)
        {
            x -= mod;
        }
        return x;
    }

    // return n^-1 mod 2^pow, Newton iteration
    constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
    {
        const uint64_t mask = all_one<uint64_t>(pow);
        uint64_t xn = 1, t = n & mask;
        while (t != 1)
        {
            xn = (xn * (2 - t));
            t = (xn * n) & mask;
        }
        return xn & mask;
    }

    namespace modint
    {
        //  Montgomery for mod < 2^30
        //  default R = 2^32
        template <uint32_t MOD>
        struct MontInt32Lazy
        {
            static_assert(hint_log2(MOD) < 30, "MOD can't be larger than 30 bits");
            uint32_t data;
            using IntType = uint32_t;

            constexpr MontInt32Lazy() : data(0) {}
            constexpr MontInt32Lazy(uint32_t n) : data(toMont(n)) {}

            constexpr MontInt32Lazy operator+(MontInt32Lazy rhs) const
            {
                rhs.data = data + rhs.data;
                return rhs.largeNorm();
            }
            constexpr MontInt32Lazy operator-(MontInt32Lazy rhs) const
            {
                rhs.data = data - rhs.data;
                rhs.data = rhs.data > data ? rhs.data + mod2() : rhs.data;
                return rhs;
            }
            constexpr MontInt32Lazy operator*(MontInt32Lazy rhs) const
            {
                rhs.data = redcLazy(uint64_t(data) * rhs.data);
                return rhs;
            }
            constexpr MontInt32Lazy &operator+=(const MontInt32Lazy &rhs)
            {
                return *this = *this + rhs;
            }
            constexpr MontInt32Lazy &operator-=(const MontInt32Lazy &rhs)
            {
                return *this = *this - rhs;
            }
            constexpr MontInt32Lazy &operator*=(const MontInt32Lazy &rhs)
            {
                data = redc(uint64_t(data) * rhs.data);
                return *this;
            }
            constexpr MontInt32Lazy norm() const
            {
                MontInt32Lazy res;
                res.data = data >= mod() ? data - mod() : data;
                return res;
            }
            constexpr MontInt32Lazy largeNorm() const
            {
                MontInt32Lazy res;
                res.data = data >= mod2() ? data - mod2() : data;
                return res;
            }
            constexpr MontInt32Lazy add(MontInt32Lazy rhs) const
            {
                rhs.data = data + rhs.data;
                return rhs;
            }
            constexpr MontInt32Lazy sub(MontInt32Lazy rhs) const
            {
                rhs.data = data - rhs.data + mod2();
                return rhs;
            }
            constexpr operator uint32_t() const
            {
                return toInt(data);
            }
            constexpr MontInt32Lazy inv() const
            {
                return qpow(*this, mod() - 2);
            }
            static constexpr uint32_t mod()
            {
                return MOD;
            }
            static constexpr uint32_t mod2()
            {
                return MOD * 2;
            }
            static constexpr uint32_t modInv()
            {
                constexpr uint32_t mod_inv = uint32_t(inv_mod2pow(mod(), 32));
                return mod_inv;
            }
            static constexpr uint32_t modNegInv()
            {
                constexpr uint32_t mod_neg_inv = uint32_t(0 - modInv());
                return mod_neg_inv;
            }
            static_assert((mod() * modInv()) == 1, "mod_inv not correct");

            static constexpr uint32_t toMont(uint32_t n)
            {
                return (uint64_t(n) << 32) % MOD;
            }
            static constexpr uint32_t toInt(uint32_t n)
            {
                return redc(n);
            }

            static constexpr uint32_t redcLazy(uint64_t n)
            {
                uint32_t prod = uint32_t(n) * modNegInv();
                return (uint64_t(prod) * mod() + n) >> 32;
            }
            static constexpr uint32_t redc(uint64_t n)
            {
                uint32_t res = redcLazy(n);
                return res < mod() ? res : res - mod();
            }
            constexpr MontInt32Lazy divR() const
            {
                MontInt32Lazy res;
                res.data = redc(data);
                return res;
            }
            constexpr MontInt32Lazy mulR() const
            {
                MontInt32Lazy res;
                res.data = toMont(data);
                return res;
            }
        };
    }

    namespace simd
    {
        template <typename MontInt32Type>
        struct MontInt32X8
        {
            using MontInt = MontInt32Type;
            using Int32X8 = __m256i;
            __m256i data;

            MontInt32X8() : data(_mm256_setzero_si256()) {}
            MontInt32X8(MontInt x) : data(_mm256_set1_epi32(x.data)) {}
            MontInt32X8(Int32X8 n) : data(toMont(n)) {}
            template <typename T>
            MontInt32X8(const T *p)
            {
                loadu(p);
            }

            MontInt32X8 operator+(MontInt32X8 rhs) const
            {
                rhs.data = _mm256_add_epi32(data, rhs.data);
                return rhs.largeNorm();
            }
            MontInt32X8 operator-(MontInt32X8 rhs) const
            {
                rhs.data = _mm256_sub_epi32(data, rhs.data);
                return rhs.smallNorm();
            }
            MontInt32X8 operator*(MontInt32X8 rhs) const
            {
                rhs.data = mulMontLazy(data, rhs.data);
                return rhs;
            }
            MontInt32X8 &operator+=(const MontInt32X8 &rhs)
            {
                return *this = *this + rhs;
            }
            MontInt32X8 &operator-=(const MontInt32X8 &rhs)
            {
                return *this = *this - rhs;
            }
            MontInt32X8 &operator*=(const MontInt32X8 &rhs)
            {
                return *this = *this * rhs;
            }

            MontInt32X8 add(MontInt32X8 rhs) const
            {
                rhs.data = _mm256_add_epi32(data, rhs.data);
                return rhs;
            }
            MontInt32X8 sub(MontInt32X8 rhs) const
            {
                rhs.data = _mm256_sub_epi32(data, rhs.data);
                rhs.data = _mm256_add_epi32(mod2X8(), rhs.data);
                return rhs;
            }

            static Int32X8 montRedcLazy(Int32X8 even64, Int32X8 odd64)
            {
                Int32X8 prod0 = mul64(even64, modNX8());
                Int32X8 prod1 = mul64(odd64, modNX8());
                prod0 = mul64(prod0, modX8());
                prod1 = mul64(prod1, modX8());
                prod0 = rawAdd64(prod0, even64);
                prod1 = rawAdd64(prod1, odd64);
                prod0 = rShift64<32>(prod0);
                return blend<0b10101010>(prod0, prod1);
            }
            static Int32X8 montRedc(Int32X8 even64, Int32X8 odd64)
            {
                MontInt32X8 res;
                res.data = montRedcLazy(even64, odd64);
                return res.norm().data;
            }

            static Int32X8 mulMont(Int32X8 lhs, Int32X8 rhs)
            {
                mul32X32To64(lhs, rhs, lhs, rhs);
                return montRedc(lhs, rhs);
            }
            static Int32X8 mulMontLazy(Int32X8 lhs, Int32X8 rhs)
            {
                mul32X32To64(lhs, rhs, lhs, rhs);
                return montRedcLazy(lhs, rhs);
            }

            static void mul32X32To64(Int32X8 lhs, Int32X8 rhs, Int32X8 &low, Int32X8 &high)
            {
                low = mul64(lhs, rhs);
                high = mul64(rShift64<32>(lhs), rShift64<32>(rhs));
            }

            MontInt32X8 norm() const
            {
                MontInt32X8 dif;
                dif.data = rawSub(data, modX8());
                dif.data = minU32(data, dif.data);
                return dif;
            }
            MontInt32X8 largeNorm() const
            {
                MontInt32X8 dif;
                dif.data = rawSub(data, mod2X8());
                dif.data = minU32(data, dif.data);
                return dif;
            }
            MontInt32X8 smallNorm() const
            {
                MontInt32X8 sum;
                sum.data = rawAdd(data, mod2X8());
                sum.data = minU32(data, sum.data);
                return sum;
            }
            // [a,b]->[0,a]
            MontInt32X8 lshift32In64() const
            {
                MontInt32X8 res;
                res.data = lShift64<32>(data);
                return res;
            }
            // [a,b]->[b,0]
            MontInt32X8 rshift32In64() const
            {
                MontInt32X8 res;
                res.data = rShift64<32>(data);
                return res;
            }
            template <int N>
            static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
            {
                a.data = _mm256_blend_epi32(a.data, b.data, N);
                return a;
            }
            template <int N>
            static MontInt32X8 permute2X128(MontInt32X8 a, MontInt32X8 b)
            {
                a.data = _mm256_permute2x128_si256(a.data, b.data, N);
                return a;
            }
            template <int N>
            MontInt32X8 lShiftByte128() const
            {
                MontInt32X8 res;
                res.data = _mm256_bslli_epi128(data, N);
                return res;
            }
            template <int N>
            MontInt32X8 rShiftByte128() const
            {
                MontInt32X8 res;
                res.data = _mm256_bsrli_epi128(data, N);
                return res;
            }
            MontInt32X8 lshift64In128() const
            {
                return lShiftByte128<8>();
            }
            MontInt32X8 rshift64In128() const
            {
                return rShiftByte128<8>();
            }
            // even[a,b],odd[c,d]->[a,d]
            static MontInt32X8 cross32(const MontInt32X8 &even, const MontInt32X8 &odd)
            {
                return blend<0b10101010>(even, odd);
            }
            // even[a,b,c,d],odd[e,f,g,h]->[a,b,g,h]
            static MontInt32X8 cross64(const MontInt32X8 &even, const MontInt32X8 &odd)
            {
                return blend<0b11001100>(even, odd);
            }
            // lo[a,b],hi[c,d]->[a,c]
            static MontInt32X8 packLo128(const MontInt32X8 &lo, const MontInt32X8 &hi)
            {
                return permute2X128<0x20>(lo, hi);
            }
            // lo[a,b],hi[c,d]->[b,d]
            static MontInt32X8 packHi128(const MontInt32X8 &lo, const MontInt32X8 &hi)
            {
                return permute2X128<0x31>(lo, hi);
            }
            static constexpr uint32_t mod()
            {
                return MontInt::mod();
            }
            static constexpr uint32_t modNegInv()
            {
                return MontInt::modNegInv();
            }
            static Int32X8 zeroX8()
            {
                return _mm256_setzero_si256();
            }
            static Int32X8 modX8()
            {
                return _mm256_set1_epi32(mod());
            }
            static Int32X8 mod2X8()
            {
                constexpr uint32_t MOD2 = mod() * 2;
                return _mm256_set1_epi32(MOD2);
            }
            static Int32X8 modNX8()
            {
                constexpr uint32_t MOD_INV_NEG = modNegInv();
                return _mm256_set1_epi32(MOD_INV_NEG);
            }
            static Int32X8 r2X8()
            {
                constexpr uint32_t R = (uint64_t(1) << 32) % mod(), R2 = uint64_t(R) * R % mod();
                return _mm256_set1_epi32(R2);
            }

            static Int32X8 mul64(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_mul_epu32(lhs, rhs);
            }
            template <int N>
            static Int32X8 lShift64(const Int32X8 &n)
            {
                return _mm256_slli_epi64(n, N);
            }
            template <int N>
            static Int32X8 rShift64(const Int32X8 &n)
            {
                return _mm256_srli_epi64(n, N);
            }

            static Int32X8 rawAdd(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_add_epi32(lhs, rhs);
            }
            static Int32X8 rawSub(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_sub_epi32(lhs, rhs);
            }
            static Int32X8 rawAdd64(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_add_epi64(lhs, rhs);
            }
            static Int32X8 rawSub64(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_sub_epi64(lhs, rhs);
            }

            static Int32X8 maxU32(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_max_epu32(lhs, rhs);
            }
            static Int32X8 minU32(const Int32X8 &lhs, const Int32X8 &rhs)
            {
                return _mm256_min_epu32(lhs, rhs);
            }

            template <int N>
            static Int32X8 blend(const Int32X8 &a, const Int32X8 &b)
            {
                return _mm256_blend_epi32(a, b, N);
            }

            static Int32X8 toMont(const Int32X8 &n)
            {
                return mulMont(r2X8(), n);
            }
            static Int32X8 toInt(const Int32X8 &n)
            {
                Int32X8 e = evenElements(n);
                Int32X8 o = rShift64<32>(n);
                return montRedc(e, o);
            }
            Int32X8 toInt() const
            {
                return toInt(data);
            }
            // a,b,c,d -> a,0,b,0
            static Int32X8 evenElements(const Int32X8 &n)
            {
                return blend<0b10101010>(n, zeroX8());
            }
            // a,b,c,d -> 0,b,0,d
            static Int32X8 oddElements(const Int32X8 &n)
            {
                return blend<0b01010101>(n, zeroX8());
            }

            void set1(int32_t n)
            {
                data = _mm256_set1_epi32(n);
            }
            template <typename T>
            void loadu(const T *p)
            {
                data = _mm256_loadu_si256((const Int32X8 *)p);
            }
            template <typename T>
            void load(const T *p)
            {
                data = _mm256_load_si256((const Int32X8 *)p);
            }
            template <typename T>
            void loadMask(const T *p, const Int32X8 &mask)
            {
                data = _mm256_maskload_epi32((const int *)p, mask);
            }
            template <typename T>
            void loadN(const T *p, int n)
            {
                constexpr uint32_t m = UINT32_MAX;
                constexpr uint32_t mask_arr[16]{m, m, m, m, m, m, m, m};
                Int32X8 mask = _mm256_loadu_si256((const Int32X8 *)(mask_arr + 8 - n));
                loadMask(p, mask);
            }
            template <typename T>
            void storeu(T *p) const
            {
                _mm256_storeu_si256((__m256i *)p, data);
            }
            template <typename T>
            void store(T *p) const
            {
                _mm256_store_si256((__m256i *)p, data);
            }
            template <typename T>
            void storeMask(T *p, const __m256i &mask) const
            {
                _mm256_maskstore_epi32((int *)p, mask, data);
            }
            template <typename T>
            void storeN(T *p, int n) const
            {
                constexpr uint32_t m = UINT32_MAX;
                constexpr uint32_t mask_arr[16]{m, m, m, m, m, m, m, m};
                Int32X8 mask = _mm256_loadu_si256((const Int32X8 *)(mask_arr + 8 - n));
                storeMask(p, mask);
            }
            uint32_t nthU32(size_t i) const
            {
                return _mm256_extract_epi32(data, i);
            }
            uint64_t nthU64(size_t i) const
            {
                return _mm256_extract_epi64(data, i);
            }
            void printU32(bool norm = false) const
            {
                if (!norm)
                {
                    std::cout << "[" << nthU32(0) << "," << nthU32(1)
                              << "," << nthU32(2) << "," << nthU32(3)
                              << "," << nthU32(4) << "," << nthU32(5)
                              << "," << nthU32(6) << "," << nthU32(7) << "]" << std::endl;
                    return;
                }
                largeNorm().norm().printU32();
            }
            void printU64() const
            {
                std::cout << "[" << nthU64(0) << "," << nthU64(1)
                          << "," << nthU64(2) << "," << nthU64(3) << "]" << std::endl;
            }
            void printU32Int() const
            {
                MontInt32X8 res;
                res.data = toInt();
                res.printU32();
            }
        };

        template <typename YMM>
        inline void transpose8x8(YMM &row0, YMM &row1, YMM &row2, YMM &row3, YMM &row4, YMM &row5, YMM &row6, YMM &row7)
        {
            using Type = YMM;
            static_assert(sizeof(Type) == 32);
            __m256 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
            temp0 = _mm256_unpacklo_ps(__m256(row0), __m256(row1));
            temp1 = _mm256_unpackhi_ps(__m256(row0), __m256(row1));
            temp2 = _mm256_unpacklo_ps(__m256(row2), __m256(row3));
            temp3 = _mm256_unpackhi_ps(__m256(row2), __m256(row3));
            temp4 = _mm256_unpacklo_ps(__m256(row4), __m256(row5));
            temp5 = _mm256_unpackhi_ps(__m256(row4), __m256(row5));
            temp6 = _mm256_unpacklo_ps(__m256(row6), __m256(row7));
            temp7 = _mm256_unpackhi_ps(__m256(row6), __m256(row7));
            row0 = Type(_mm256_shuffle_ps(temp0, temp2, _MM_SHUFFLE(1, 0, 1, 0)));
            row1 = Type(_mm256_shuffle_ps(temp0, temp2, _MM_SHUFFLE(3, 2, 3, 2)));
            row2 = Type(_mm256_shuffle_ps(temp1, temp3, _MM_SHUFFLE(1, 0, 1, 0)));
            row3 = Type(_mm256_shuffle_ps(temp1, temp3, _MM_SHUFFLE(3, 2, 3, 2)));
            row4 = Type(_mm256_shuffle_ps(temp4, temp6, _MM_SHUFFLE(1, 0, 1, 0)));
            row5 = Type(_mm256_shuffle_ps(temp4, temp6, _MM_SHUFFLE(3, 2, 3, 2)));
            row6 = Type(_mm256_shuffle_ps(temp5, temp7, _MM_SHUFFLE(1, 0, 1, 0)));
            row7 = Type(_mm256_shuffle_ps(temp5, temp7, _MM_SHUFFLE(3, 2, 3, 2)));
            temp0 = _mm256_permute2f128_ps(__m256(row0), __m256(row4), 0x20);
            temp1 = _mm256_permute2f128_ps(__m256(row1), __m256(row5), 0x20);
            temp2 = _mm256_permute2f128_ps(__m256(row2), __m256(row6), 0x20);
            temp3 = _mm256_permute2f128_ps(__m256(row3), __m256(row7), 0x20);
            temp4 = _mm256_permute2f128_ps(__m256(row0), __m256(row4), 0x31);
            temp5 = _mm256_permute2f128_ps(__m256(row1), __m256(row5), 0x31);
            temp6 = _mm256_permute2f128_ps(__m256(row2), __m256(row6), 0x31);
            temp7 = _mm256_permute2f128_ps(__m256(row3), __m256(row7), 0x31);
            row0 = Type(temp0), row1 = Type(temp1), row2 = Type(temp2), row3 = Type(temp3);
            row4 = Type(temp4), row5 = Type(temp5), row6 = Type(temp6), row7 = Type(temp7);
        }
    }

    namespace transform
    {
        template <typename T>
        constexpr void transform2(T &sum, T &diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }

        // 多模式,自动类型,自检查快速数论变换
        namespace ntt
        {
            using namespace modint;
            using namespace simd;
            template <typename IntType>
            constexpr bool check_inv(uint64_t n, uint64_t n_inv, uint64_t mod)
            {
                n %= mod;
                n_inv %= mod;
                IntType m(n);
                m *= IntType(n_inv);
                m %= IntType(mod);
                return m == IntType(1);
            }

            namespace radix2_avx
            {
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w41(const T &n)
                {
                    constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                    return n * T(W_4_1);
                }
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w81(const T &n)
                {
                    constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                    return n * T(W_8_1);
                }
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w83(const T &n)
                {
                    constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8 * 3);
                    return n * T(W_8_3);
                }

                // in: in_out0<4p, in_ou1<4p
                // out: in_out0<4p, in_ou1<4p
                template <typename ModIntType>
                inline void dit_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
                {
                    auto x = in_out0.largeNorm();
                    auto y = in_out1 * omega;
                    in_out0 = x.add(y);
                    in_out1 = x.sub(y);
                }

                // in: in_out0<2p, in_ou1<4p
                // out: in_out0<4p, in_ou1<4p
                template <typename ModIntType>
                inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::false_type)
                {
                    auto x = in_out0;
                    auto y = in_out1 * omega;
                    in_out0 = x.add(y);
                    in_out1 = x.sub(y);
                }

                // in: in_out0<2p, in_ou1<4p
                // out: in_out0<2p, in_ou1<2p
                template <typename ModIntType>
                inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::true_type)
                {
                    auto x = in_out0;
                    auto y = in_out1 * omega;
                    in_out0 = x + y;
                    in_out1 = x - y;
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p or 4p, in_ou1<2p or 4p, in_out2<2p or 4p, in_ou3<2p or 4p
                template <bool OUT2P, typename ModIntType>
                inline void dit_butterfly2_i2424_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                        ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    dit_butterfly2_i24(in_out0, in_out2, omega0, std::integral_constant<bool, OUT2P>{});
                    dit_butterfly2_i24(in_out1, in_out3, omega1, std::integral_constant<bool, OUT2P>{});
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
                template <typename ModIntType>
                inline void dit_butterfly2_2layer_out3(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                       ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    dit_butterfly2_i24(in_out0, in_out2, omega0, std::true_type{});
                    in_out1 = in_out1 + in_out3 * omega1;
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
                template <typename ModIntType>
                inline void dit_butterfly2_2layer_out2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                       ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    in_out0 = in_out0 + in_out2 * omega0;
                    in_out1 = in_out1 + in_out3 * omega1;
                }

                // in: in_out0<2p, in_ou1<2p
                // out: in_out0<2p, in_ou1<2p
                template <typename ModIntType>
                inline void dif_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
                {
                    auto x = in_out0.add(in_out1);
                    auto y = in_out0.sub(in_out1);
                    in_out0 = x.largeNorm();
                    in_out1 = y * omega;
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                  ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dif_butterfly2(in_out0, in_out2, omega0);
                    dif_butterfly2(in_out1, in_out3, omega1);
                    dif_butterfly2(in_out0, in_out1, omega_last);
                    dif_butterfly2(in_out2, in_out3, omega_last);
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer_in2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                      ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    in_out2 = in_out0 * omega0;
                    in_out3 = in_out1 * omega1;
                    dif_butterfly2(in_out0, in_out1, omega_last);
                    dif_butterfly2(in_out2, in_out3, omega_last);
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer_in1(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                      ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    in_out2 = in_out0 * omega0;
                    in_out1 = in_out0 * omega_last;
                    in_out3 = in_out2 * omega_last;
                }

                template <typename ModIntType, uint32_t ROOT>
                static auto omegax8(size_t ntt_len, int factor, size_t begin = 0, bool inv = false)
                {
                    using ModIntX8 = MontInt32X8<ModIntType>;
                    alignas(32) ModIntType w_arr[8]{};
                    ModIntType unit(qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / ntt_len * factor));
                    if (inv)
                    {
                        unit = unit.inv();
                    }
                    ModIntType w(qpow(unit, begin));
                    for (auto &&i : w_arr)
                    {
                        i = w;
                        w = w * unit;
                    }
                    return ModIntX8(w_arr);
                }

                // LEN <= MAX_LEN
                template <size_t MAX_LEN, uint32_t ROOT, typename ModIntTy>
                struct NTTShort
                {
                    static constexpr size_t NTT_LEN = MAX_LEN;
                    static constexpr int LOG_LEN = hint_log2(NTT_LEN);

                    static constexpr uint32_t mod()
                    {
                        return ModIntTy::mod();
                    }
                    static constexpr uint32_t root()
                    {
                        return ROOT;
                    }
                    static constexpr uint32_t rootInv()
                    {
                        constexpr uint32_t IROOT = mod_inv<int64_t>(ROOT, mod());
                        return IROOT;
                    }

                    using ModIntType = ModIntTy;
                    using ModIntX8 = MontInt32X8<ModIntType>;
                    using INTT = NTTShort<MAX_LEN, rootInv(), ModIntTy>;

                    struct TableType
                    {
                        alignas(64) std::array<ModIntType, NTT_LEN> omega_table;
                        // Compute in compile time if need.
                        /*constexpr*/ TableType()
                        {
                            for (int omega_log_len = 4; omega_log_len <= LOG_LEN; omega_log_len++)
                            {
                                size_t omega_len = size_t(1) << omega_log_len, omega_count = omega_len / 2;
                                auto it = &omega_table[omega_len / 2];
                                ModIntType root = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / omega_len);
                                ModIntType omega(1);
                                size_t i = 0;
                                for (; i < 8; i++)
                                {
                                    it[i] = omega;
                                    omega *= root;
                                }
                                ModIntX8 omegaX8, rootX8 = qpow(root, 8);
                                omegaX8.loadu(&it[0]);
                                for (; i < omega_count; i += 8)
                                {
                                    omegaX8 *= rootX8;
                                    omegaX8.storeu(&it[i]);
                                }
                            }
                        }
                        constexpr ModIntType &operator[](size_t i)
                        {
                            return omega_table[i];
                        }
                        constexpr const ModIntType &operator[](size_t i) const
                        {
                            return omega_table[i];
                        }
                        constexpr const ModIntType *getOmegaIt(size_t len) const
                        {
                            return &omega_table[len / 2];
                        }
                    };

                    static const TableType table;

                    static constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                    static constexpr ModIntType W_8_2 = qpow(W_8_1, 2);
                    static constexpr ModIntType W_8_3 = qpow(W_8_1, 3);
                    static constexpr const ModIntType *W16_IT = table.getOmegaIt(16);
                    static constexpr const ModIntType *W32_IT = table.getOmegaIt(32);
                    static constexpr const ModIntType *W64_IT = table.getOmegaIt(64);

                    template <bool OUT2P = false>
                    static void ditLayer(ModIntType in_out[], size_t rank)
                    {
                        const size_t gap = rank / 4;
                        auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
                        auto it0 = in_out, it1 = in_out + gap, it2 = in_out + gap * 2, it3 = in_out + gap * 3;
                        for (size_t i = 0; i < gap; i += 16)
                        {
                            // In: 2p, 4p, 2p, 4p
                            ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7, omega0, omega1, omega_last;
                            temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);
                            temp4.load(&it0[8 + i]), temp5.load(&it1[8 + i]), temp6.load(&it2[8 + i]), temp7.load(&it3[8 + i]);

                            omega0.load(&omega_it[i]), omega1.load(&omega_it[gap + i]), omega_last.load(&last_omega_it[i]);
                            dit_butterfly2_i2424_2layer<OUT2P>(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);
                            omega0.load(&omega_it[8 + i]), omega1.load(&omega_it[8 + gap + i]), omega_last.load(&last_omega_it[8 + i]);
                            dit_butterfly2_i2424_2layer<OUT2P>(temp4, temp5, temp6, temp7, omega0, omega1, omega_last);

                            temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                            temp4.store(&it0[8 + i]), temp5.store(&it1[8 + i]), temp6.store(&it2[8 + i]), temp7.store(&it3[8 + i]);
                        }
                    }

                    static void difLayer(ModIntType in_out[], size_t rank)
                    {
                        const size_t gap = rank / 4;
                        auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
                        auto it0 = in_out, it1 = in_out + gap, it2 = in_out + gap * 2, it3 = in_out + gap * 3;
                        for (size_t i = 0; i < gap; i += 16)
                        {
                            ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7, omega0, omega1, omega_last;
                            temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);
                            temp4.load(&it0[8 + i]), temp5.load(&it1[8 + i]), temp6.load(&it2[8 + i]), temp7.load(&it3[8 + i]);

                            omega0.load(&omega_it[i]), omega1.load(&omega_it[gap + i]), omega_last.load(&last_omega_it[i]);
                            dif_butterfly2_2layer(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);
                            omega0.load(&omega_it[8 + i]), omega1.load(&omega_it[8 + gap + i]), omega_last.load(&last_omega_it[8 + i]);
                            dif_butterfly2_2layer(temp4, temp5, temp6, temp7, omega0, omega1, omega_last);

                            temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                            temp4.store(&it0[8 + i]), temp5.store(&it1[8 + i]), temp6.store(&it2[8 + i]), temp7.store(&it3[8 + i]);
                        }
                    }
                    static void difLayerX2(ModIntType in_out1[], ModIntType in_out2[], size_t rank)
                    {
                        const size_t gap = rank / 4;
                        auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
                        auto it0 = in_out1, it1 = in_out1 + gap, it2 = in_out1 + gap * 2, it3 = in_out1 + gap * 3;
                        auto it4 = in_out2, it5 = in_out2 + gap, it6 = in_out2 + gap * 2, it7 = in_out2 + gap * 3;
                        for (size_t i = 0; i < gap; i += 8)
                        {
                            ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7, omega0, omega1, omega_last;
                            temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);
                            temp4.load(&it4[i]), temp5.load(&it5[i]), temp6.load(&it6[i]), temp7.load(&it7[i]);
                            omega0.load(&omega_it[i]), omega1.load(&omega_it[gap + i]), omega_last.load(&last_omega_it[i]);

                            dif_butterfly2_2layer(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);
                            dif_butterfly2_2layer(temp4, temp5, temp6, temp7, omega0, omega1, omega_last);

                            temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                            temp4.store(&it4[i]), temp5.store(&it5[i]), temp6.store(&it6[i]), temp7.store(&it7[i]);
                        }
                    }
                    template <bool OUT2P>
                    static void dit(ModIntType in_out[], size_t len, size_t rank)
                    {
                        len = std::min(NTT_LEN, len);
                        for (; rank < len; rank *= 4)
                        {
                            const size_t gap = rank / 4;
                            auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
                            for (size_t j = 0; j < len; j += rank * 2)
                            {
                                ditLayer<true>(in_out + j, rank);
                                ditLayer<false>(in_out + j + rank, rank);
                            }
                        }
                        assert(rank == len);
                        ditLayer<OUT2P>(in_out, len);
                    }
                    static void dif(ModIntType in_out[], size_t len, size_t &rank, size_t rank_end = 128)
                    {
                        len = std::min(NTT_LEN, len);
                        for (rank = len; rank >= rank_end; rank /= 4)
                        {
                            for (size_t j = 0; j < len; j += rank)
                            {
                                difLayer(in_out + j, rank);
                            }
                        }
                    }
                    static void dif(ModIntType in_out1[], ModIntType in_out2[], size_t len, size_t &rank, size_t rank_end = 128)
                    {
                        len = std::min(NTT_LEN, len);
                        for (rank = len; rank >= rank_end; rank /= 4)
                        {
                            const size_t gap = rank / 4;
                            auto omega_it = table.getOmegaIt(rank), last_omega_it = table.getOmegaIt(rank / 2);
                            for (size_t j = 0; j < len; j += rank)
                            {
                                difLayerX2(in_out1 + j, in_out2 + j, rank);
                            }
                        }
                    }
                    template <bool OUT2P = false>
                    static void convolution(ModIntType in_out[], ModIntType in[], size_t len, ModIntType ntt_len_inv_r)
                    {
                        constexpr size_t L1_THRESHOLD = L1_BYTE / (2 * sizeof(ModIntType));
                        if (len <= L1_THRESHOLD)
                        {
                            convolutionL1<OUT2P>(in_out, in, len, ntt_len_inv_r);
                            return;
                        }
                        difLayerX2(in_out, in, len);
                        const size_t len_4 = len / 4;
                        convolution<true>(in_out, in, len_4, ntt_len_inv_r);
                        convolution<false>(in_out + len_4, in + len_4, len_4, ntt_len_inv_r);
                        convolution<true>(in_out + len_4 * 2, in + len_4 * 2, len_4, ntt_len_inv_r);
                        convolution<false>(in_out + len_4 * 3, in + len_4 * 3, len_4, ntt_len_inv_r);
                        INTT::template ditLayer<OUT2P>(in_out, len);
                    }
                    template <bool OUT2P>
                    static void convolutionL1(ModIntType in_out[], ModIntType in[], size_t len, ModIntType ntt_len_inv_r)
                    {
                        if (len <= 64)
                        {
                            convolutionTiny(in_out, in, len, ntt_len_inv_r);
                            if (OUT2P)
                            {
                                size_t i = 0;
                                for (const size_t rem_len = len - len % 8; i < rem_len; i += 8)
                                {
                                    ModIntX8 temp;
                                    temp.load(&in_out[i]);
                                    temp.largeNorm().store(&in_out[i]);
                                }
                                for (; i < len; ++i)
                                {
                                    in_out[i] = in_out[i].largeNorm();
                                }
                            }
                            return;
                        }
                        // len >= 128
                        size_t rank = 0;
                        dif(in_out, in, len, rank);
                        if (rank == 32)
                        {
                            for (size_t i = 0; i < len; i += 64)
                            {
                                convolution32<true>(in_out + i, in + i, ntt_len_inv_r);
                                convolution32<false>(in_out + 32 + i, in + 32 + i, ntt_len_inv_r);
                            }
                        }
                        else if (rank == 64)
                        {
                            for (size_t i = 0; i < len; i += 128)
                            {
                                convolution64<true>(in_out + i, in + i, ntt_len_inv_r);
                                convolution64<false>(in_out + 64 + i, in + 64 + i, ntt_len_inv_r);
                            }
                        }
                        else
                        {
                            assert(0);
                        }
                        INTT::template dit<OUT2P>(in_out, len, rank * 4);
                    }

                    static void convolutionTiny(ModIntType in_out[], ModIntType in[], size_t len, ModIntType ntt_len_inv_r)
                    {
                        switch (len)
                        {
                        case 1:
                            in_out[0] *= in[0] * ntt_len_inv_r;
                            break;
                        case 2:
                            convolution2(in_out, in, ntt_len_inv_r);
                            break;
                        case 4:
                            convolution4(in_out, in, ntt_len_inv_r);
                            break;
                        case 8:
                            convolution8(in_out, in, ntt_len_inv_r);
                            break;
                        case 16:
                            convolution16(in_out, in, ntt_len_inv_r);
                            break;
                        case 32:
                            convolution32<false>(in_out, in, ntt_len_inv_r);
                            break;
                        case 64:
                            convolution64<false>(in_out, in, ntt_len_inv_r);
                            break;
                        default:
                            assert(0);
                        }
                    }

                    static void convolution2(ModIntType in_out[], ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        ModIntType temp0 = in_out[0], temp1 = in_out[1];
                        ModIntType temp2 = in[0], temp3 = in[1];
                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        temp0 = temp0 * temp2 * ntt_len_inv_r;
                        temp1 = temp1 * temp3 * ntt_len_inv_r;
                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                    }

                    static void convolution4(ModIntType in_out[], ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        dif4(in_out);
                        dif4(in);
                        in_out[0] = in_out[0] * in[0] * ntt_len_inv_r;
                        in_out[1] = in_out[1] * in[1] * ntt_len_inv_r;
                        in_out[2] = in_out[2] * in[2] * ntt_len_inv_r;
                        in_out[3] = in_out[3] * in[3] * ntt_len_inv_r;
                        INTT::dit4(in_out);
                    }

                    static void convolution8(ModIntType in_out[], ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        ModIntX8 temp0, temp1, empty;
                        temp0.load(in_out), temp1.load(in);
                        dif8X2(temp0, temp1);
                        temp0 *= temp1 * ntt_len_inv_r;
                        INTT::dit8X2(temp0, empty);
                        temp0.store(in_out);
                    }

                    static void convolution16(ModIntType in_out[], ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        ModIntX8 temp0, temp1, temp2, temp3, empty;
                        dif16(in_out);
                        dif16(in);
                        temp0.load(in_out), temp1.load(in_out + 8);
                        temp2.load(in), temp3.load(in + 8);
                        temp0 *= temp2 * ntt_len_inv_r;
                        temp1 *= temp3 * ntt_len_inv_r;
                        temp0.store(in_out), temp1.store(in_out + 8);
                        INTT::dit16(in_out);
                    }

                    static void dit2(ModIntType &in_out0, ModIntType &in_out1)
                    {
                        auto x = in_out0.largeNorm();
                        auto y = in_out1.largeNorm();
                        in_out0 = x.add(y);
                        in_out1 = x.sub(y);
                    }
                    static void dif2(ModIntType &in_out0, ModIntType &in_out1)
                    {
                        transform2(in_out0, in_out1);
                    }
                    static void dit4(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0].largeNorm();
                        auto temp1 = in_out[1].largeNorm();
                        auto temp2 = in_out[2].largeNorm();
                        auto temp3 = in_out[3].largeNorm();

                        transform2(temp0, temp1);
                        auto sum = temp2.add(temp3);
                        auto dif = temp2.sub(temp3);
                        temp2 = sum.largeNorm();
                        temp3 = mul_w41<ROOT, ModIntType>(dif);

                        in_out[0] = temp0.add(temp2);
                        in_out[1] = temp1.add(temp3);
                        in_out[2] = temp0.sub(temp2);
                        in_out[3] = temp1.sub(temp3);
                    }
                    static void dif4(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];

                        transform2(temp0, temp2);
                        auto sum = temp1.add(temp3);
                        auto dif = temp1.sub(temp3);
                        temp1 = sum.largeNorm();
                        temp3 = mul_w41<ROOT, ModIntType>(dif);

                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                        in_out[2] = temp2 + temp3;
                        in_out[3] = temp2 - temp3;
                    }

                    static void dit8(ModIntType in_out[])
                    {
                        ModIntX8 A, B;
                        A.load(in_out);
                        dit8X2(A, B);
                        A.store(in_out);
                    }
                    static void dif8(ModIntType in_out[])
                    {
                        ModIntX8 A, B;
                        A.load(in_out);
                        dif8X2(A, B);
                        A.store(in_out);
                    }
                    static void dit16(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1, omega;
                        temp0.load(&in_out[0]);
                        temp1.load(&in_out[8]);
                        omega.load(W16_IT);
                        dit8X2(temp0, temp1);
                        temp0 = temp0.largeNorm();
                        temp1 = temp1 * omega;

                        temp0.add(temp1).store(&in_out[0]);
                        temp0.sub(temp1).store(&in_out[8]);
                    }
                    static void dif16(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1, sum, dif, omega;
                        temp0.load(&in_out[0]);
                        temp1.load(&in_out[8]);
                        omega.load(W16_IT);
                        sum = temp0.add(temp1);
                        dif = temp0.sub(temp1);
                        temp0 = sum.largeNorm();
                        temp1 = dif * omega;
                        dif8X2(temp0, temp1);

                        temp0.store(&in_out[0]);
                        temp1.store(&in_out[8]);
                    }

                    static ModIntX8 dit2X4(ModIntX8 in)
                    {
                        // in = in.largeNorm();
                        ModIntX8 lo = in.lshift32In64(); // 0, a
                        ModIntX8 hi = in.rshift32In64(); // b, 0
                        lo = lo.sub(in);                 // X, a - b + mod2
                        hi = hi.add(in);                 // a + b ,X
                        return ModIntX8::cross32(hi, lo);
                    }

                    static ModIntX8 dif2X4(ModIntX8 in)
                    {
                        ModIntX8 lo = in.lshift32In64(); // 0, a
                        ModIntX8 hi = in.rshift32In64(); // b, 0
                        lo = lo.sub(in);                 // X, a - b + mod2
                        hi = hi.add(in);                 // a + b ,X
                        return ModIntX8::cross32(hi, lo).largeNorm();
                    }

                    static void dit4X4(ModIntX8 &A, ModIntX8 &B)
                    {
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_2, ModIntType(1), W_8_2, ModIntType(1), W_8_2, ModIntType(1), W_8_2};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp0 = dit2X4(A); // A0,A1,A2,A3,A4,A5,A6,A7
                        temp1 = dit2X4(B); // B0,B1,B2,B3,B4,B5,B6,B7

                        omega.load(w_arr);
                        temp2 = temp0.rshift64In128(); // A2,A3,X,X,A6,A7,X,X
                        temp3 = temp1.lshift64In128(); // X,X,B0,B1,X,X,B4,B5

                        temp0 = ModIntX8::cross64(temp0, temp3); // A0,A1,B0,B1,A4,A5,B4,B5
                        temp1 = ModIntX8::cross64(temp2, temp1); // A2,A3,B2,B3,A6,A7,B6,B7

                        temp0 = temp0.largeNorm();
                        temp1 = temp1 * omega; // (A2,A3,B2,B3,A6,A7,B6,B7)*w

                        temp2 = temp0.add(temp1); // A0,A1,B0,B1,A4,A5,B4,B5
                        temp3 = temp0.sub(temp1); // A2,A3,B2,B3,A6,A7,B6,B7

                        temp0 = temp2.rshift64In128(); // B0,B1,X,X,B4,B5,X,X
                        temp1 = temp3.lshift64In128(); // X,X,A2,A3,X,X,A6,A7

                        A = ModIntX8::cross64(temp2, temp1); // A0,A1,A2,A3,A4,A5,A6,A7
                        B = ModIntX8::cross64(temp0, temp3); // B0,B1,B2,B3,B4,B5,B6,B7
                    }
                    static void dif4X4(ModIntX8 &A, ModIntX8 &B)
                    {

                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_2, ModIntType(1), W_8_2, ModIntType(1), W_8_2, ModIntType(1), W_8_2};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp2 = A.rshift64In128(); // A2,A3,X,X,A6,A7,X,X
                        temp3 = B.lshift64In128(); // X,X,B0,B1,X,X,B4,B5

                        omega.load(w_arr);
                        temp0 = ModIntX8::cross64(A, temp3); // A0,A1,B0,B1,A4,A5,B4,B5
                        temp1 = ModIntX8::cross64(temp2, B); // A2,A3,B2,B3,A6,A7,B6,B7

                        temp2 = temp0.add(temp1); // A0,A1,B0,B1,A4,A5,B4,B5
                        temp3 = temp0.sub(temp1);
                        temp2 = temp2.largeNorm();
                        temp3 = temp3 * omega; // (A2,A3,B2,B3,A6,A7,B6,B7)*w

                        temp0 = temp2.rshift64In128(); // B0,B1,X,X,B4,B5,X,X
                        temp1 = temp3.lshift64In128(); // X,X,A2,A3,X,X,A6,A7

                        temp2 = ModIntX8::cross64(temp2, temp1); // A0,A1,A2,A3,A4,A5,A6,A7
                        temp3 = ModIntX8::cross64(temp0, temp3); // B0,B1,B2,B3,B4,B5,B6,B7

                        A = dif2X4(temp2); // A
                        B = dif2X4(temp3); // B
                    }

                    static void dit8X2(ModIntX8 &A, ModIntX8 &B)
                    {
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};

                        dit4X4(A, B); // A0,A1,A2,A3,A4,A5,A6,A7; B0,B1,B2,B3,B4,B5,B6,B7
                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        omega.load(w_arr);
                        temp0 = ModIntX8::packLo128(A, B); // A0,A1,A2,A3,B0,B1,B2,B3
                        temp1 = ModIntX8::packHi128(A, B); // A4,A5,A6,A7,B4,B5,B6,B7
                        temp0 = temp0.largeNorm();
                        temp1 = temp1 * omega;                 // (A4,A5,A6,A7,B4,B5,B6,B7)*w
                        temp2 = temp0.add(temp1);              // A0,A1,A2,A3,B0,B1,B2,B3
                        temp3 = temp0.sub(temp1);              // A4,A5,A6,A7,B4,B5,B6,B7
                        A = ModIntX8::packLo128(temp2, temp3); // A0,A1,A2,A3,A4,A5,A6,A7
                        B = ModIntX8::packHi128(temp2, temp3); // B0,B1,B2,B3,B4,B5,B6,B7
                    }
                    static void dif8X2(ModIntX8 &A, ModIntX8 &B)
                    {
                        alignas(32) constexpr ModIntType w_arr[8]{ModIntType(1), W_8_1, W_8_2, W_8_3, ModIntType(1), W_8_1, W_8_2, W_8_3};

                        ModIntX8 temp0, temp1, temp2, temp3, omega;
                        temp0 = ModIntX8::packLo128(A, B); // A0,A1,A2,A3,B0,B1,B2,B3
                        temp1 = ModIntX8::packHi128(A, B); // A4,A5,A6,A7,B4,B5,B6,B7
                        omega.load(w_arr);
                        temp2 = temp0.add(temp1); // A0,A1,A2,A3,B0,B1,B2,B3
                        temp3 = temp0.sub(temp1); // A4,A5,A6,A7,B4,B5,B6,B7
                        temp2 = temp2.largeNorm();
                        temp3 = temp3 * omega;                 //(A4,A5,A6,A7,B4,B5,B6,B7)*w
                        A = ModIntX8::packLo128(temp2, temp3); // A0,A1,A2,A3,A4,A5,A6,A7
                        B = ModIntX8::packHi128(temp2, temp3); // B0,B1,B2,B3,B4,B5,B6,B7
                        dif4X4(A, B);
                    }
                    static void dit8X2(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1;
                        temp0.load(in_out), temp1.load(in_out + 8);
                        dit8X2(temp0, temp1);
                        temp0.store(in_out), temp1.store(in_out + 8);
                    }
                    static void dif8X2(ModIntType in_out[])
                    {
                        ModIntX8 temp0, temp1;
                        temp0.load(in_out), temp1.load(in_out + 8);
                        dif8X2(temp0, temp1);
                        temp0.store(in_out), temp1.store(in_out + 8);
                    }

                    template <bool OUT2P>
                    static void convolution32(ModIntType in_out[], const ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7, omega;
                        temp0.load(&in_out[0]), temp1.load(&in_out[8]), temp2.load(&in_out[16]), temp3.load(&in_out[24]);
                        temp4.load(&in[0]), temp5.load(&in[8]), temp6.load(&in[16]), temp7.load(&in[24]);
                        // DIF32X2
                        omega.load(W32_IT);
                        dif_butterfly2(temp0, temp2, omega);
                        dif_butterfly2(temp4, temp6, omega);
                        omega.load(W32_IT + 8);
                        dif_butterfly2(temp1, temp3, omega);
                        dif_butterfly2(temp5, temp7, omega);
                        // DIF16X4
                        omega.load(W16_IT);
                        dif_butterfly2(temp0, temp1, omega);
                        dif_butterfly2(temp2, temp3, omega);
                        dif_butterfly2(temp4, temp5, omega);
                        dif_butterfly2(temp6, temp7, omega);
                        // DIF8X8
                        transpose8x8(temp0.data, temp1.data, temp2.data, temp3.data, temp4.data, temp5.data, temp6.data, temp7.data);

                        omega = ModIntX8(W_8_2);
                        transform2(temp0, temp4);
                        dif_butterfly2(temp1, temp5, ModIntX8(W_8_1));
                        dif_butterfly2(temp2, temp6, omega);
                        dif_butterfly2(temp3, temp7, ModIntX8(W_8_3));

                        transform2(temp0, temp2);
                        dif_butterfly2(temp1, temp3, omega);
                        transform2(temp4, temp6);
                        dif_butterfly2(temp5, temp7, omega);

                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        transform2(temp4, temp5);
                        transform2(temp6, temp7);

                        transpose8x8(temp0.data, temp1.data, temp2.data, temp3.data, temp4.data, temp5.data, temp6.data, temp7.data);
                        // DOT MUL
                        omega = ModIntX8(ntt_len_inv_r);
                        temp0 *= (temp4 * omega), temp1 *= (temp5 * omega);
                        temp2 *= (temp6 * omega), temp3 *= (temp7 * omega);
                        // DIT8X4
                        INTT::dit8X2(temp0, temp1);
                        INTT::dit8X2(temp2, temp3);
                        // DIT16X2
                        omega.load(INTT::W16_IT);
                        temp0 = temp0.largeNorm();
                        dit_butterfly2_i24(temp0, temp1, omega, std::true_type{});
                        temp2 = temp2.largeNorm();
                        dit_butterfly2_i24(temp2, temp3, omega, std::false_type{});
                        // DIT32
                        using Out22Ty = std::integral_constant<bool, OUT2P>;
                        omega.load(INTT::W32_IT);
                        dit_butterfly2_i24(temp0, temp2, omega, Out22Ty{});
                        omega.load(INTT::W32_IT + 8);
                        dit_butterfly2_i24(temp1, temp3, omega, Out22Ty{});

                        temp0.store(&in_out[0]), temp1.store(&in_out[8]), temp2.store(&in_out[16]), temp3.store(&in_out[24]);
                    }
                    template <bool OUT2P>
                    static void convolution64(ModIntType in_out[], ModIntType in[], ModIntType ntt_len_inv_r)
                    {
                        ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7, omega;
                        temp0.load(&in_out[0]), temp1.load(&in_out[8]), temp2.load(&in_out[32]), temp3.load(&in_out[40]);
                        temp4.load(&in[0]), temp5.load(&in[8]), temp6.load(&in[32]), temp7.load(&in[40]);
                        omega.load(W64_IT);
                        dif_butterfly2(temp0, temp2, omega);
                        dif_butterfly2(temp4, temp6, omega);
                        omega.load(W64_IT + 8);
                        dif_butterfly2(temp1, temp3, omega);
                        dif_butterfly2(temp5, temp7, omega);
                        temp0.store(&in_out[0]), temp1.store(&in_out[8]), temp2.store(&in_out[32]), temp3.store(&in_out[40]);
                        temp4.store(&in[0]), temp5.store(&in[8]), temp6.store(&in[32]), temp7.store(&in[40]);
                        temp0.load(&in_out[16]), temp1.load(&in_out[24]), temp2.load(&in_out[48]), temp3.load(&in_out[56]);
                        temp4.load(&in[16]), temp5.load(&in[24]), temp6.load(&in[48]), temp7.load(&in[56]);
                        omega.load(W64_IT + 16);
                        dif_butterfly2(temp0, temp2, omega);
                        dif_butterfly2(temp4, temp6, omega);
                        omega.load(W64_IT + 24);
                        dif_butterfly2(temp1, temp3, omega);
                        dif_butterfly2(temp5, temp7, omega);
                        temp0.store(&in_out[16]), temp1.store(&in_out[24]), temp2.store(&in_out[48]), temp3.store(&in_out[56]);
                        temp4.store(&in[16]), temp5.store(&in[24]), temp6.store(&in[48]), temp7.store(&in[56]);
                        convolution32<true>(in_out, in, ntt_len_inv_r);
                        convolution32<false>(in_out + 32, in + 32, ntt_len_inv_r);
                        temp0.load(&in_out[0]), temp1.load(&in_out[8]), temp2.load(&in_out[16]), temp3.load(&in_out[24]);
                        temp4.load(&in_out[32]), temp5.load(&in_out[40]), temp6.load(&in_out[48]), temp7.load(&in_out[56]);
                        using Out22Ty = std::integral_constant<bool, OUT2P>;
                        omega.load(INTT::W64_IT);
                        dit_butterfly2_i24(temp0, temp4, omega, Out22Ty{});
                        omega.load(INTT::W64_IT + 8);
                        dit_butterfly2_i24(temp1, temp5, omega, Out22Ty{});
                        omega.load(INTT::W64_IT + 16);
                        dit_butterfly2_i24(temp2, temp6, omega, Out22Ty{});
                        omega.load(INTT::W64_IT + 24);
                        dit_butterfly2_i24(temp3, temp7, omega, Out22Ty{});
                        temp0.store(&in_out[0]), temp1.store(&in_out[8]), temp2.store(&in_out[16]), temp3.store(&in_out[24]);
                        temp4.store(&in_out[32]), temp5.store(&in_out[40]), temp6.store(&in_out[48]), temp7.store(&in_out[56]);
                    }
                };
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                const typename NTTShort<LEN, ROOT, ModIntType>::TableType NTTShort<LEN, ROOT, ModIntType>::table;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr size_t NTTShort<LEN, ROOT, ModIntType>::NTT_LEN;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr int NTTShort<LEN, ROOT, ModIntType>::LOG_LEN;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr ModIntType NTTShort<LEN, ROOT, ModIntType>::W_8_1;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr ModIntType NTTShort<LEN, ROOT, ModIntType>::W_8_2;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr ModIntType NTTShort<LEN, ROOT, ModIntType>::W_8_3;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr const ModIntType *NTTShort<LEN, ROOT, ModIntType>::W16_IT;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr const ModIntType *NTTShort<LEN, ROOT, ModIntType>::W32_IT;
                template <size_t LEN, uint32_t ROOT, typename ModIntType>
                constexpr const ModIntType *NTTShort<LEN, ROOT, ModIntType>::W64_IT;

                template <uint32_t MOD, uint32_t ROOT>
                struct NTT
                {
                    static constexpr uint32_t mod()
                    {
                        return MOD;
                    }
                    static constexpr uint32_t root()
                    {
                        return ROOT;
                    }
                    static constexpr uint32_t rootInv()
                    {
                        constexpr uint32_t IROOT = mod_inv<int64_t>(ROOT, MOD);
                        return IROOT;
                    }

                    static_assert(root() < mod(), "ROOT must be smaller than MOD");
                    static_assert(check_inv<uint64_t>(root(), rootInv(), mod()), "IROOT * ROOT % MOD must be 1");
                    static constexpr int MOD_BITS = hint_log2(mod()) + 1;
                    static constexpr int MAX_LOG_LEN = hint_ctz(mod() - 1);

                    static constexpr size_t getMaxLen()
                    {
                        if (MAX_LOG_LEN < sizeof(size_t) * CHAR_BIT)
                        {
                            return size_t(1) << MAX_LOG_LEN;
                        }
                        return size_t(1) << (sizeof(size_t) * CHAR_BIT - 1);
                    }
                    static constexpr size_t NTT_MAX_LEN = getMaxLen();

                    using INTT = NTT<mod(), rootInv()>;
                    using ModIntType = MontInt32Lazy<MOD>;
                    using ModIntX8 = MontInt32X8<ModIntType>;

                    static constexpr size_t LONG_THRESHOLD = std::min(L2_BYTE / (2 * sizeof(ModIntType)), NTT_MAX_LEN);
                    using NTTTemplate = NTTShort<LONG_THRESHOLD, root(), ModIntType>;

                    static ModIntX8 unitx8(size_t ntt_len, int factor, uint32_t root_in = root())
                    {
                        return ModIntX8(qpow(ModIntType(root_in), (mod() - 1) / ntt_len * factor * 8));
                    }
                    static ModIntX8 omegax8(size_t ntt_len, int factor, uint32_t root_in = root())
                    {
                        alignas(32) ModIntType w_arr[8]{};
                        ModIntType w(1), unit(qpow(ModIntType(root_in), (mod() - 1) / ntt_len * factor));
                        for (auto &&i : w_arr)
                        {
                            i = w;
                            w = w * unit;
                        }
                        return ModIntX8(w_arr);
                    }

                    static void ditOutLayer(uint32_t out[], size_t len, const ModIntType in[], size_t ntt_len)
                    {
                        assert(len > ntt_len / 2);
                        ModIntX8 omega0 = omegax8(ntt_len, 1, root());
                        ModIntX8 omega1 = mul_w41<root(), ModIntType>(omega0);
                        ModIntX8 omega_last = omega0 * omega0;
                        ModIntX8 unit1X8 = unitx8(ntt_len, 1, root()), unit_last = unit1X8 * unit1X8;
                        ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
                        size_t len_4 = ntt_len / 4;
                        auto it0 = in, it1 = in + len_4, it2 = in + len_4 * 2, it3 = in + len_4 * 3;
                        auto it4 = out, it5 = out + len_4, it6 = out + len_4 * 2, it7 = out + len_4 * 3;
                        if (len > len_4 * 3)
                        {
                            const size_t len1 = len - len_4 * 3, rem_len = len1 - len1 % 8;
                            size_t i = 0;
                            for (; i < rem_len; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_i2424_2layer<true>(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm(), temp3 = temp3.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]), temp2.storeu(&it6[i]), temp3.storeu(&it7[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                            if (len1 % 8 > 0)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_i2424_2layer<true>(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm(), temp3 = temp3.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]), temp2.storeu(&it6[i]), temp3.storeN(&it7[i], len1 % 8);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                                i += 8;
                            }
                            for (; i < len_4; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_2layer_out3(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]), temp2.storeu(&it6[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                        }
                        else
                        {
                            const size_t len1 = len - len_4 * 2, rem_len = len1 - len1 % 8;
                            size_t i = 0;
                            for (; i < rem_len; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_2layer_out3(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]), temp2.storeu(&it6[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                            if (len1 % 8 > 0)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_2layer_out3(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]), temp2.storeN(&it6[i], len1 % 8);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                                i += 8;
                            }
                            for (; i < len_4; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_2layer_out2(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0 = temp0.norm(), temp1 = temp1.norm();
                                temp0.storeu(&it4[i]), temp1.storeu(&it5[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                        }
                    }
                    static void difInLayer(const uint32_t in[], size_t len, ModIntType out[], size_t ntt_len)
                    {
                        assert(ntt_len / 4 < len && len <= ntt_len / 2);
                        ModIntX8 omega0 = omegax8(ntt_len, 1, root());
                        ModIntX8 omega1 = mul_w41<root(), ModIntType>(omega0);
                        ModIntX8 omega_last = omega0 * omega0;
                        ModIntX8 unit1X8 = unitx8(ntt_len, 1, root()), unit_last = unit1X8 * unit1X8;
                        ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
                        size_t len_4 = ntt_len / 4;
                        auto it0 = in, it1 = in + len_4, it2 = in + len_4 * 2, it3 = in + len_4 * 3;
                        auto it4 = out, it5 = out + len_4, it6 = out + len_4 * 2, it7 = out + len_4 * 3;
                        const size_t len1 = len - len_4, rem_len = len1 - len1 % 8;
                        size_t i = 0;
                        for (; i < rem_len; i += 8)
                        {
                            temp0.loadu(&it0[i]), temp1.loadu(&it1[i]);

                            dif_butterfly2_2layer_in2(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                            temp0.store(&it4[i]), temp1.store(&it5[i]), temp2.store(&it6[i]), temp3.store(&it7[i]);
                            omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                        }
                        if (len1 % 8 > 0)
                        {
                            temp0.loadu(&it0[i]), temp1.loadN(&it1[i], len1 % 8);

                            dif_butterfly2_2layer_in2(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                            temp0.store(&it4[i]), temp1.store(&it5[i]), temp2.store(&it6[i]), temp3.store(&it7[i]);
                            omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            i += 8;
                        }
                        for (; i < len_4; i += 8)
                        {
                            temp0.loadu(&it0[i]);

                            dif_butterfly2_2layer_in1(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                            temp0.store(&it4[i]), temp1.store(&it5[i]), temp2.store(&it6[i]), temp3.store(&it7[i]);
                            omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                        }
                    }

                    // Inner recursion
                    template <bool OUT2P = true>
                    static void convolutionRecursion(ModIntType in_out1[], ModIntType in_out2[], size_t ntt_len, ModIntType len_inv_r, bool norm = false)
                    {
                        if (ntt_len <= LONG_THRESHOLD)
                        {
                            NTTTemplate::template convolution<OUT2P>(in_out1, in_out2, ntt_len, len_inv_r);
                            if (norm)
                            {
                                for (size_t i = 0; i < ntt_len; i++)
                                {
                                    in_out1[i] = in_out1[i].norm();
                                }
                            }
                            return;
                        }
                        ModIntX8 temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
                        ModIntX8 omega0 = omegax8(ntt_len, 1, root());
                        ModIntX8 omega1 = mul_w41<root(), ModIntType>(omega0);
                        ModIntX8 omega_last = omega0 * omega0;
                        ModIntX8 unit1X8 = unitx8(ntt_len, 1, root()), unit_last = unit1X8 * unit1X8;
                        size_t len_4 = ntt_len / 4;
                        auto it0 = in_out1, it1 = in_out1 + len_4, it2 = in_out1 + len_4 * 2, it3 = in_out1 + len_4 * 3;
                        auto it4 = in_out2, it5 = in_out2 + len_4, it6 = in_out2 + len_4 * 2, it7 = in_out2 + len_4 * 3;
                        for (size_t i = 0; i < len_4; i += 8)
                        {
                            temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);
                            temp4.load(&it4[i]), temp5.load(&it5[i]), temp6.load(&it6[i]), temp7.load(&it7[i]);

                            dif_butterfly2_2layer(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);
                            dif_butterfly2_2layer(temp4, temp5, temp6, temp7, omega0, omega1, omega_last);

                            temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                            temp4.store(&it4[i]), temp5.store(&it5[i]), temp6.store(&it6[i]), temp7.store(&it7[i]);
                            omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                        }
                        convolutionRecursion<true>(it0, it4, len_4, len_inv_r);
                        convolutionRecursion<false>(it1, it5, len_4, len_inv_r);
                        convolutionRecursion<true>(it2, it6, len_4, len_inv_r);
                        convolutionRecursion<false>(it3, it7, len_4, len_inv_r);
                        omega0 = omegax8(ntt_len, 1, rootInv());
                        omega1 = mul_w41<rootInv(), ModIntType>(omega0);
                        omega_last = omega0 * omega0;
                        unit1X8 = unitx8(ntt_len, 1, rootInv());
                        unit_last = unit1X8 * unit1X8;
                        if (norm)
                        {
                            for (size_t i = 0; i < len_4; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_i2424_2layer<true>(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);
                                temp0 = temp0.norm(), temp1 = temp1.norm(), temp2 = temp2.norm(), temp3 = temp3.norm();

                                temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                        }
                        else
                        {
                            for (size_t i = 0; i < len_4; i += 8)
                            {
                                temp0.load(&it0[i]), temp1.load(&it1[i]), temp2.load(&it2[i]), temp3.load(&it3[i]);

                                dit_butterfly2_i2424_2layer<OUT2P>(temp0, temp1, temp2, temp3, omega0, omega1, omega_last);

                                temp0.store(&it0[i]), temp1.store(&it1[i]), temp2.store(&it2[i]), temp3.store(&it3[i]);
                                omega0 *= unit1X8, omega1 *= unit1X8, omega_last *= unit_last;
                            }
                        }
                    }

                    // Outer recursion
                    static void convolution(const uint32_t in1[], size_t len1, const uint32_t in2[], size_t len2, uint32_t out[])
                    {
                        size_t conv_len = len1 + len2 - 1, ntt_len = int_ceil2(conv_len);
                        const ModIntType len_inv_r = ModIntType(ntt_len).inv().mulR();
                        auto ntt_p1 = reinterpret_cast<ModIntType *>(_mm_malloc(sizeof(ModIntType) * ntt_len, 32));
                        auto ntt_p2 = reinterpret_cast<ModIntType *>(_mm_malloc(sizeof(ModIntType) * ntt_len, 32));
                        if (ntt_len <= LONG_THRESHOLD)
                        {
                            std::memcpy(ntt_p1, in1, sizeof(ModIntType) * len1);
                            std::memcpy(ntt_p2, in2, sizeof(ModIntType) * len2);
                            std::memset(ntt_p1 + len1, 0, sizeof(ModIntType) * (ntt_len - len1));
                            std::memset(ntt_p2 + len2, 0, sizeof(ModIntType) * (ntt_len - len2));
                            convolutionRecursion<true>(ntt_p1, ntt_p2, ntt_len, len_inv_r);
                            size_t i = 0;
                            for (const size_t rem_len = conv_len - conv_len % 8; i < rem_len; i += 8)
                            {
                                ModIntX8 temp0;
                                temp0.load(&ntt_p1[i]);
                                temp0.norm().storeu(&out[i]);
                            }
                            for (; i < conv_len; i++)
                            {
                                ModIntType temp = ntt_p1[i];
                                out[i] = temp.norm().data;
                            }
                            _mm_free(ntt_p1);
                            _mm_free(ntt_p2);
                            return;
                        }
                        difInLayer(in1, len1, ntt_p1, ntt_len);
                        difInLayer(in2, len2, ntt_p2, ntt_len);
                        size_t len_4 = ntt_len / 4;
                        convolutionRecursion<true>(ntt_p1, ntt_p2, len_4, len_inv_r);
                        convolutionRecursion<false>(ntt_p1 + len_4, ntt_p2 + len_4, len_4, len_inv_r);
                        convolutionRecursion<true>(ntt_p1 + len_4 * 2, ntt_p2 + len_4 * 2, len_4, len_inv_r);
                        convolutionRecursion<false>(ntt_p1 + len_4 * 3, ntt_p2 + len_4 * 3, len_4, len_inv_r);
                        INTT::ditOutLayer(out, conv_len, ntt_p1, ntt_len);
                        _mm_free(ntt_p1);
                        _mm_free(ntt_p2);
                    }
                };
                template <uint32_t MOD, uint32_t ROOT>
                constexpr int NTT<MOD, ROOT>::MOD_BITS;
                template <uint32_t MOD, uint32_t ROOT>
                constexpr int NTT<MOD, ROOT>::MAX_LOG_LEN;
                template <uint32_t MOD, uint32_t ROOT>
                constexpr size_t NTT<MOD, ROOT>::NTT_MAX_LEN;
                template <uint32_t MOD, uint32_t ROOT>
                constexpr size_t NTT<MOD, ROOT>::LONG_THRESHOLD;
            }
        }
    }
}

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    using NTT = hint::transform::ntt::radix2_avx::NTT<998244353, 3>;
    NTT::convolution(a, n + 1, b, m + 1, c);
}

class QPrint
{
private:
    char *data = nullptr;
    size_t pos = 0;

public:
    QPrint(size_t max_len)
    {
        data = new char[max_len];
    }
    ~QPrint()
    {
        if (data != nullptr)
        {
            delete[] data;
        }
    }
    void operator<<(uint64_t n)
    {
        if (pos != 0)
        {
            data[pos] = ' ';
            pos++;
        }
        if (n == 0)
        {
            data[pos] = '0';
            pos++;
            return;
        }
        size_t digs = pos;
        uint64_t tmp = n;
        while (n > 0)
        {
            n /= 10;
            digs++;
        }
        pos = digs;
        while (tmp > 0)
        {
            digs--;
            data[digs] = tmp % 10 + '0';
            tmp /= 10;
        }
        data[pos] = '\0';
    }
    void operator<<(const std::string &s)
    {
        if (pos != 0)
        {
            data[pos] = ' ';
            pos++;
        }
        memcpy(data + pos, s.data(), s.size());
        pos += s.size();
        data[pos] = '\0';
    }
    void put() const
    {
        puts(data);
    }
};
inline int ReadNum()
{
    int res = 0;
    int tmp = getchar();
    while (tmp < '0' || '9' < tmp)
    {
        tmp = getchar();
    }
    while ('0' <= tmp && tmp <= '9')
    {
        res *= 10;
        res += (tmp - '0');
        tmp = getchar();
    }
    return res;
}
int main()
{
    size_t m = 4, n = 4;
    m = ReadNum();
    n = ReadNum();
    size_t len1 = m + 1, len2 = n + 1;
    QPrint qout(20000000);
    using NTT = hint::transform::ntt::radix2_avx::NTT<998244353, 3>;
    using ModInt = NTT::ModIntType;
    size_t conv_len = len1 + len2 - 1, ntt_len = hint::int_ceil2(conv_len);
    ModInt len_inv_r = ModInt(ntt_len).inv().mulR();
    auto ntt_p1 = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
    auto ntt_p2 = (ModInt *)_mm_malloc(ntt_len * sizeof(ModInt), 32);
    for (size_t i = 0; i < len1; i++)
    {
        ntt_p1[i].data = ReadNum();
    }
    for (size_t i = 0; i < len2; i++)
    {
        ntt_p2[i].data = ReadNum();
    }
    std::memset(ntt_p1 + len1, 0, (ntt_len - len1) * sizeof(ModInt));
    std::memset(ntt_p2 + len2, 0, (ntt_len - len2) * sizeof(ModInt));
    NTT::convolutionRecursion<true>(ntt_p1, ntt_p2, ntt_len, len_inv_r, true);
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        qout << ntt_p1[i].data;
    }
    qout.put();
}

CompilationN/AN/ACompile OKScore: N/A

Subtask #1 Testcase #167.26 us104 KBAcceptedScore: 100

Subtask #1 Testcase #26.768 ms4 MB + 964 KBAcceptedScore: 0

Subtask #1 Testcase #32.335 ms1 MB + 668 KBAcceptedScore: 0

Subtask #1 Testcase #42.441 ms1 MB + 644 KBAcceptedScore: 0

Subtask #1 Testcase #561.25 us104 KBAcceptedScore: 0

Subtask #1 Testcase #660.09 us104 KBAcceptedScore: 0

Subtask #1 Testcase #760.98 us104 KBAcceptedScore: 0

Subtask #1 Testcase #86.197 ms4 MB + 428 KBAcceptedScore: 0

Subtask #1 Testcase #96.208 ms4 MB + 428 KBAcceptedScore: 0

Subtask #1 Testcase #105.527 ms3 MB + 912 KBAcceptedScore: 0

Subtask #1 Testcase #116.927 ms5 MB + 100 KBAcceptedScore: 0

Subtask #1 Testcase #124.434 ms2 MB + 884 KBAcceptedScore: 0

Subtask #1 Testcase #1360.31 us104 KBAcceptedScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-03-01 01:50:20 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠