提交记录 27948


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002. 测测你的多项式乘法 Compile Error 0 0 ns 0 KB C++14 31.45 KB
提交时间 评测时间
2025-02-19 16:08:21 2025-02-19 16:08:22
#include <vector>
#include <complex>
#include <iostream>
#include <cassert>
#include <cstring>
#include <ctime>
#include <cstddef>
#include <cstdint>
#include <climits>
#include <string>
#include <array>
#include <fstream>
#include <type_traits>
#include <immintrin.h>
#pragma GCC optimize("inline")
#pragma GCC target("avx2")

namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;
    constexpr size_t L1_BYTE = size_t(1) << 24; // L1 cache size, change this if you know your cache size.
    constexpr size_t L2_BYTE = size_t(1) << 20; // L2 cache size, change this if you know your cache size.

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    // bits of 1, equals to 2^bits - 1
    template <typename T>
    constexpr T all_one(int bits)
    {
        T temp = T(1) << (bits - 1);
        return temp - 1 + temp;
    }

    // Leading zeros
    template <typename IntTy>
    constexpr int hint_clz(IntTy x)
    {
        constexpr uint32_t MASK32 = uint32_t(0xFFFF) << 16;
        int res = sizeof(IntTy) * CHAR_BIT;
        if (x & MASK32)
        {
            res -= 16;
            x >>= 16;
        }
        if (x & (MASK32 >> 8))
        {
            res -= 8;
            x >>= 8;
        }
        if (x & (MASK32 >> 12))
        {
            res -= 4;
            x >>= 4;
        }
        if (x & (MASK32 >> 14))
        {
            res -= 2;
            x >>= 2;
        }
        if (x & (MASK32 >> 15))
        {
            res -= 1;
            x >>= 1;
        }
        return res - x;
    }
    // Leading zeros
    constexpr int hint_clz(uint64_t x)
    {
        if (x & (uint64_t(0xFFFFFFFF) << 32))
        {
            return hint_clz(uint32_t(x >> 32));
        }
        return hint_clz(uint32_t(x)) + 32;
    }

    // Integer bit length
    template <typename IntTy>
    constexpr int hint_bit_length(IntTy x)
    {
        if (x == 0)
        {
            return 0;
        }
        return sizeof(IntTy) * CHAR_BIT - hint_clz(x);
    }

    // Integer log2
    template <typename IntTy>
    constexpr int hint_log2(IntTy x)
    {
        return (sizeof(IntTy) * CHAR_BIT - 1) - hint_clz(x);
    }

    constexpr int hint_ctz(uint32_t x)
    {
        int r = 31;
        x &= (-x);
        if (x & 0x0000FFFF)
        {
            r -= 16;
        }
        if (x & 0x00FF00FF)
        {
            r -= 8;
        }
        if (x & 0x0F0F0F0F)
        {
            r -= 4;
        }
        if (x & 0x33333333)
        {
            r -= 2;
        }
        if (x & 0x55555555)
        {
            r -= 1;
        }
        return r;
    }

    constexpr int hint_ctz(uint64_t x)
    {
        if (x & 0xFFFFFFFF)
        {
            return hint_ctz(uint32_t(x));
        }
        return hint_ctz(uint32_t(x >> 32)) + 32;
    }

    // Fast power
    template <typename T, typename T1>
    constexpr T qpow(T m, T1 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result *= m;
            }
            m *= m;
            n >>= 1;
        }
        return result;
    }

    // Fast power with mod
    template <typename T, typename T1>
    constexpr T qpow(T m, T1 n, T mod)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result *= m;
                result %= mod;
            }
            m *= m;
            m %= mod;
            n >>= 1;
        }
        return result;
    }

    // Get cloest power of 2 that not larger than n
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * CHAR_BIT;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    // Get cloest power of 2 that not smaller than n
    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * CHAR_BIT;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    // x + y = sum with carry
    template <typename UintTy>
    constexpr UintTy add_half(UintTy x, UintTy y, bool &cf)
    {
        x = x + y;
        cf = (x < y);
        return x;
    }

    // x - y = diff with borrow
    template <typename UintTy>
    constexpr UintTy sub_half(UintTy x, UintTy y, bool &bf)
    {
        y = x - y;
        bf = (y > x);
        return y;
    }

    // x + y + cf = sum with carry
    template <typename UintTy>
    constexpr UintTy add_carry(UintTy x, UintTy y, bool &cf)
    {
        UintTy sum = x + cf;
        cf = (sum < x);
        sum += y;             // carry
        cf = cf || (sum < y); // carry
        return sum;
    }

    // x - y - bf = diff with borrow
    template <typename UintTy>
    constexpr UintTy sub_borrow(UintTy x, UintTy y, bool &bf)
    {
        UintTy diff = x - bf;
        bf = (diff > x);
        y = diff - y;          // borrow
        bf = bf || (y > diff); // borrow
        return y;
    }

    // a * x + b * y = gcd(a,b)
    template <typename IntTy>
    constexpr IntTy exgcd(IntTy a, IntTy b, IntTy &x, IntTy &y)
    {
        if (b == 0)
        {
            x = 1;
            y = 0;
            return a;
        }
        IntTy k = a / b;
        IntTy g = exgcd(b, a - k * b, y, x);
        y -= k * x;
        return g;
    }

    // return n^-1 mod mod
    template <typename IntTy>
    constexpr IntTy mod_inv(IntTy n, IntTy mod)
    {
        n %= mod;
        IntTy x = 0, y = 0;
        exgcd(n, mod, x, y);
        if (x < 0)
        {
            x += mod;
        }
        else if (x >= mod)
        {
            x -= mod;
        }
        return x;
    }

    // return n^-1 mod 2^pow, Newton iteration
    constexpr uint64_t inv_mod2pow(uint64_t n, int pow)
    {
        const uint64_t mask = all_one<uint64_t>(pow);
        uint64_t xn = 1, t = n & mask;
        while (t != 1)
        {
            xn = (xn * (2 - t));
            t = (xn * n) & mask;
        }
        return xn & mask;
    }
    namespace simd
    {
        class Int256
        {
        public:
            Int256() : data(_mm256_setzero_si256()) {}
            Int256(__m256i data) : data(data) {}
            Int256(int data) : data(_mm256_set1_epi32(data)) {}

            Int256 add32(Int256 input) const
            {
                return _mm256_add_epi32(data, input.data);
            }
            Int256 sub32(Int256 input) const
            {
                return _mm256_sub_epi32(data, input.data);
            }
            Int256 add64(Int256 input) const
            {
                return _mm256_add_epi64(data, input.data);
            }
            Int256 sub64(Int256 input) const
            {
                return _mm256_sub_epi64(data, input.data);
            }
            Int256 minU32(Int256 input) const
            {
                return _mm256_min_epu32(data, input.data);
            }
            Int256 maxU32(Int256 input) const
            {
                return _mm256_max_epu32(data, input.data);
            }
            Int256 minI32(Int256 input) const
            {
                return _mm256_min_epi32(data, input.data);
            }
            Int256 maxI32(Int256 input) const
            {
                return _mm256_max_epi32(data, input.data);
            }
            Int256 mullo32To64(Int256 input) const
            {
                return _mm256_mul_epu32(data, input.data);
            }
            Int256 evenEle32() const
            {
                return blend32<0b10101010>(*this, Int256{});
            }
            template <int N>
            Int256 lShift64() const
            {
                return _mm256_slli_epi64(data, N);
            }
            template <int N>
            Int256 rShift64() const
            {
                return _mm256_srli_epi64(data, N);
            }
            template <int M>
            static Int256 blend32(Int256 a, Int256 b)
            {
                return _mm256_blend_epi32(a.data, b.data, M);
            }
            template <typename T>
            void loadu(const T *p)
            {
                data = _mm256_loadu_si256((const __m256i *)p);
            }
            template <typename T>
            void load(const T *p)
            {
                data = _mm256_load_si256((const __m256i *)p);
            }
            template <typename T>
            void storeu(T *p) const
            {
                _mm256_storeu_si256((__m256i *)p, data);
            }
            template <typename T>
            void store(T *p) const
            {
                _mm256_store_si256((__m256i *)p, data);
            }
            operator __m256i() const
            {
                return data;
            }
            uint32_t nthU32(size_t i) const
            {
                return _mm256_extract_epi32(data, i);
            }
            uint64_t nthU64(size_t i) const
            {
                return _mm256_extract_epi64(data, i);
            }
            void printU32() const
            {
                std::cout << "[" << nthU32(0) << "," << nthU32(1)
                          << "," << nthU32(2) << "," << nthU32(3)
                          << "," << nthU32(4) << "," << nthU32(5)
                          << "," << nthU32(6) << "," << nthU32(7) << "]" << std::endl;
            }
            void printU64() const
            {
                std::cout << "[" << nthU64(0) << "," << nthU64(1)
                          << "," << nthU64(2) << "," << nthU64(3) << "]" << std::endl;
            }

        private:
            __m256i data;
        };
    }
    namespace modint
    {
        class Montgomery32
        {
        public:
            using Ui32X8 = hint::simd::Int256;
            Montgomery32(uint32_t m) : mod(m), modx8(m), mod2x8(m * 2)
            {
                uint32_t inv = inv_mod2pow(mod, 32);
                r = (uint64_t(1) << 32) % m;
                r2 = uint64_t(r) * r % m;
                ninv = (uint64_t(1) << 32) - inv;
                assert(inv * mod == 1);
                assert(inv + ninv == 0);
                r2x8 = Ui32X8(r2);
                ninvx8 = Ui32X8(ninv);
            }

            uint32_t toMontgomery(uint32_t x) const
            {
                return redcLazy(uint64_t(x) * r2);
            }
            uint32_t fromMontgomery(uint32_t x) const
            {
                return redc(x);
            }
            uint32_t add(uint32_t x, uint32_t y) const
            {
                return x + y;
            }
            uint32_t sub(uint32_t x, uint32_t y) const
            {
                return x - y + mod * 2;
            }
            uint32_t addNorm2(uint32_t x, uint32_t y) const
            {
                return norm2(x + y);
            }
            uint32_t subNorm2(uint32_t x, uint32_t y) const
            {
                y = x - y;
                return y > x ? y + mod * 2 : y;
            }
            uint32_t norm(uint32_t x) const
            {
                return x >= mod ? x - mod : x;
            }
            uint32_t norm2(uint32_t x) const
            {
                return x >= mod * 2 ? x - mod * 2 : x;
            }
            uint32_t mul(uint32_t x, uint32_t y) const
            {
                return redcLazy(uint64_t(x) * y);
            }
            uint32_t mulNorm(uint32_t x, uint32_t y) const
            {
                return redc(uint64_t(x) * y);
            }
            uint32_t redcLazy(uint64_t x) const
            {
                uint32_t prod = uint32_t(x) * ninv;
                return (uint64_t(prod) * mod + x) >> 32;
            }
            uint32_t redc(uint64_t x) const
            {
                return norm(redcLazy(x));
            }
            uint32_t inv(uint32_t x) const
            {
                return pow(x, mod - 2);
            }
            template <typename T, typename Ti>
            T pow(T x, Ti index) const
            {
                T res = montOne();
                while (true)
                {
                    if (index & 1)
                    {
                        res = mul(res, x);
                    }
                    index >>= 1;
                    if (index == 0)
                    {
                        break;
                    }
                    x = mul(x, x);
                }
                return res;
            }

            uint32_t montOne() const
            {
                return r;
            }
            uint32_t montR() const
            {
                return r2;
            }
            uint32_t getMod() const
            {
                return mod;
            }

            Ui32X8 toMontgomery(Ui32X8 x) const
            {
                return mul(x, r2x8);
            }
            Ui32X8 fromMontgomery(Ui32X8 x) const
            {
                return redc(x.evenEle32(), x.rShift64<32>());
            }
            Ui32X8 add(Ui32X8 x, Ui32X8 y) const
            {
                return x.add32(y);
            }
            Ui32X8 sub(Ui32X8 x, Ui32X8 y) const
            {
                return x.sub32(y).add32(mod2x8);
            }
            Ui32X8 addNorm2(Ui32X8 x, Ui32X8 y) const
            {
                return norm2(x.add32(y));
            }
            Ui32X8 subNorm2(Ui32X8 x, Ui32X8 y) const
            {
                return negNorm2(x.sub32(y));
            }
            Ui32X8 norm(Ui32X8 x) const
            {
                Ui32X8 dif = x.sub32(modx8);
                return x.minU32(dif);
            }
            Ui32X8 norm2(Ui32X8 x) const
            {
                Ui32X8 dif = x.sub32(mod2x8);
                return x.minU32(dif);
            }
            Ui32X8 negNorm2(Ui32X8 x) const
            {
                Ui32X8 sum = x.add32(mod2x8);
                return x.minU32(sum);
            }
            Ui32X8 mul(Ui32X8 x, Ui32X8 y) const
            {
                Ui32X8 prodo = x.mullo32To64(y);
                Ui32X8 prode = x.rShift64<32>().mullo32To64(y.rShift64<32>());
                return redcLazy(prodo, prode);
            }
            Ui32X8 redcLazy(Ui32X8 e, Ui32X8 o) const
            {
                Ui32X8 prod0 = e.mullo32To64(ninvx8);
                Ui32X8 prod1 = o.mullo32To64(ninvx8);
                prod0 = prod0.mullo32To64(modx8).add64(e);
                prod1 = prod1.mullo32To64(modx8).add64(o);
                prod0 = prod0.rShift64<32>();
                return Ui32X8::blend32<0b10101010>(prod0, prod1);
            }
            Ui32X8 redc(Ui32X8 e, Ui32X8 o) const
            {
                return norm(redcLazy(e, o));
            }

        private:
            Ui32X8 modx8, ninvx8, mod2x8, r2x8;
            uint32_t mod, ninv, r, r2;
        };
    }
    namespace transform
    {
        namespace ntt
        {
            using namespace simd;
            using namespace modint;
            namespace radix2_avx
            {
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w41(const T &n)
                {
                    constexpr ModIntType W_4_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 4);
                    return n * T(W_4_1);
                }
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w81(const T &n)
                {
                    constexpr ModIntType W_8_1 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8);
                    return n * T(W_8_1);
                }
                template <uint32_t ROOT, typename ModIntType, typename T>
                inline T mul_w83(const T &n)
                {
                    constexpr ModIntType W_8_3 = qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / 8 * 3);
                    return n * T(W_8_3);
                }

                // in: in_out0<4p, in_ou1<4p
                // out: in_out0<4p, in_ou1<4p
                template <typename ModIntType>
                inline void dit_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
                {
                    auto x = in_out0.largeNorm();
                    auto y = in_out1 * omega;
                    in_out0 = x.add(y);
                    in_out1 = x.sub(y);
                }

                // in: in_out0<2p, in_ou1<4p
                // out: in_out0<4p, in_ou1<4p
                template <typename ModIntType>
                inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::false_type)
                {
                    auto x = in_out0;
                    auto y = in_out1 * omega;
                    in_out0 = x.add(y);
                    in_out1 = x.sub(y);
                }

                // in: in_out0<2p, in_ou1<4p
                // out: in_out0<2p, in_ou1<2p
                template <typename ModIntType>
                inline void dit_butterfly2_i24(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega, std::true_type)
                {
                    auto x = in_out0;
                    auto y = in_out1 * omega;
                    in_out0 = x + y;
                    in_out1 = x - y;
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p or 4p, in_ou1<2p or 4p, in_out2<2p or 4p, in_ou3<2p or 4p
                template <bool OUT2P, typename ModIntType>
                inline void dit_butterfly2_i2424_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                        ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    dit_butterfly2_i24(in_out0, in_out2, omega0, std::integral_constant<bool, OUT2P>{});
                    dit_butterfly2_i24(in_out1, in_out3, omega1, std::integral_constant<bool, OUT2P>{});
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
                template <typename ModIntType>
                inline void dit_butterfly2_2layer_out3(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                       ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    dit_butterfly2_i24(in_out0, in_out2, omega0, std::true_type{});
                    in_out1 = in_out1 + in_out3 * omega1;
                }

                // in: in_out0<2p, in_ou1<4p, in_out2<2p, in_ou3<4p
                // out: in_out0<2p, in_ou1<2p , in_out2<2p , in_ou3<2p
                template <typename ModIntType>
                inline void dit_butterfly2_2layer_out2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                       ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dit_butterfly2_i24(in_out0, in_out1, omega_last, std::true_type{});
                    dit_butterfly2_i24(in_out2, in_out3, omega_last, std::false_type{});
                    in_out0 = in_out0 + in_out2 * omega0;
                    in_out1 = in_out1 + in_out3 * omega1;
                }

                // in: in_out0<2p, in_ou1<2p
                // out: in_out0<2p, in_ou1<2p
                template <typename ModIntType>
                inline void dif_butterfly2(ModIntType &in_out0, ModIntType &in_out1, ModIntType omega)
                {
                    auto x = in_out0.add(in_out1);
                    auto y = in_out0.sub(in_out1);
                    in_out0 = x.largeNorm();
                    in_out1 = y * omega;
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                  ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    dif_butterfly2(in_out0, in_out2, omega0);
                    dif_butterfly2(in_out1, in_out3, omega1);
                    dif_butterfly2(in_out0, in_out1, omega_last);
                    dif_butterfly2(in_out2, in_out3, omega_last);
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer_in2(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                      ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    in_out2 = in_out0 * omega0;
                    in_out3 = in_out1 * omega1;
                    dif_butterfly2(in_out0, in_out1, omega_last);
                    dif_butterfly2(in_out2, in_out3, omega_last);
                }

                // in: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                // out: in_out0<2p, in_ou1<2p, in_out2<2p, in_out3<2p
                template <typename ModIntType>
                inline void dif_butterfly2_2layer_in1(ModIntType &in_out0, ModIntType &in_out1, ModIntType &in_out2, ModIntType &in_out3,
                                                      ModIntType omega0, ModIntType omega1, ModIntType omega_last)
                {
                    in_out2 = in_out0 * omega0;
                    in_out1 = in_out0 * omega_last;
                    in_out3 = in_out2 * omega_last;
                }

                // template <typename ModIntType, uint32_t ROOT>
                // static auto omegax8(size_t ntt_len, int factor, size_t begin = 0, bool inv = false)
                // {
                //     using ModIntX8 = MontInt32X8<ModIntType>;
                //     alignas(32) ModIntType w_arr[8]{};
                //     ModIntType unit(qpow(ModIntType(ROOT), (ModIntType::mod() - 1) / ntt_len * factor));
                //     if (inv)
                //     {
                //         unit = unit.inv();
                //     }
                //     ModIntType w(qpow(unit, begin));
                //     for (auto &&i : w_arr)
                //     {
                //         i = w;
                //         w = w * unit;
                //     }
                //     return ModIntX8(w_arr);
                // }

                struct NTT32AVX
                {
                    using Int = uint32_t;
                    static constexpr size_t L1_LEN = L1_BYTE / (2 * sizeof(Int));
                    static constexpr size_t L2_LEN = L2_BYTE / (2 * sizeof(Int));
                    static constexpr int LOG_L1_LEN = hint_log2(L1_LEN);

                    template <int LOG_LEN>
                    class BinRevTable
                    {
                    public:
                        static constexpr size_t LEN = size_t(1) << LOG_LEN;
                        static constexpr size_t TABLE_LEN = LEN / 2;

                        BinRevTable(Int root, const Montgomery32 &mont, bool init_all) : cur_len(2)
                        {
                            table[0] = mont.montOne();
                            root = mont.toMontgomery(root);
                            for (size_t i = 1; i < TABLE_LEN; i *= 2)
                            {
                                table[i] = getOmega(LEN, LEN / 4 / i, root, mont);
                            }
                            if (init_all)
                            {
                                init(LEN, mont);
                            }
                        }

                        void init(size_t len, const Montgomery32 &mont)
                        {
                            size_t n = std::min(len, LEN) / 2;
                            for (size_t begin = cur_len; begin < n; begin *= 2)
                            {
                                Int unit = table[begin];
                                for (size_t i = begin + 1; i < begin * 2; i++)
                                {
                                    table[i] = mont.mul(unit, table[i - begin]);
                                }
                            }
                            cur_len = n;
                        }

                        Int getRevOmega(size_t i) const
                        {
                            return table[i];
                        }

                        Int getOmega(size_t n, size_t index, uint32_t root, const Montgomery32 &mont)
                        {
                            return mont.pow(root, (mont.getMod() - 1) / n * index);
                        }

                    private:
                        Int table[TABLE_LEN];
                        size_t cur_len;
                    };

                    Montgomery32 mont;
                    BinRevTable<LOG_L1_LEN> table;
                    BinRevTable<LOG_L1_LEN> itable;

                    NTT32AVX(Int root, Int mod, bool init_all = true) : mont(mod), table(root, mont, init_all), itable(mod_inv<int64_t>(root, mod), mont, init_all) {}

                    void dit(Int in_out[], size_t ntt_len)
                    {
                        assert(ntt_len <= L1_LEN);
                        itable.init(ntt_len, mont);
                        for (size_t rank = 2; rank < ntt_len; rank *= 2)
                        {
                            size_t gap = rank / 2, omega_index = 1;
                            for (size_t i = 0; i < gap; i++)
                            {
                                Int x = in_out[i], y = in_out[gap + i];
                                in_out[i] = mont.addNorm2(x, y);
                                in_out[gap + i] = mont.subNorm2(x, y);
                            }
                            for (auto it = in_out + rank; it < in_out + ntt_len; it += rank, omega_index++)
                            {
                                const Int omega = itable.getRevOmega(omega_index);
                                for (size_t j = 0; j < gap; j++)
                                {
                                    Int x = it[j], y = it[gap + j];
                                    it[j] = mont.addNorm2(x, y);
                                    it[gap + j] = mont.mul(mont.subNorm2(x, y), omega);
                                }
                            }
                        }
                        for (size_t i = 0; i < ntt_len / 2; i++)
                        {
                            Int x = in_out[i], y = in_out[ntt_len / 2 + i];
                            in_out[i] = mont.norm(mont.addNorm2(x, y));
                            in_out[ntt_len / 2 + i] = mont.norm(mont.subNorm2(x, y));
                        }
                    }

                    void dif(Int in_out[], size_t ntt_len)
                    {
                        assert(ntt_len <= L1_LEN);
                        table.init(ntt_len, mont);
                        for (size_t rank = ntt_len; rank >= 2; rank /= 2)
                        {
                            size_t gap = rank / 2, omega_index = 1;
                            for (size_t i = 0; i < gap; i++)
                            {
                                Int x = in_out[i], y = in_out[gap + i];
                                in_out[i] = mont.addNorm2(x, y);
                                in_out[gap + i] = mont.subNorm2(x, y);
                            }
                            for (auto it = in_out + rank; it < in_out + ntt_len; it += rank, omega_index++)
                            {
                                const Int omega = table.getRevOmega(omega_index);
                                for (size_t j = 0; j < gap; j++)
                                {
                                    Int x = it[j], y = mont.mul(it[gap + j], omega);
                                    it[j] = mont.addNorm2(x, y);
                                    it[gap + j] = mont.subNorm2(x, y);
                                }
                            }
                        }
                    }

                    // void difL1X2(Int in_out1[], Int in_out2[], size_t ntt_len)
                    // {
                    // }
                    // void ditL1(Int in_out[], size_t ntt_len, size_t rank)
                    // {
                    //     for (; rank <= ntt_len; rank *= 4)
                    //     {
                    //         size_t gap = rank / 4;
                    //         for (size_t i = 0; i < ntt_len; i += rank)
                    //         {
                    //         }
                    //     }
                    // }
                    // void conv32(Int in1_out[], Int in2[])
                    // {
                    // }
                    // void conv64(Int in1_out[], Int in2[])
                    // {
                    // }
                    // void convL1(Int in1_out[], Int in2[], size_t ntt_len)
                    // {
                    //     assert(ntt_len <= L1_LEN && ntt_len >= 32);
                    //     difL1X2(in1_out, in2, ntt_len);
                    //     conv32(in1_out, in2);
                    //     ditL1(in1_out, ntt_len);
                    // }
                };
            }
        }
    }
}

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    using namespace hint;
    using namespace transform;
    using namespace ntt::radix2_avx;
    static NTT32AVX ntt(3, 998244353, false);
    const size_t conv_len = n + m + 1, ntt_len = int_ceil2(conv_len);
    auto ntt_a = new uint32_t[ntt_len];
    auto ntt_b = new uint32_t[ntt_len];
    std::memcpy(ntt_a, a, (n + 1) * sizeof(unsigned));
    std::memcpy(ntt_b, b, (m + 1) * sizeof(unsigned));
    std::memset(ntt_a + n + 1, 0, (ntt_len - n - 1) * sizeof(unsigned));
    std::memset(ntt_b + m + 1, 0, (ntt_len - m - 1) * sizeof(unsigned));
    ntt.dif(ntt_a, ntt_len);
    ntt.dif(ntt_b, ntt_len);
    uint32_t len_inv_r = ntt.mont.toMontgomery(ntt_len);
    len_inv_r = ntt.mont.inv(len_inv_r);
    len_inv_r = ntt.mont.mul(len_inv_r, ntt.mont.montR());
    for (size_t i = 0; i < ntt_len; i++)
    {
        uint32_t n = ntt.mont.mul(ntt_a[i], ntt_b[i]);
        ntt_a[i] = ntt.mont.mul(n, len_inv_r);
    }
    ntt.dit(ntt_a, ntt_len);
    std::memcpy(c, ntt_a, conv_len * sizeof(uint32_t));
    delete[] ntt_a;
    delete[] ntt_b;
}

#include "stopwatch.hpp"
void test_convolution()
{
    int m, n;
    // std::cin >> m >> n;
    int len1 = 1 << 22, len2 = len1;
    unsigned *a = new unsigned[len1];
    unsigned *b = new unsigned[len2];
    unsigned *c = new unsigned[len1 + len2 - 1]{};
    uint64_t ele = 5;
    for (size_t i = 0; i < len1; i++)
    {
        // scanf("%d", &a[i]);
        a[i] = 2;
    }
    for (size_t i = 0; i < len2; i++)
    {
        // scanf("%d", &b[i]);
        b[i] = 5;
    }
    StopWatch w(1000);
    w.start();
    poly_multiply(a, len1 - 1, b, len2 - 1, c);
    w.stop();
    std::cout << w.duration() << "ms" << std::endl;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        // std::cout << c[i] << " ";
    }
    delete[] a;
    delete[] b;
    delete[] c;
}

CompilationN/AN/ACompile ErrorScore: N/A


Judge Duck Online | 评测鸭在线
Server Time: 2025-02-22 02:18:14 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠