提交记录 21635


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY test. 自定义测试 Runtime Error 0 50.671 ms 98396 KB C++14 55.67 KB
提交时间 评测时间
2024-04-24 15:36:05 2024-04-24 15:36:09
#pragma GCC target("avx2")
#pragma GCC optimize("O2")
#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <array>
#include <ctime>
#include <cstring>
#include <immintrin.h>
template <typename T, size_t LEN>
class AlignAry
{
private:
    alignas(4096) T ary[LEN];

public:
    constexpr AlignAry() {}
    constexpr T& operator[](size_t index)
    {
        return ary[index];
    }
    constexpr const T& operator[](size_t index) const
    {
        return ary[index];
    }
    T* data()
    {
        return reinterpret_cast<T*>(ary);
    }
    T* begin()
    {
        return reinterpret_cast<T*>(ary);
    }
    T* end()
    {
        return begin() + LEN;
    }
    const T* data() const
    {
        return reinterpret_cast<const T*>(ary);
    }
    template <typename Ty>
    Ty* cast_ptr()
    {
        return reinterpret_cast<Ty*>(ary);
    }
    template <typename Ty>
    const Ty* cast_ptr() const
    {
        return reinterpret_cast<const Ty*>(ary);
    }
};
namespace hint
{
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    // bits个二进制全为1的数,等于2^bits-1
    template <typename T>
    constexpr T all_one(int bits)
    {
        T tmp = T(1) << (bits - 1);
        return tmp - 1 + tmp;
    }

    // 整数log2
    template <typename UintTy>
    constexpr int hint_log2(UintTy n)
    {
        constexpr int bits = 8 * sizeof(UintTy);
        constexpr UintTy mask = all_one<UintTy>(bits / 2) << (bits / 2);
        UintTy m = mask;
        int res = 0, shift = bits / 2;
        while (shift > 0)
        {
            if ((n & m))
            {
                res += shift;
                n >>= shift;
            }
            shift /= 2;
            m >>= shift;
        }
        return res;
    }
    // return n^-1 mod 2^pow, Newton iteration
    constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
    {
        const uint64_t mask = all_one<uint64_t>(pow);
        uint64_t xn = 1, t = n & mask;
        while (t != 1)
        {
            xn = (xn * (2 - t));
            t = (xn * n) & mask;
        }
        return xn & mask;
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, uint32_t n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    // 数组按位相乘
    template <typename T>
    inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
    {
        size_t mod4 = len % 4;
        len -= mod4;
        for (size_t i = 0; i < len; i += 4)
        {
            out[i] = in1[i] * in2[i];
            out[i + 1] = in1[i + 1] * in2[i + 1];
            out[i + 2] = in1[i + 2] * in2[i + 2];
            out[i + 3] = in1[i + 3] * in2[i + 3];
        }
        for (size_t i = len; i < len + mod4; i++)
        {
            out[i] = in1[i] * in2[i];
        }
    }
    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        template <typename T>
        inline void transform2(T& sum, T& diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }

        // 返回单位圆上辐角为theta的点
        template <typename FloatTy>
        inline auto unit_root(FloatTy theta)
        {
            return std::polar<FloatTy>(1.0, theta);
        }

        // 二进制逆序
        template <typename It>
        void binary_reverse_swap(It begin, It end)
        {
            const size_t len = end - begin;
            // 左下标小于右下标时交换,防止重复交换
            auto smaller_swap = [=](It it_left, It it_right)
            {
                if (it_left < it_right)
                {
                    std::swap(it_left[0], it_right[0]);
                }
            };
            // 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
            auto get_next_bitrev = [=](It last)
            {
                size_t k = len / 2, indx = last - begin;
                indx ^= k;
                while (k > indx)
                {
                    k >>= 1;
                    indx ^= k;
                };
                return begin + indx;
            };
            // 长度较短的普通逆序
            if (len <= 16)
            {
                for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
                {
                    smaller_swap(i, j);
                    j = get_next_bitrev(j);
                }
                return;
            }
            const size_t len_8 = len / 8;
            const auto last = begin + len_8;
            auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
            for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
            {
                smaller_swap(i0, j);
                smaller_swap(i1, j + 1);
                smaller_swap(i2, j + 2);
                smaller_swap(i3, j + 3);
                smaller_swap(i0 + len_8, j + 4);
                smaller_swap(i1 + len_8, j + 5);
                smaller_swap(i2 + len_8, j + 6);
                smaller_swap(i3 + len_8, j + 7);
                j = get_next_bitrev(j);
            }
        }

        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T ary, const size_t len)
        {
            binary_reverse_swap(ary, ary + len);
        }
        namespace hint_ntt
        {
            template <uint32_t MOD>
            struct ModInt32
            {
                uint32_t data;
                constexpr ModInt32() {}
                constexpr ModInt32(uint32_t in) : data(in) {}

                constexpr ModInt32 largeNorm() const
                {
                    return data < MOD ? data : data - MOD;
                }
                constexpr uint64_t mul64(ModInt32 in) const
                {
                    return uint64_t(data) * uint64_t(in.data);
                }
                constexpr ModInt32 getW1() const
                {
                    return (uint64_t(data) << 32) / MOD;
                }
                constexpr ModInt32 mulModShoup(ModInt32 w, ModInt32 w1) const
                {
                    uint64_t q = (uint64_t(data) * uint32_t(w1.data)) >> 32;
                    ModInt32 res = uint64_t(data) * w.data - q * MOD;
                    // return res.largeNorm();
                    return res;
                }
                constexpr ModInt32 operator+(ModInt32 in) const
                {
                    uint32_t diff = MOD - in.data;
                    return data < diff ? data + in.data : data - diff;
                }
                constexpr ModInt32 operator-(ModInt32 in) const
                {
                    in.data = data - in.data;
                    return in.data > data ? in.data + MOD : in.data;
                }
                constexpr ModInt32 operator*(ModInt32 in) const
                {
                    return mul64(in) % MOD;
                }
                constexpr ModInt32& operator+=(ModInt32 in)
                {
                    return *this = *this + in;
                }
                constexpr ModInt32& operator-=(ModInt32 in)
                {
                    return *this = *this - in;
                }
                constexpr ModInt32& operator*=(ModInt32 in)
                {
                    return *this = *this * in;
                }
                constexpr operator uint32_t() const
                {
                    return data;
                }
                static constexpr uint32_t mod()
                {
                    return MOD;
                }
            };

            template <uint32_t MOD>
            struct MontInt32
            {
                static constexpr uint64_t R = uint64_t(1) << 32;
                static constexpr uint32_t R_MASK = R - 1;
                static constexpr uint32_t MOD_INV = inv_mod2pow(MOD, 32);
                static constexpr uint32_t MOD_INV_NEG = R - MOD_INV;
                static constexpr uint32_t MOD2 = MOD * 2;
                static_assert(hint_log2(MOD) <= 30, "MOD can't be larger than 30 bits");
                static_assert(uint32_t(MOD_INV* MOD) == 1, "Montgomery32 modulus is not correct");

                uint32_t data;
                constexpr MontInt32() : data(0) {}
                constexpr MontInt32(uint32_t n) : data(toMont(n)) {}

                static constexpr uint32_t toMont(uint32_t n)
                {
                    return (uint64_t(n) << 32) % MOD;
                }
                static constexpr uint32_t redc(uint64_t input)
                {
                    uint64_t n = uint32_t(input) * MOD_INV_NEG;
                    n = n * MOD + input;
                    n >>= 32;
                    return n < MOD ? n : n - MOD;
                }
                static constexpr uint32_t toInt(uint32_t n)
                {
                    return redc(n);
                }
                static constexpr uint32_t addMont(uint32_t m, uint32_t n)
                {
                    n = m + n;
                    return n < MOD ? n : n - MOD;
                }
                static constexpr uint32_t subMont(uint32_t m, uint32_t n)
                {
                    n = m - n;
                    return n > m ? n + MOD : n;
                }
                static constexpr uint32_t mulMont(uint32_t m, uint32_t n)
                {
                    return redc(uint64_t(m) * n);
                }
                constexpr void fromInt(uint32_t n)
                {
                    data = toMont(n);
                }
                constexpr uint32_t toInt() const
                {
                    return toInt(data);
                }
                constexpr operator uint32_t() const
                {
                    return toInt();
                }
                constexpr MontInt32 operator+(MontInt32 rhs) const
                {
                    rhs.data = addMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32 operator-(MontInt32 rhs) const
                {
                    rhs.data = subMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32 operator*(MontInt32 rhs) const
                {
                    rhs.data = mulMont(data, rhs.data);
                    return rhs;
                }
                constexpr MontInt32& operator+=(const MontInt32& rhs)
                {
                    data = addMont(data, rhs.data);
                    return *this;
                }
                constexpr MontInt32& operator-=(const MontInt32& rhs)
                {
                    data = subMont(data, rhs.data);
                    return *this;
                }
                constexpr MontInt32& operator*=(const MontInt32& rhs)
                {
                    data = mulMont(data, rhs.data);
                    return *this;
                }
                static constexpr uint32_t mod()
                {
                    return MOD;
                }
            };

            template <uint32_t MOD>
            struct MontInt32X8
            {
                using MontInt = MontInt32<MOD>;
                __m256i data;

                MontInt32X8() { data = _mm256_setzero_si256(); }
                MontInt32X8(MontInt x) { data = _mm256_set1_epi32(x.data); }
                MontInt32X8(int32_t x0, int32_t x1, int32_t x2, int32_t x3, int32_t x4, int32_t x5, int32_t x6, int32_t x7)
                {
                    data = _mm256_set_epi32(x7, x6, x5, x4, x3, x2, x1, x0);
                }
                MontInt32X8(__m256i rhs) : data(rhs) {}
                template <typename T>
                MontInt32X8(const T* p)
                {
                    loadu(p);
                }
                static constexpr uint32_t mod()
                {
                    return MOD;
                }
                static MontInt32X8 zeroX8()
                {
                    return _mm256_setzero_si256();
                }
                static MontInt32X8 modX8()
                {
                    return _mm256_set1_epi32(MOD);
                }
                static MontInt32X8 mod1X8()
                {
                    return _mm256_set1_epi32(MOD - 1);
                }
                static MontInt32X8 mod2X8()
                {
                    return _mm256_set1_epi32(MOD * 2);
                }
                static MontInt32X8 modNX8()
                {
                    return _mm256_set1_epi32(MontInt::MOD_INV_NEG);
                }

                MontInt32X8 mul64(MontInt32X8 rhs) const
                {
                    return _mm256_mul_epu32(data, rhs.data);
                }
                MontInt32X8 lShift64(int n) const
                {
                    return _mm256_slli_epi64(data, n);
                }
                MontInt32X8 rShift64(int n) const
                {
                    return _mm256_srli_epi64(data, n);
                }
                MontInt32X8 evenElements() const
                {
                    return blend<0b10101010>(data, zeroX8());
                }
                MontInt32X8 oddElements() const
                {
                    return blend<0b01010101>(data, zeroX8());
                }
                std::pair<MontInt32X8, MontInt32X8> mul64hl(MontInt32X8 rhs) const
                {
                    return std::make_pair(mul64(rhs), rShift64(32).mul64(rhs.rShift64(32)));
                }
                // MontInt32X8 getW1() const
                // {
                //     alignas(32) uint32_t temp[8];
                //     store(temp);
                //     for (auto &&i : temp)
                //     {
                //         i = (uint64_t(i) << 32) / mod;
                //     }
                //     return MontInt32X8(temp);
                // }
                // MontInt32X8 mulModShoup(MontInt32X8 w, MontInt32X8 w1) const
                // {
                //     MontInt32X8 q0, q1, t0, t1;
                //     std::tie(q0, q1) = mul64hl(w1);
                //     std::tie(t0, t1) = mul64hl(w);
                //     q0 = q0.rShift64(32), q1 = q1.rShift64(32);
                //     q0 = q0.mul64(MontInt32X8(mod)), q1 = q1.mul64(MontInt32X8(mod));
                //     t0 = t0.rawSub64(q0), t1 = t1.rawSub64(q1);
                //     return t0 | t1.lShift64(32);
                // }
                static MontInt32X8 montRedcLazy(MontInt32X8 even64, MontInt32X8 odd64)
                {
                    MontInt32X8 p0 = even64.mul64(modNX8());
                    MontInt32X8 p1 = odd64.mul64(modNX8());
                    p0 = p0.mul64(modX8()).rawAdd64(even64).rShift64(32);
                    p1 = p1.mul64(modX8()).rawAdd64(odd64);
                    return blend<0b10101010>(p0, p1);
                }
                static MontInt32X8 montRedc(MontInt32X8 even64, MontInt32X8 odd64)
                {
                    return montRedcLazy(even64, odd64).largeNorm();
                }
                MontInt32X8 toMont() const
                {
                    alignas(32) uint32_t temp[8];
                    store(temp);
                    for (auto&& i : temp)
                    {
                        i = (uint64_t(i) << 32) % MOD;
                    }
                    return MontInt32X8(temp);
                }
                MontInt32X8 toInt() const
                {
                    MontInt32X8 e = evenElements();
                    MontInt32X8 o = rShift64(32);
                    return montRedc(e, o);
                }
                template <int N>
                static MontInt32X8 blend(MontInt32X8 a, MontInt32X8 b)
                {
                    return _mm256_blend_epi32(a.data, b.data, N);
                }
                MontInt32X8 largeNorm() const
                {
                    MontInt32X8 sub = (*this > mod1X8()) & modX8();
                    return rawSub(sub);
                }
                MontInt32X8 smallNorm() const
                {
                    MontInt32X8 add = (zeroX8() > *this) & modX8();
                    return rawAdd(add);
                }
                MontInt32X8 smallNorm2() const
                {
                    MontInt32X8 add = (zeroX8() > *this) & mod2X8();
                    return rawAdd(add);
                }
                MontInt32X8 addMont(MontInt32X8 rhs) const
                {
                    return rawAdd(rhs).largeNorm();
                }
                MontInt32X8 subMont(MontInt32X8 rhs) const
                {
                    return rawSub(rhs).smallNorm();
                }
                MontInt32X8 mulMont(MontInt32X8 rhs) const
                {
                    auto mulhl = mul64hl(rhs);
                    return montRedc(mulhl.first, mulhl.second);
                }

                MontInt32X8 operator+(MontInt32X8 rhs) const
                {
                    return addMont(rhs);
                }
                MontInt32X8 operator-(MontInt32X8 rhs) const
                {
                    return subMont(rhs);
                }
                MontInt32X8 operator*(MontInt32X8 rhs) const
                {
                    return mulMont(rhs);
                }
                MontInt32X8 rawAdd(MontInt32X8 rhs) const
                {
                    return _mm256_add_epi32(data, rhs.data);
                }
                MontInt32X8 rawSub(MontInt32X8 rhs) const
                {
                    return _mm256_sub_epi32(data, rhs.data);
                }
                MontInt32X8 rawAdd64(MontInt32X8 rhs) const
                {
                    return _mm256_add_epi64(data, rhs.data);
                }
                MontInt32X8 rawSub64(MontInt32X8 rhs) const
                {
                    return _mm256_sub_epi64(data, rhs.data);
                }

                MontInt32X8 operator>(MontInt32X8 n) const
                {
                    return _mm256_cmpgt_epi32(data, n.data);
                }
                MontInt32X8 operator<(MontInt32X8 n) const
                {
                    return n > *this;
                }
                MontInt32X8 operator==(MontInt32X8 n) const
                {
                    return _mm256_cmpeq_epi32(data, n.data);
                }
                MontInt32X8 operator&(MontInt32X8 n) const
                {
                    return _mm256_and_si256(data, n.data);
                }
                MontInt32X8 operator|(MontInt32X8 n) const
                {
                    return _mm256_or_si256(data, n.data);
                }
                MontInt32X8 operator^(MontInt32X8 n) const
                {
                    return _mm256_xor_si256(data, n.data);
                }
                void set1(int32_t n)
                {
                    data = _mm256_set1_epi32(n);
                }
                template <typename T>
                void loadu(const T* p)
                {
                    data = _mm256_loadu_si256((const __m256i*)p);
                }
                template <typename T>
                void load(const T* p)
                {
                    data = _mm256_load_si256((const __m256i*)p);
                    // data = *reinterpret_cast<const __m256i *>(p);
                }
                template <typename T>
                void storeu(T* p) const
                {
                    _mm256_storeu_si256((__m256i*)p, data);
                }
                template <typename T>
                void store(T* p) const
                {
                    _mm256_store_si256((__m256i*)p, data);
                    // *reinterpret_cast<__m256i *>(p) = data;
                }
                void printI32() const
                {
                    alignas(32) int32_t v[8];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                        << "," << v[2] << "," << v[3]
                        << "," << v[4] << "," << v[5]
                        << "," << v[6] << "," << v[7] << "]" << std::endl;
                }
                void printU32() const
                {
                    alignas(32) uint32_t v[8];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                        << "," << v[2] << "," << v[3]
                        << "," << v[4] << "," << v[5]
                        << "," << v[6] << "," << v[7] << "]" << std::endl;
                }
                void printU64() const
                {
                    alignas(32) uint64_t v[4];
                    store(v);
                    std::cout << "[" << v[0] << "," << v[1]
                        << "," << v[2] << "," << v[3] << "]" << std::endl;
                }
            };
            namespace split_radix_avx
            {
                template <uint32_t ROOT, typename ModIntType>
                inline ModIntType mul_w41(ModIntType n)
                {
                    constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                    return n * W_4_1;
                }
                template <uint32_t ROOT, typename ModIntType>
                inline ModIntType mul_w81(ModIntType n)
                {
                    constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                    return n * W_8_1;
                }
                template <uint32_t ROOT, typename ModIntType>
                inline ModIntType mul_w83(ModIntType n)
                {
                    constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / 8) * 3);
                    return n * W_8_3;
                }
                // template <uint32_t MOD, uint32_t ROOT>
                // inline void dit_butterfly244(ModIntType it[], ModIntType omega1, ModIntType omega3, size_t rank)
                // {
                //     auto temp2 = it[rank * 2] * omega1;
                //     auto temp3 = it[rank * 3] * omega3;

                //     transform2(temp2, temp3);
                //     temp3 = mul_w41<ROOT>(temp3);

                //     auto temp0 = it[0];
                //     auto temp1 = it[rank];
                //     it[0] = temp0 + temp2;
                //     it[rank] = temp1 + temp3;
                //     it[rank * 2] = temp0 - temp2;
                //     it[rank * 3] = temp1 - temp3;
                // }
                // template <uint32_t MOD, uint32_t ROOT>
                // inline void dif_butterfly244(ModIntType it[], ModIntType omega1, ModIntType omega3, size_t rank)
                // {
                //     auto temp0 = it[0];
                //     auto temp1 = it[rank];
                //     auto temp2 = it[rank * 2];
                //     auto temp3 = it[rank * 3];
                //     it[0] = temp0 + temp2;
                //     it[rank] = temp1 + temp3;

                //     temp2 = temp0 - temp2;
                //     temp3 = temp1 - temp3;
                //     temp3 = mul_w41<ROOT>(temp3);
                //     transform2(temp2, temp3);

                //     it[rank * 2] = temp2 * omega1;
                //     it[rank * 3] = temp3 * omega3;
                // }
                template <uint32_t ROOT, uint32_t MOD>
                inline void dit_butterfly244_avx2(MontInt32<MOD> it[], MontInt32X8<MOD> omega1, MontInt32X8<MOD> omega3, size_t rank)
                {
                    MontInt32X8<MOD> temp0, temp1, temp2, temp3;
                    temp2.load(&it[rank * 2]);
                    temp3.load(&it[rank * 3]);
                    temp2 = temp2 * omega1;
                    temp3 = temp3 * omega3;

                    transform2(temp2, temp3);
                    constexpr MontInt32<MOD> W_4_1 = qpow(MontInt32<MOD>(ROOT), (MOD - 1) / 4);
                    MontInt32X8<MOD> Wx8;
                    Wx8.set1(W_4_1.data);
                    temp3 = temp3 * Wx8;

                    temp0.load(&it[0]);
                    temp1.load(&it[rank]);
                    (temp0 + temp2).store(&it[0]);
                    (temp1 + temp3).store(&it[rank]);
                    (temp0 - temp2).store(&it[rank * 2]);
                    (temp1 - temp3).store(&it[rank * 3]);
                }
                template <uint32_t ROOT, uint32_t MOD>
                inline void dif_butterfly244_avx2(MontInt32<MOD> it[], MontInt32X8<MOD> omega1, MontInt32X8<MOD> omega3, size_t rank)
                {
                    MontInt32X8<MOD> temp0, temp1, temp2, temp3;
                    temp0.loadu(&it[0]);
                    temp1.loadu(&it[rank]);
                    temp2.loadu(&it[rank * 2]);
                    temp3.loadu(&it[rank * 3]);
                    (temp0 + temp2).storeu(&it[0]);
                    (temp1 + temp3).storeu(&it[rank]);

                    temp2 = temp0 - temp2;
                    temp3 = temp1 - temp3;
                    constexpr MontInt32<MOD> W_4_1 = qpow(MontInt32<MOD>(ROOT), (MOD - 1) / 4);
                    MontInt32X8<MOD> Wx8;
                    Wx8.set1(W_4_1.data);
                    temp3 = temp3 * Wx8;
                    transform2(temp2, temp3);

                    (temp2 * omega1).storeu(&it[rank * 2]);
                    (temp3 * omega3).storeu(&it[rank * 3]);
                }

                // template <uint64_t ROOT, uint32_t MOD>
                // static void dit_butterfly2488_avx(MontInt32<MOD> input[],
                //                                   MontInt32X8<MOD> omega, MontInt32X8<MOD> omega3, MontInt32X8<MOD> omega7,
                //                                   size_t rank)
                // {
                //     MontInt32X8<MOD> temp0 = input[0];
                //     MontInt32X8<MOD> temp1 = input[rank];
                //     MontInt32X8<MOD> temp2 = input[rank * 2];
                //     MontInt32X8<MOD> temp3 = input[rank * 3];
                //     MontInt32X8<MOD> temp4 = input[rank * 4] * omega;
                //     MontInt32X8<MOD> temp5 = input[rank * 5] * omega;
                //     MontInt32X8<MOD> temp6 = input[rank * 6] * omega3;
                //     MontInt32X8<MOD> temp7 = input[rank * 7] * omega7;
                //     transform2(temp6, temp7);
                //     transform2(temp4, temp6);
                //     temp6 = mul_w41<ROOT>(temp6);
                //     temp7 = mul_w41<ROOT>(temp7);
                //     transform2(temp5, temp7);
                //     temp5 = mul_w81<ROOT>(temp5);
                //     temp7 = mul_w83<ROOT>(temp7);
                //     input[0] = temp0 + temp4;
                //     input[rank] = temp1 + temp5;
                //     input[rank * 2] = temp2 + temp6;
                //     input[rank * 3] = temp3 + temp7;
                //     input[rank * 4] = temp0 - temp4;
                //     input[rank * 5] = temp1 - temp5;
                //     input[rank * 6] = temp2 - temp6;
                //     input[rank * 7] = temp3 - temp7;
                // }
                // template <uint64_t ROOT, uint32_t MOD>
                // static void dif_butterfly2488_avx(MontInt32<MOD> input[],
                //                                   MontInt32X8<MOD> omega, MontInt32X8<MOD> omega3, MontInt32X8<MOD> omega7,
                //                                   size_t rank)
                // {
                //     MontInt32X8<MOD> temp0 = input[0];
                //     MontInt32X8<MOD> temp1 = input[rank];
                //     MontInt32X8<MOD> temp2 = input[rank * 2];
                //     MontInt32X8<MOD> temp3 = input[rank * 3];
                //     MontInt32X8<MOD> temp4 = input[rank * 4];
                //     MontInt32X8<MOD> temp5 = input[rank * 5];
                //     MontInt32X8<MOD> temp6 = input[rank * 6];
                //     MontInt32X8<MOD> temp7 = input[rank * 7];
                //     transform2(temp0, temp4);
                //     transform2(temp1, temp5);
                //     transform2(temp2, temp6);
                //     transform2(temp3, temp7);
                //     temp5 = mul_w81<ROOT>(temp5);
                //     temp7 = mul_w83<ROOT>(temp7);
                //     transform2(temp5, temp7);
                //     temp6 = mul_w41<ROOT>(temp6);
                //     temp7 = mul_w41<ROOT>(temp7);
                //     transform2(temp4, temp6);
                //     transform2(temp6, temp7);
                //     input[0] = temp0;
                //     input[rank] = temp1;
                //     input[rank * 2] = temp2;
                //     input[rank * 3] = temp3;
                //     input[rank * 4] = temp4 * omega;
                //     input[rank * 5] = temp5 * omega;
                //     input[rank * 6] = temp6 * omega3;
                //     input[rank * 7] = temp7 * omega7;
                // }
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                struct NTTShort
                {
                    static constexpr size_t ntt_len = LEN;
                    static constexpr size_t half_len = ntt_len / 2;
                    static constexpr size_t quarter_len = ntt_len / 4;
                    static constexpr size_t octant_len = ntt_len / 8;
                    static constexpr int log_len = hint_log2(ntt_len);

                    using ModIntX8 = MontInt32X8<MOD>;
                    using ModIntType = typename ModIntX8::MontInt;
                    using HalfNTT = NTTShort<half_len, MOD, ROOT>;
                    using QuarterNTT = NTTShort<quarter_len, MOD, ROOT>;
                    using TableType = AlignAry<ModIntType, quarter_len>;

                    static constexpr TableType getNTTTable(int factor)
                    {
                        ModIntType root = qpow(ModIntType(ROOT), ((ModIntType::mod() - 1) / LEN) * factor);
                        ModIntType omega(1);
                        TableType res;
                        for (auto&& i : res)
                        {
                            i = omega;
                            omega *= root;
                        }
                        return res;
                    }

                    static TableType table1;
                    static TableType table3;

                    static constexpr uint64_t mod()
                    {
                        return ModIntType::mod();
                    }

                    static constexpr uint64_t root()
                    {
                        return ROOT;
                    }
                    static void dit(ModIntType in_out[])
                    {
                        QuarterNTT::dit(in_out + half_len + quarter_len);
                        QuarterNTT::dit(in_out + half_len);
                        HalfNTT::dit(in_out);

                        ModIntX8 omega1, omega3;
                        for (size_t i = 0; i < quarter_len; i += 8)
                        {
                            omega1.loadu(&table1[i]), omega3.loadu(&table3[i]);
                            dit_butterfly244_avx2<ROOT>(in_out + i, omega1, omega3, quarter_len);
                        }
                    }

                    static void dif(ModIntType in_out[])
                    {
                        ModIntX8 omega1, omega3;
                        for (size_t i = 0; i < quarter_len; i += 8)
                        {
                            omega1.loadu(&table1[i]), omega3.loadu(&table3[i]);
                            dif_butterfly244_avx2<ROOT>(in_out + i, omega1, omega3, quarter_len);
                        }
                        HalfNTT::dif(in_out);
                        QuarterNTT::dif(in_out + half_len);
                        QuarterNTT::dif(in_out + half_len + quarter_len);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < LEN)
                        {
                            HalfNTT::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < LEN)
                        {
                            HalfNTT::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table1 = NTTShort<LEN, MOD, ROOT>::getNTTTable(1);
                template <size_t LEN, uint32_t MOD, uint32_t ROOT>
                typename NTTShort<LEN, MOD, ROOT>::TableType NTTShort<LEN, MOD, ROOT>::table3 = NTTShort<LEN, MOD, ROOT>::getNTTTable(3);

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<0, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[]) {}
                    static void dif(ModIntType in_out[]) {}
                    static void dit(ModIntType in_out[], size_t len) {}
                    static void dif(ModIntType in_out[], size_t len) {}
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<1, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[]) {}
                    static void dif(ModIntType in_out[]) {}
                    static void dit(ModIntType in_out[], size_t len) {}
                    static void dif(ModIntType in_out[], size_t len) {}
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<2, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[])
                    {
                        transform2(in_out[0], in_out[1]);
                    }
                    static void dif(ModIntType in_out[])
                    {
                        transform2(in_out[0], in_out[1]);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 2)
                        {
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 2)
                        {
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<4, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];

                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);

                        in_out[0] = temp0 + temp2;
                        in_out[1] = temp1 + temp3;
                        in_out[2] = temp0 - temp2;
                        in_out[3] = temp1 - temp3;
                    }
                    static void dif(ModIntType in_out[])
                    {
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        temp3 = mul_w41<ROOT>(temp3);

                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                        in_out[2] = temp2 + temp3;
                        in_out[3] = temp2 - temp3;
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 4)
                        {
                            NTTShort<2, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 4)
                        {
                            NTTShort<2, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<8, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    static void dit(ModIntType in_out[])
                    {
                        static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        static constexpr ModIntType w2 = qpow(w1, 2);
                        static constexpr ModIntType w3 = qpow(w1, 3);
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];
                        auto temp4 = in_out[4];
                        auto temp5 = in_out[5];
                        auto temp6 = in_out[6];
                        auto temp7 = in_out[7];

                        transform2(temp0, temp1);
                        transform2(temp2, temp3);
                        transform2(temp4, temp5);
                        transform2(temp6, temp7);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp7 = mul_w41<ROOT>(temp7);

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        transform2(temp4, temp6);
                        transform2(temp5, temp7);
                        temp5 = temp5 * w1;
                        temp6 = temp6 * w2;
                        temp7 = temp7 * w3;

                        in_out[0] = temp0 + temp4;
                        in_out[1] = temp1 + temp5;
                        in_out[2] = temp2 + temp6;
                        in_out[3] = temp3 + temp7;
                        in_out[4] = temp0 - temp4;
                        in_out[5] = temp1 - temp5;
                        in_out[6] = temp2 - temp6;
                        in_out[7] = temp3 - temp7;
                    }
                    static void dif(ModIntType in_out[])
                    {
                        static constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                        static constexpr ModIntType w2 = qpow(w1, 2);
                        static constexpr ModIntType w3 = qpow(w1, 3);
                        auto temp0 = in_out[0];
                        auto temp1 = in_out[1];
                        auto temp2 = in_out[2];
                        auto temp3 = in_out[3];
                        auto temp4 = in_out[4];
                        auto temp5 = in_out[5];
                        auto temp6 = in_out[6];
                        auto temp7 = in_out[7];

                        transform2(temp0, temp4);
                        transform2(temp1, temp5);
                        transform2(temp2, temp6);
                        transform2(temp3, temp7);
                        temp5 = temp5 * w1;
                        temp6 = temp6 * w2;
                        temp7 = temp7 * w3;

                        transform2(temp0, temp2);
                        transform2(temp1, temp3);
                        transform2(temp4, temp6);
                        transform2(temp5, temp7);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp7 = mul_w41<ROOT>(temp7);

                        in_out[0] = temp0 + temp1;
                        in_out[1] = temp0 - temp1;
                        in_out[2] = temp2 + temp3;
                        in_out[3] = temp2 - temp3;
                        in_out[4] = temp4 + temp5;
                        in_out[5] = temp4 - temp5;
                        in_out[6] = temp6 + temp7;
                        in_out[7] = temp6 - temp7;
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 8)
                        {
                            NTTShort<4, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 8)
                        {
                            NTTShort<4, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };

                template <uint32_t MOD, uint32_t ROOT>
                struct NTTShort<16, MOD, ROOT>
                {
                    using ModIntType = MontInt32<MOD>;

                    using NTT4 = NTTShort<4, MOD, ROOT>;
                    using NTT8 = NTTShort<8, MOD, ROOT>;

                    static void dit(ModIntType in_out[])
                    {
                        constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
                        constexpr ModIntType w2 = qpow(w1, 2);
                        constexpr ModIntType w3 = qpow(w1, 3);
                        constexpr ModIntType w6 = qpow(w1, 6);
                        constexpr ModIntType w9 = qpow(w1, 9);
                        NTT4::dit(in_out + 12);
                        NTT4::dit(in_out + 8);
                        NTT8::dit(in_out);

                        ModIntType temp0, temp1, temp2, temp3;
                        temp2 = in_out[8];
                        temp3 = in_out[12];
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp0 = in_out[0];
                        temp1 = in_out[4];
                        in_out[0] = temp0 + temp2;
                        in_out[4] = temp1 + temp3;
                        in_out[8] = temp0 - temp2;
                        in_out[12] = temp1 - temp3;

                        temp2 = in_out[9] * w1;
                        temp3 = in_out[13] * w3;
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp0 = in_out[1];
                        temp1 = in_out[5];
                        in_out[1] = temp0 + temp2;
                        in_out[5] = temp1 + temp3;
                        in_out[9] = temp0 - temp2;
                        in_out[13] = temp1 - temp3;

                        temp2 = in_out[10] * w2;
                        temp3 = in_out[14] * w6;
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp0 = in_out[2];
                        temp1 = in_out[6];
                        in_out[2] = temp0 + temp2;
                        in_out[6] = temp1 + temp3;
                        in_out[10] = temp0 - temp2;
                        in_out[14] = temp1 - temp3;

                        temp2 = in_out[11] * w3;
                        temp3 = in_out[15] * w9;
                        transform2(temp2, temp3);
                        temp3 = mul_w41<ROOT>(temp3);
                        temp0 = in_out[3];
                        temp1 = in_out[7];
                        in_out[3] = temp0 + temp2;
                        in_out[7] = temp1 + temp3;
                        in_out[11] = temp0 - temp2;
                        in_out[15] = temp1 - temp3;
                    }
                    static void dif(ModIntType in_out[])
                    {
                        constexpr ModIntType w1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 16);
                        constexpr ModIntType w2 = qpow(w1, 2);
                        constexpr ModIntType w3 = qpow(w1, 3);
                        constexpr ModIntType w6 = qpow(w1, 6);
                        constexpr ModIntType w9 = qpow(w1, 9);
                        ModIntType temp0, temp1, temp2, temp3;
                        temp0 = in_out[0];
                        temp1 = in_out[4];
                        temp2 = in_out[8];
                        temp3 = in_out[12];
                        in_out[0] = temp0 + temp2;
                        in_out[4] = temp1 + temp3;
                        temp2 = temp0 - temp2;
                        temp3 = temp1 - temp3;
                        temp3 = mul_w41<ROOT>(temp3);
                        transform2(temp2, temp3);
                        in_out[8] = temp2;
                        in_out[12] = temp3;

                        temp0 = in_out[1];
                        temp1 = in_out[5];
                        temp2 = in_out[9];
                        temp3 = in_out[13];
                        in_out[1] = temp0 + temp2;
                        in_out[5] = temp1 + temp3;
                        temp2 = temp0 - temp2;
                        temp3 = temp1 - temp3;
                        temp3 = mul_w41<ROOT>(temp3);
                        transform2(temp2, temp3);
                        in_out[9] = temp2 * w1;
                        in_out[13] = temp3 * w3;

                        temp0 = in_out[2];
                        temp1 = in_out[6];
                        temp2 = in_out[10];
                        temp3 = in_out[14];
                        in_out[2] = temp0 + temp2;
                        in_out[6] = temp1 + temp3;
                        temp2 = temp0 - temp2;
                        temp3 = temp1 - temp3;
                        temp3 = mul_w41<ROOT>(temp3);
                        transform2(temp2, temp3);
                        in_out[10] = temp2 * w2;
                        in_out[14] = temp3 * w6;

                        temp0 = in_out[3];
                        temp1 = in_out[7];
                        temp2 = in_out[11];
                        temp3 = in_out[15];
                        in_out[3] = temp0 + temp2;
                        in_out[7] = temp1 + temp3;
                        temp2 = temp0 - temp2;
                        temp3 = temp1 - temp3;
                        temp3 = mul_w41<ROOT>(temp3);
                        transform2(temp2, temp3);
                        in_out[11] = temp2 * w3;
                        in_out[15] = temp3 * w9;

                        NTT8::dif(in_out);
                        NTT4::dif(in_out + 8);
                        NTT4::dif(in_out + 12);
                    }
                    static void dit(ModIntType in_out[], size_t len)
                    {
                        if (len < 16)
                        {
                            NTTShort<8, MOD, ROOT>::dit(in_out, len);
                            return;
                        }
                        dit(in_out);
                    }
                    static void dif(ModIntType in_out[], size_t len)
                    {
                        if (len < 16)
                        {
                            NTTShort<8, MOD, ROOT>::dif(in_out, len);
                            return;
                        }
                        dif(in_out);
                    }
                };
            };
        }
    }
}

using namespace std;
using namespace hint;
using namespace hint_transform;
// using namespace hint_ntt::split_radix_avx;

// template <typename T>
// vector<T> poly_multiply(const vector<T> &in1, const vector<T> &in2)
// {
//     size_t len1 = in1.size(), len2 = in2.size(), out_len = len1 + len2;
//     vector<T> result(out_len);
//     size_t ntt_len = int_ceil2(out_len);

//     using ntt = NTTShort<1 << 23, 998244353, 3>;
//     using intt = NTTShort<1 << 23, 998244353, 3>;

//     auto mod_ary1 = new ntt::NTTModInt32[ntt_len]();
//     auto mod_ary2 = new ntt::NTTModInt32[ntt_len]();

//     for (size_t i = 0; i < len1; i++)
//     {
//         mod_ary1[i] = in1[i];
//     }
//     for (size_t i = 0; i < len2; i++)
//     {
//         mod_ary2[i] = in2[i];
//     }
//     ntt::ntt_dif(mod_ary1, ntt_len);
//     ntt::ntt_dif(mod_ary2, ntt_len);
//     ary_mul(mod_ary1, mod_ary2, mod_ary1, ntt_len);
//     intt::ntt_dit(mod_ary1, ntt_len);
//     intt::ntt_basic::ntt_normalize(mod_ary1, ntt_len);
//     for (size_t i = 0; i < out_len; i++)
//     {
//         result[i] = mod_ary1[i].data;
//     }
//     delete[] mod_ary1;
//     delete[] mod_ary2;
//     return result;
// }

// template <uint32_t MOD, uint32_t G_ROOT>
// void poly_inv(uint32_t *in, uint32_t *out, size_t len)
// {
//     using ntt = NTT<MOD, G_ROOT, 1 << 24>;
//     using intt = typename ntt::intt;
//     using NttInt = typename ntt::NTTModInt32;
//     std::vector<NttInt> ntt_ary(len * 2);
//     auto in_ntt = ntt_ary.data();
//     auto out_ntt = reinterpret_cast<NttInt *>(out);
//     out[0] = mod_inv(in[0], MOD);
//     for (size_t rank = 2; rank <= len; rank *= 2)
//     {
//         size_t gap = rank * 2;
//         std::copy(in, in + rank, in_ntt);
//         std::fill(in_ntt + rank, in_ntt + gap, 0);
//         std::fill(out + rank / 2, out + gap, 0);
//         // std::cout << gap << "\n";
//         auto t1 = std::chrono::high_resolution_clock::now();
//         ntt::ntt_dif(in_ntt, gap);
//         ntt::ntt_dif(out_ntt, gap);
//         for (size_t i = 0; i < gap; i++)
//         {
//             uint32_t a = in_ntt[i].data, b = out[i];
//             out[i] = MOD - ((b * b % MOD) * a - b * 2 + MOD) % MOD;
//         }
//         intt::ntt_dit(out_ntt, gap);
//         auto t2 = std::chrono::high_resolution_clock::now();
//         // std::cout << "ntt time: " << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us\n";
//         uint32_t inv = mod_inv(gap, MOD);
//         for (size_t i = 0; i < rank; i++)
//         {
//             out[i] = out[i] * inv % MOD;
//         }
//     }
// }

// template <typename T>
// void result_test(const vector<T> &res, uint32_t ele)
// {
//     size_t len = res.size();
//     for (size_t i = 0; i < len / 2; i++)
//     {
//         uint64_t x = (i + 1) * ele * ele;
//         uint64_t y = res[i];
//         if (x != y)
//         {
//             cout << "fail:" << i << "\t" << (i + 1) * ele * ele << "\t" << y << "\n";
//             return;
//         }
//     }
//     for (size_t i = len / 2; i < len; i++)
//     {
//         uint64_t x = (len - i - 1) * ele * ele;
//         uint64_t y = res[i];
//         if (x != y)
//         {
//             cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
//             return;
//         }
//     }
//     std::cout << "success\n";
// }

// int main()
// {
//     StopWatch w(1000);
//     int n = 18;
//     cin >> n;
//     size_t len = 1 << n; // 变换长度
//     cout << "fft len:" << len << "\n";
//     uint64_t ele = 5;
//     vector<uint32_t> in1(len / 2, ele);
//     vector<uint32_t> in2(len / 2, ele); // 计算两个长度为len/2,每个元素为ele的卷积
//     w.start();
//     vector<uint32_t> res = poly_multiply(in1, in2);
//     // poly_inv<998244353, 3>(in1.data(), in2.data(), 1 << 16);
//     w.stop();
//     result_test(res, ele); // 结果校验
//     cout << w.duration() << "ms\n";
// }

template <uint64_t ROOT, typename ModInt>
void ntt_dit(ModInt in_out[], size_t ntt_len)
{
    for (size_t rank = 2; rank <= ntt_len; rank *= 2)
    {
        ModInt unit_omega = hint::qpow(ModInt(ROOT), (ModInt::mod() - 1) / rank);
        size_t dis = rank / 2;
        for (auto begin = in_out; begin < in_out + ntt_len; begin += rank)
        {
            ModInt omega = 1;
            for (auto p = begin; p < begin + dis; p++)
            {
                auto temp0 = p[0], temp1 = p[dis] * omega;
                p[0] = temp0 + temp1;
                p[dis] = temp0 - temp1;
                omega = omega * unit_omega;
            }
        }
    }
}

template <uint64_t ROOT, typename ModInt>
void ntt_dif(ModInt in_out[], size_t ntt_len)
{
    for (size_t rank = ntt_len; rank >= 2; rank /= 2)
    {
        ModInt unit_omega = hint::qpow(ModInt(ROOT), (ModInt::mod() - 1) / rank);
        size_t dis = rank / 2;
        for (auto begin = in_out; begin < in_out + ntt_len; begin += rank)
        {
            ModInt omega = 1;
            for (auto p = begin; p < begin + dis; p++)
            {
                auto temp0 = p[0], temp1 = p[dis];
                p[0] = temp0 + temp1;
                p[dis] = (temp0 - temp1) * omega;
                omega = omega * unit_omega;
            }
        }
    }
}

void avx2_test()
{
    using namespace hint;
    using namespace hint_transform;
    using namespace hint_ntt;
    constexpr size_t len = 1 << 5;
    constexpr uint32_t mod = 998244353;
    using NTTX8 = MontInt32X8<mod>;
    using ModInt = typename NTTX8::MontInt;
    alignas(64) static ModInt a[len];
    for (size_t i = 0; i < len; i++)
    {
        a[i] = i;
        // b[i] = i;
    }
    size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
    auto t1 = std::chrono::steady_clock::now();
    for (size_t i = 0; i < times; i++)
    {
        // ntt::dif(a);
        NTTX8 x;
        x.loadu(a);
        x = x * x;
        x.storeu(a);
        // ntt::dif(b.data(), len);
        // ntt::dit(b.data(), len);
    }
    auto t2 = std::chrono::steady_clock::now();
    auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
    for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
    {
        std::cout << i << ":\t" << uint32_t(a[i]) << "\n";
    }
    std::cout << time1 << "\n";
}

void ntt_check()
{
    using namespace hint;
    using namespace hint_transform;
    using namespace hint_ntt;
    using namespace split_radix_avx;
    constexpr size_t len = 1 << 23;
    constexpr uint32_t mod = 469762049, root = 3;
    using ntt = NTTShort<len, mod, root>;
    using ModInt = ntt::ModIntType;
    using ModInt1 = ModInt32<mod>;
    using NTTX8 = ntt::ModIntX8;
    static AlignAry<ModInt, len> a;
    static AlignAry<ModInt, len> b;
    for (size_t i = 0; i < len; i++)
    {
        a[i] = i;
        b[i] = i;
    }
    size_t times = 1; // std::max<size_t>(1, (1 << 25) / len);
    auto t1 = std::chrono::steady_clock::now();
    for (size_t i = 0; i < times; i++)
    {
        ntt::dif(a.data());
        ntt::dif(a.data());
        ntt::dit(a.data());
        // ntt::dif(b.data(), len);
        // ntt::dit(b.data(), len);
    }
    auto t2 = std::chrono::steady_clock::now();
    for (size_t i = 0; i < times; i++)
    {
        ntt_dif<root>(b.data(), len);
        ntt_dif<root>(b.data(), len);
        ntt_dit<root>(b.data(), len);
        // ntt_dit<root>(b, len);
        // ntt_dif<root>(a.data(), len);
        // ntt_dit<root>(a.data(), len);
    }
    auto t3 = std::chrono::steady_clock::now();
    auto time1 = std::chrono::duration_cast<std::chrono::duration<double>>(t2 - t1).count();
    auto time2 = std::chrono::duration_cast<std::chrono::duration<double>>(t3 - t2).count();
    for (size_t i = 0; i < std::min<size_t>(len, 1024); i++)
    {
        if (uint32_t(a[i]) != uint32_t(b[i]))
        {
            std::cout << i << ":\t" << uint32_t(a[i].data) << "\t" << uint32_t(b[i].data) << "\n";
            std::cout << i << ":\t" << uint32_t(a[i]) << "\t" << uint32_t(b[i]) << "\n";
            return;
        }
    }
    std::cout << time1 << "\t" << time2 << "\t" << time2 / time1 << "\n";
}
int main()
{
    ntt_check();
    // avx2_test();
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #150.671 ms96 MB + 92 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-07-18 14:24:39 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠