提交记录 19382


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002. 测测你的多项式乘法 Wrong Answer 0 228.839 ms 18372 KB C++14 33.84 KB
提交时间 评测时间
2023-05-02 12:59:46 2023-05-02 12:59:51
#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"

namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using ULONG = unsigned long;
    using LONG = long;

    constexpr double PI = 3.1415926535897932384626433832795;
    constexpr double PI_2 = PI * 2;

    constexpr UINT_64 NTT_MOD1 = 3221225473;
    constexpr UINT_64 NTT_ROOT1 = 5;
    constexpr UINT_64 NTT_MOD2 = 3489660929;
    constexpr UINT_64 NTT_ROOT2 = 3;
    constexpr size_t NTT_MAX_LEN1 = 1ull << 28;

    constexpr UINT_64 NTT_MOD3 = 79164837199873;
    constexpr UINT_64 NTT_ROOT3 = 5;
    constexpr UINT_64 NTT_MOD4 = 96757023244289;
    constexpr UINT_64 NTT_ROOT4 = 3;
    constexpr size_t NTT_MAX_LEN2 = 1ull << 43;

    template <typename T>
    constexpr bool is_odd(T x)
    {
        return static_cast<bool>(x & 1);
    }
    template <typename T>
    constexpr T max_2pow(T n)
    {
        T res = 1;
        res <<= (sizeof(T) * 8 - 1);
        while (res > n)
        {
            res /= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T min_2pow(T n)
    {
        T res = 1;
        while (res < n)
        {
            res *= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T hint_log2(T n)
    {
        T res = 0;
        while (n > 1)
        {
            n /= 2;
            res++;
        }
        return res;
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, UINT_64 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    // 取模快速幂
    constexpr UINT_64 qpow(UINT_64 m, UINT_64 n, UINT_64 mod)
    {
        if (m <= 1)
        {
            return m;
        }
        UINT_64 result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m % mod;
            }
            m = m * m % mod;
            n >>= 1;
        }
        return result;
    }
    // 模意义下的逆元
    constexpr UINT_64 mod_inv(UINT_64 n, UINT_64 mod)
    {
        return qpow(n, mod - 2, mod);
    }
    template <typename T>
    inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
    {
        for (size_t i = 0; i < len; i++)
        {
            out[i] = in1[i] * in2[i];
        }
    }
    namespace hint_transform
    {
        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T &ary, size_t len)
        {
            size_t i = 0;
            for (size_t j = 1; j < len - 1; j++)
            {
                size_t k = len >> 1;
                i ^= k;
                while (k > i)
                {
                    k >>= 1;
                    i ^= k;
                };
                if (j < i)
                {
                    std::swap(ary[i], ary[j]);
                }
            }
        }
        // 四进制逆序
        template <typename SizeType = UINT_32, typename T>
        void quaternary_reverse_swap(T &ary, size_t len)
        {
            SizeType log_n = hint_log2(len);
            SizeType *rev = new SizeType[len / 4];
            rev[0] = 0;
            for (SizeType i = 1; i < len; i++)
            {
                SizeType index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
                if (i < len / 4)
                {
                    rev[i] = index;
                }
                if (i < index)
                {
                    std::swap(ary[i], ary[index]);
                }
            }
            delete[] rev;
        }
        namespace hint_ntt
        {
            template <typename DataTy, UINT_64 MOD>
            struct ModInt
            {
                // 实际存放的整数
                DataTy data = 0;
                // 等价复数的i
                constexpr ModInt() noexcept {}
                constexpr ModInt(DataTy num) noexcept
                {
                    data = num;
                }
                constexpr ModInt operator+(ModInt n) const
                {
                    UINT_64 sum = UINT_64(data) + n.data;
                    return sum < MOD ? sum : sum - MOD;
                }
                constexpr ModInt operator-(ModInt n) const
                {
                    return data < n.data ? MOD + data - n.data : data - n.data;
                }
                constexpr ModInt operator*(ModInt n) const
                {
                    return static_cast<UINT_64>(data) * n.data % MOD;
                }
                constexpr ModInt operator/(ModInt n) const
                {
                    return *this * inv();
                }
                constexpr ModInt inv() const
                {
                    return qpow(*this, MOD - 2u);
                }
                constexpr static ModInt mod()
                {
                    return MOD;
                }
            };
            template <UINT_64 MOD>
            struct ModInt<UINT_64, MOD>
            {
                // 实际存放的整数
                UINT_64 data = 0;
                // 等价复数的i
                constexpr ModInt() noexcept {}
                constexpr ModInt(UINT_64 num) noexcept
                {
                    data = num;
                }
                constexpr ModInt operator+(ModInt n) const
                {
                    return (n.data += data) < MOD ? n.data : n.data - MOD;
                }
                constexpr ModInt operator-(ModInt n) const
                {
                    return data < n.data ? MOD + data - n.data : data - n.data;
                }
                constexpr ModInt operator*(ModInt n) const
                {
                    UINT_64 b1 = n.data & 0xffff;
                    UINT_64 b2 = (n.data >> 16) & 0xffff;
                    UINT_64 b3 = n.data >> 32;
                    b1 = data * b1;
                    b2 = data * b2;
                    b3 = (((data * b3) % MOD) << 16) + b2;
                    b3 = ((b3 % MOD) << 16) + b1;
                    return b3 % MOD;
                    // UINT_64 b2 = n.data >> 20;
                    // n.data &= 0xfffff;
                    // n.data *= data;
                    // b2 = data * b2 % MOD;
                    // return (n.data + (b2 << 20)) % MOD;
                }
                constexpr ModInt operator/(ModInt n) const
                {
                    return *this * inv();
                }
                constexpr ModInt inv() const
                {
                    return qpow(*this, MOD - 2ull);
                }
                constexpr static ModInt mod()
                {
                    return MOD;
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct NTT_BASIC
            {
                static constexpr UINT_64 mod()
                {
                    return MOD;
                }
                static constexpr UINT_64 root()
                {
                    return ROOT;
                }

                using NTTModInt32 = ModInt<UINT_32, MOD>;
                using NTTModInt64 = ModInt<UINT_64, MOD>;

                template <typename T>
                static constexpr void ntt_normalize(T *input, size_t ntt_len)
                {
                    const T inv = T(ntt_len).inv();
                    size_t mod4 = ntt_len % 4;
                    ntt_len -= mod4;
                    for (size_t i = 0; i < ntt_len; i += 4)
                    {
                        input[i] = inv * input[i];
                        input[i + 1] = inv * input[i + 1];
                        input[i + 2] = inv * input[i + 2];
                        input[i + 3] = inv * input[i + 3];
                    }
                    for (size_t i = ntt_len; i < ntt_len + mod4; i++)
                    {
                        input[i] = inv * input[i];
                    }
                }
                // 2点NTT
                template <typename T>
                static constexpr void ntt_2point(T &sum, T &diff)
                {
                    T tmp1 = sum;
                    T tmp2 = diff;
                    sum = tmp1 + tmp2;
                    diff = tmp1 - tmp2;
                }
                template <typename T>
                static constexpr void ntt_dit_4point(T *input, size_t rank = 1)
                {
                    constexpr T W_4_1 = qpow(T(ROOT), (MOD - 1) / 4); // 等价于复数i
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2];
                    T tmp3 = input[rank * 3];

                    ntt_2point(tmp0, tmp1);
                    ntt_2point(tmp2, tmp3);
                    tmp3 = tmp3 * W_4_1;

                    input[0] = tmp0 + tmp2;
                    input[rank] = tmp1 + tmp3;
                    input[rank * 2] = tmp0 - tmp2;
                    input[rank * 3] = tmp1 - tmp3;
                }
                template <typename T>
                static constexpr void ntt_dit_8point(T *input, size_t rank = 1)
                {
                    constexpr T W_8_1 = qpow(T(ROOT), (MOD - 1) / 8);
                    constexpr T W_8_2 = qpow(W_8_1, 2);
                    constexpr T W_8_3 = qpow(W_8_1, 3);
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2];
                    T tmp3 = input[rank * 3];
                    T tmp4 = input[rank * 4];
                    T tmp5 = input[rank * 5];
                    T tmp6 = input[rank * 6];
                    T tmp7 = input[rank * 7];
                    ntt_2point(tmp0, tmp1);
                    ntt_2point(tmp2, tmp3);
                    ntt_2point(tmp4, tmp5);
                    ntt_2point(tmp6, tmp7);
                    tmp3 = tmp3 * W_8_2;
                    tmp7 = tmp7 * W_8_2;

                    ntt_2point(tmp0, tmp2);
                    ntt_2point(tmp1, tmp3);
                    ntt_2point(tmp4, tmp6);
                    ntt_2point(tmp5, tmp7);
                    tmp5 = tmp5 * W_8_1;
                    tmp6 = tmp6 * W_8_2;
                    tmp7 = tmp7 * W_8_3;

                    input[0] = tmp0 + tmp4;
                    input[rank] = tmp1 + tmp5;
                    input[rank * 2] = tmp2 + tmp6;
                    input[rank * 3] = tmp3 + tmp7;
                    input[rank * 4] = tmp0 - tmp4;
                    input[rank * 5] = tmp1 - tmp5;
                    input[rank * 6] = tmp2 - tmp6;
                    input[rank * 7] = tmp3 - tmp7;
                }
                template <typename T>
                static constexpr void ntt_dit_16point(T *input, size_t rank = 1)
                {
                    constexpr T W_16_1 = qpow(T(ROOT), (MOD - 1) / 16);
                    constexpr T W_16_2 = qpow(W_16_1, 2);
                    constexpr T W_16_3 = qpow(W_16_1, 3);
                    constexpr T W_16_4 = qpow(W_16_1, 4);
                    constexpr T W_16_5 = qpow(W_16_1, 5);
                    constexpr T W_16_6 = qpow(W_16_1, 6);
                    constexpr T W_16_7 = qpow(W_16_1, 7);

                    ntt_dit_8point(input, rank);
                    ntt_dit_8point(input + rank * 8, rank);

                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 8];
                    T tmp3 = input[rank * 9] * W_16_1;
                    input[0] = tmp0 + tmp2;
                    input[rank] = tmp1 + tmp3;
                    input[rank * 8] = tmp0 - tmp2;
                    input[rank * 9] = tmp1 - tmp3;

                    tmp0 = input[rank * 2];
                    tmp1 = input[rank * 3];
                    tmp2 = input[rank * 10] * W_16_2;
                    tmp3 = input[rank * 11] * W_16_3;
                    input[rank * 2] = tmp0 + tmp2;
                    input[rank * 3] = tmp1 + tmp3;
                    input[rank * 10] = tmp0 - tmp2;
                    input[rank * 11] = tmp1 - tmp3;

                    tmp0 = input[rank * 4];
                    tmp1 = input[rank * 5];
                    tmp2 = input[rank * 12] * W_16_4;
                    tmp3 = input[rank * 13] * W_16_5;
                    input[rank * 4] = tmp0 + tmp2;
                    input[rank * 5] = tmp1 + tmp3;
                    input[rank * 12] = tmp0 - tmp2;
                    input[rank * 13] = tmp1 - tmp3;

                    tmp0 = input[rank * 6];
                    tmp1 = input[rank * 7];
                    tmp2 = input[rank * 14] * W_16_6;
                    tmp3 = input[rank * 15] * W_16_7;
                    input[rank * 6] = tmp0 + tmp2;
                    input[rank * 7] = tmp1 + tmp3;
                    input[rank * 14] = tmp0 - tmp2;
                    input[rank * 15] = tmp1 - tmp3;
                }
                template <typename T>
                static constexpr void ntt_dif_4point(T *input, size_t rank = 1)
                {
                    constexpr T W_4_1 = qpow(T(ROOT), (MOD - 1) / 4);
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2];
                    T tmp3 = input[rank * 3];

                    ntt_2point(tmp0, tmp2);
                    ntt_2point(tmp1, tmp3);
                    tmp3 = tmp3 * W_4_1;

                    input[0] = tmp0 + tmp1;
                    input[rank] = tmp0 - tmp1;
                    input[rank * 2] = tmp2 + tmp3;
                    input[rank * 3] = tmp2 - tmp3;
                }
                template <typename T>
                static constexpr void ntt_dif_8point(T *input, size_t rank = 1)
                {
                    constexpr T W_8_1 = qpow(T(ROOT), (MOD - 1) / 8);
                    constexpr T W_8_2 = qpow(W_8_1, 2);
                    constexpr T W_8_3 = qpow(W_8_1, 3);
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2];
                    T tmp3 = input[rank * 3];
                    T tmp4 = input[rank * 4];
                    T tmp5 = input[rank * 5];
                    T tmp6 = input[rank * 6];
                    T tmp7 = input[rank * 7];

                    ntt_2point(tmp0, tmp4);
                    ntt_2point(tmp1, tmp5);
                    ntt_2point(tmp2, tmp6);
                    ntt_2point(tmp3, tmp7);
                    tmp5 = tmp5 * W_8_1;
                    tmp6 = tmp6 * W_8_2;
                    tmp7 = tmp7 * W_8_3;

                    ntt_2point(tmp0, tmp2);
                    ntt_2point(tmp1, tmp3);
                    ntt_2point(tmp4, tmp6);
                    ntt_2point(tmp5, tmp7);
                    tmp3 = tmp3 * W_8_2;
                    tmp7 = tmp7 * W_8_2;

                    input[0] = tmp0 + tmp1;
                    input[rank] = tmp0 - tmp1;
                    input[rank * 2] = tmp2 + tmp3;
                    input[rank * 3] = tmp2 - tmp3;
                    input[rank * 4] = tmp4 + tmp5;
                    input[rank * 5] = tmp4 - tmp5;
                    input[rank * 6] = tmp6 + tmp7;
                    input[rank * 7] = tmp6 - tmp7;
                }
                template <typename T>
                static constexpr void ntt_dif_16point(T *input, size_t rank = 1)
                {
                    constexpr T W_16_1 = qpow(T(ROOT), (MOD - 1) / 16);
                    constexpr T W_16_2 = qpow(W_16_1, 2);
                    constexpr T W_16_3 = qpow(W_16_1, 3);
                    constexpr T W_16_4 = qpow(W_16_1, 4);
                    constexpr T W_16_5 = qpow(W_16_1, 5);
                    constexpr T W_16_6 = qpow(W_16_1, 6);
                    constexpr T W_16_7 = qpow(W_16_1, 7);

                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 8];
                    T tmp3 = input[rank * 9];
                    input[0] = tmp0 + tmp2;
                    input[rank] = tmp1 + tmp3;
                    input[rank * 8] = tmp0 - tmp2;
                    input[rank * 9] = (tmp1 - tmp3) * W_16_1;

                    tmp0 = input[rank * 2];
                    tmp1 = input[rank * 3];
                    tmp2 = input[rank * 10];
                    tmp3 = input[rank * 11];
                    input[rank * 2] = tmp0 + tmp2;
                    input[rank * 3] = tmp1 + tmp3;
                    input[rank * 10] = (tmp0 - tmp2) * W_16_2;
                    input[rank * 11] = (tmp1 - tmp3) * W_16_3;

                    tmp0 = input[rank * 4];
                    tmp1 = input[rank * 5];
                    tmp2 = input[rank * 12];
                    tmp3 = input[rank * 13];
                    input[rank * 4] = tmp0 + tmp2;
                    input[rank * 5] = tmp1 + tmp3;
                    input[rank * 12] = (tmp0 - tmp2) * W_16_4;
                    input[rank * 13] = (tmp1 - tmp3) * W_16_5;

                    tmp0 = input[rank * 6];
                    tmp1 = input[rank * 7];
                    tmp2 = input[rank * 14];
                    tmp3 = input[rank * 15];
                    input[rank * 6] = tmp0 + tmp2;
                    input[rank * 7] = tmp1 + tmp3;
                    input[rank * 14] = (tmp0 - tmp2) * W_16_6;
                    input[rank * 15] = (tmp1 - tmp3) * W_16_7;

                    ntt_dif_8point(input, rank);
                    ntt_dif_8point(input + rank * 8, rank);
                }
                // 基2时间抽取ntt蝶形
                template <typename T>
                static constexpr void ntt_radix2_dit_butterfly(T omega, T *input, size_t rank)
                {
                    T tmp1 = input[0];
                    T tmp2 = input[rank] * omega;
                    input[0] = tmp1 + tmp2;
                    input[rank] = tmp1 - tmp2;
                }
                // 基2频率抽取ntt蝶形
                template <typename T>
                static constexpr void ntt_radix2_dif_butterfly(T omega, T *input, size_t rank)
                {
                    T tmp1 = input[0];
                    T tmp2 = input[rank];
                    input[0] = tmp1 + tmp2;
                    input[rank] = (tmp1 - tmp2) * omega;
                }
                // ntt分裂基时间抽取蝶形变换
                template <typename T>
                static constexpr void ntt_split_radix_dit_butterfly(T omega, T omega_cube,
                                                                    T *input, size_t rank)
                {
                    constexpr T W_4_1 = qpow(T(ROOT), (MOD - 1) / 4); // 等价于复数i
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2] * omega;
                    T tmp3 = input[rank * 3] * omega_cube;

                    ntt_2point(tmp2, tmp3);
                    tmp3 = tmp3 * W_4_1;

                    input[0] = tmp0 + tmp2;
                    input[rank] = tmp1 + tmp3;
                    input[rank * 2] = tmp0 - tmp2;
                    input[rank * 3] = tmp1 - tmp3;
                }
                // ntt分裂基频率抽取蝶形变换
                template <typename T>
                static constexpr void ntt_split_radix_dif_butterfly(T omega, T omega_cube,
                                                                    T *input, size_t rank)
                {
                    constexpr T W_4_1 = qpow(T(ROOT), (MOD - 1) / 4); // 等价于复数i
                    T tmp0 = input[0];
                    T tmp1 = input[rank];
                    T tmp2 = input[rank * 2];
                    T tmp3 = input[rank * 3];

                    ntt_2point(tmp0, tmp2);
                    ntt_2point(tmp1, tmp3);
                    tmp3 = tmp3 * W_4_1;

                    input[0] = tmp0;
                    input[rank] = tmp1;
                    input[rank * 2] = (tmp2 + tmp3) * omega;
                    input[rank * 3] = (tmp2 - tmp3) * omega_cube;
                }
            };
            // 模板递归分裂基NTT
            template <size_t LEN, UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                static constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
                // 模板化时间抽取分裂基ntt
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[])
                {
                    SPLIT_RADIX_NTT<half_len, MOD, ROOT>::ntt_split_radix_dit_template(input);
                    SPLIT_RADIX_NTT<quarter_len, MOD, ROOT>::ntt_split_radix_dit_template(input + half_len);
                    SPLIT_RADIX_NTT<quarter_len, MOD, ROOT>::ntt_split_radix_dit_template(input + half_len + quarter_len);
                    constexpr T unit = qpow(T(ROOT), (MOD - 1) / LEN);
                    constexpr T unit_cube = qpow(unit, 3);
                    T omega(1), omega_cube(1);
                    for (size_t i = 0; i < quarter_len; i++)
                    {
                        ntt_basic::ntt_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
                        omega = omega * unit;
                        omega_cube = omega_cube * unit_cube;
                    }
                }
                // 模板化频率抽取分裂基ntt
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[])
                {
                    constexpr T unit = qpow(T(ROOT), (MOD - 1) / LEN);
                    constexpr T unit_cube = qpow(unit, 3);
                    T omega(1), omega_cube(1);
                    for (size_t i = 0; i < quarter_len; i++)
                    {
                        ntt_basic::ntt_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
                        omega = omega * unit;
                        omega_cube = omega_cube * unit_cube;
                    }
                    SPLIT_RADIX_NTT<half_len, MOD, ROOT>::ntt_split_radix_dif_template(input);
                    SPLIT_RADIX_NTT<quarter_len, MOD, ROOT>::ntt_split_radix_dif_template(input + half_len);
                    SPLIT_RADIX_NTT<quarter_len, MOD, ROOT>::ntt_split_radix_dif_template(input + half_len + quarter_len);
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<0, MOD, ROOT>
            {
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[]) {}
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[]) {}
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<1, MOD, ROOT>
            {
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[]) {}
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[]) {}
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<2, MOD, ROOT>
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[])
                {
                    ntt_basic::ntt_2point(input[0], input[1]);
                }
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[])
                {
                    ntt_basic::ntt_2point(input[0], input[1]);
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<4, MOD, ROOT>
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[])
                {
                    ntt_basic::ntt_dit_4point(input, 1);
                }
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[])
                {
                    ntt_basic::ntt_dif_4point(input, 1);
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<8, MOD, ROOT>
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[])
                {
                    ntt_basic::ntt_dit_8point(input, 1);
                }
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[])
                {
                    ntt_basic::ntt_dif_8point(input, 1);
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct SPLIT_RADIX_NTT<16, MOD, ROOT>
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                template <typename T>
                static constexpr void ntt_split_radix_dit_template(T input[])
                {
                    ntt_basic::ntt_dit_16point(input, 1);
                }
                template <typename T>
                static constexpr void ntt_split_radix_dif_template(T input[])
                {
                    ntt_basic::ntt_dif_16point(input, 1);
                }
            };
            // 分裂基辅助选择类
            template <size_t LEN, UINT_64 MOD, UINT_64 ROOT>
            struct NTT_ALT
            {
                template <typename T>
                static constexpr void ntt_dit_template(T input[], size_t ntt_len)
                {
                    if (ntt_len > LEN)
                    {
                        NTT_ALT<LEN * 2, MOD, ROOT>::ntt_dit_template(input, ntt_len);
                        return;
                    }
                    SPLIT_RADIX_NTT<LEN, MOD, ROOT>::ntt_split_radix_dit_template(input);
                }
                template <typename T>
                static constexpr void ntt_dif_template(T input[], size_t ntt_len)
                {
                    if (ntt_len > LEN)
                    {
                        NTT_ALT<LEN * 2, MOD, ROOT>::ntt_dif_template(input, ntt_len);
                        return;
                    }
                    SPLIT_RADIX_NTT<LEN, MOD, ROOT>::ntt_split_radix_dif_template(input);
                }
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct NTT_ALT<size_t(1) << 43, MOD, ROOT>
            {
                template <typename T>
                static constexpr void ntt_dit_template(T input[], size_t ntt_len) {}
                template <typename T>
                static constexpr void ntt_dif_template(T input[], size_t ntt_len) {}
            };
            template <UINT_64 MOD, UINT_64 ROOT>
            struct NTT
            {
                using ntt_basic = NTT_BASIC<MOD, ROOT>;
                using NTTModInt32 = ModInt<UINT_32, MOD>;
                using NTTModInt64 = ModInt<UINT_64, MOD>;
                static constexpr UINT_64 mod()
                {
                    return MOD;
                }
                static constexpr UINT_64 root()
                {
                    return ROOT;
                }
                static constexpr UINT_64 iroot()
                {
                    return NTTModInt64(ROOT).inv().data;
                }
                // 基2时间抽取ntt,学习用
                template <typename T>
                static void ntt_radix2_dit(T *input, size_t ntt_len, bool bit_rev = true)
                {
                    ntt_len = max_2pow(ntt_len);
                    if (ntt_len <= 1)
                    {
                        return;
                    }
                    if (ntt_len == 2)
                    {
                        ntt_basic::ntt_2point(input[0], input[1]);
                        return;
                    }
                    if (bit_rev)
                    {
                        binary_reverse_swap(input, ntt_len);
                    }
                    for (size_t pos = 0; pos < ntt_len; pos += 2)
                    {
                        ntt_basic::ntt_2point(input[pos], input[pos + 1]);
                    }
                    for (size_t rank = 2; rank < ntt_len / 2; rank *= 2)
                    {
                        size_t gap = rank * 2;
                        T unit_omega = qpow(T(ROOT), (MOD - 1) / gap);
                        for (size_t begin = 0; begin < ntt_len; begin += (gap * 2))
                        {
                            ntt_basic::ntt_2point(input[begin], input[begin + rank]);
                            ntt_basic::ntt_2point(input[begin + gap], input[begin + rank + gap]);
                            T omega = unit_omega;
                            for (size_t pos = begin + 1; pos < begin + rank; pos++)
                            {
                                ntt_basic::ntt_radix2_dit_butterfly(omega, input + pos, rank);
                                ntt_basic::ntt_radix2_dit_butterfly(omega, input + pos + gap, rank);
                                omega = omega * unit_omega;
                            }
                        }
                    }
                    T omega = 1, unit_omega = qpow(T(ROOT), (MOD - 1) / ntt_len);
                    ntt_len /= 2;
                    for (size_t pos = 0; pos < ntt_len; pos++)
                    {
                        ntt_basic::ntt_radix2_dit_butterfly(omega, input + pos, ntt_len);
                        omega = omega * unit_omega;
                    }
                }
                // 基2频率抽取ntt,学习用
                template <typename T>
                static void ntt_radix2_dif(T *input, size_t ntt_len, bool bit_rev = true)
                {
                    ntt_len = max_2pow(ntt_len);
                    if (ntt_len <= 1)
                    {
                        return;
                    }
                    if (ntt_len == 2)
                    {
                        ntt_basic::ntt_2point(input[0], input[1]);
                        return;
                    }
                    T unit_omega = qpow(T(ROOT), (MOD - 1) / ntt_len);
                    T omega = 1;
                    for (size_t pos = 0; pos < ntt_len / 2; pos++)
                    {
                        ntt_basic::ntt_radix2_dif_butterfly(omega, input + pos, ntt_len / 2);
                        omega = omega * unit_omega;
                    }
                    unit_omega = unit_omega * unit_omega;
                    for (size_t rank = ntt_len / 4; rank > 1; rank /= 2)
                    {
                        size_t gap = rank * 2;
                        for (size_t begin = 0; begin < ntt_len; begin += (gap * 2))
                        {
                            ntt_basic::ntt_2point(input[begin], input[begin + rank]);
                            ntt_basic::ntt_2point(input[begin + gap], input[begin + rank + gap]);
                            T omega = unit_omega;
                            for (size_t pos = begin + 1; pos < begin + rank; pos++)
                            {
                                ntt_basic::ntt_radix2_dif_butterfly(omega, input + pos, rank);
                                ntt_basic::ntt_radix2_dif_butterfly(omega, input + pos + gap, rank);
                                omega = omega * unit_omega;
                            }
                        }
                        unit_omega = unit_omega * unit_omega;
                    }
                    for (size_t pos = 0; pos < ntt_len; pos += 2)
                    {
                        ntt_basic::ntt_2point(input[pos], input[pos + 1]);
                    }
                    if (bit_rev)
                    {
                        binary_reverse_swap(input, ntt_len);
                    }
                }
                template <typename T>
                static constexpr void ntt_split_radix_dit(T input[], size_t ntt_len)
                {
                    NTT_ALT<1, MOD, ROOT>::ntt_dit_template(input, ntt_len);
                }
                template <typename T>
                static constexpr void ntt_split_radix_dif(T input[], size_t ntt_len)
                {
                    NTT_ALT<1, MOD, ROOT>::ntt_dif_template(input, ntt_len);
                }
                template <typename T>
                static constexpr void ntt_dit(T input[], size_t ntt_len)
                {
                    ntt_split_radix_dit(input, ntt_len);
                }
                template <typename T>
                static constexpr void ntt_dif(T input[], size_t ntt_len)
                {
                    ntt_split_radix_dif(input, ntt_len);
                }
                using intt = NTT<mod(), iroot()>;
            };
            using ntt1 = NTT<NTT_MOD1, NTT_ROOT1>;
            using intt1 = NTT<ntt1::mod(), ntt1::iroot()>;

            using ntt2 = NTT<NTT_MOD2, NTT_ROOT2>;
            using intt2 = NTT<ntt2::mod(), ntt2::iroot()>;

            using ntt3 = NTT<NTT_MOD3, NTT_ROOT3>;
            using intt3 = NTT<ntt3::mod(), ntt3::iroot()>;

            using ntt4 = NTT<NTT_MOD4, NTT_ROOT4>;
            using intt4 = NTT<ntt4::mod(), ntt4::iroot()>;
        }
    }
}

using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_ntt;

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    size_t len1 = n + 1, len2 = m + 1, out_len = len1 + len2;
    size_t ntt_len = min_2pow(out_len);

    using ntt = ntt1;
    using intt = ntt::intt;

    auto mod_ary1 = new ntt::NTTModInt32[ntt_len]{};
    auto mod_ary2 = new ntt::NTTModInt32[ntt_len]{};

    std::memcpy(reinterpret_cast<unsigned *>(mod_ary1), a, len1);
    std::memcpy(reinterpret_cast<unsigned *>(mod_ary2), b, len2);

    ntt::ntt_dif(mod_ary1, ntt_len);
    ntt::ntt_dif(mod_ary2, ntt_len);
    ary_mul(mod_ary1, mod_ary2, mod_ary1, ntt_len);
    intt::ntt_dit(mod_ary1, ntt_len);
    intt::ntt_basic::ntt_normalize(mod_ary1, ntt_len);
    std::memcpy(reinterpret_cast<ntt::NTTModInt32 *>(c), mod_ary1, len1 + len2 - 1);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1228.839 ms17 MB + 964 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-15 13:34:11 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠