提交记录 19668


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1002i. 【模板题】多项式乘法 Compile Error 0 0 ns 0 KB C++14 45.26 KB
提交时间 评测时间
2023-07-11 20:51:44 2023-07-11 20:51:45
#include <cstdint>
using namespace std;
using i64 = int64_t;
using u32 = uint32_t;
using u64 = uint64_t;
#include <cstring>
#include <string>
#include <vector>
namespace detail
{
    template <class Buf>
    struct FastI : Buf
    {
        using Buf::pop;
        using Buf::top;
        FastI(FILE *f, u32 size = 1 << 18) : Buf(f, size) {}
        void skipSpace()
        {
            while (top() <= ' ')
                pop();
        }
        FastI &operator>>(char &x)
        {
            skipSpace();
            x = pop();
            return *this;
        }
        FastI &operator>>(string &x)
        {
            x.resize(0);
            skipSpace();

            while (isgraph(top()))
                x.push_back(pop());

            return *this;
        }
        template <unsigned_integral T>
        FastI &operator>>(T &x)
        {
            x = 0;
            skipSpace();

            while (top() >= '0')
                x = x * 10 + (pop() & 0xf);

            return *this;
        }
        template <signed_integral T>
        FastI &operator>>(T &x)
        {
            bool neg = false;
            x = 0;
            skipSpace();

            if (top() == '-')
                neg = true, pop();

            while (top() >= '0')
                x = x * 10 + (pop() & 0xf);

            x = neg ? -x : x;
            return *this;
        }
    };
    template <class Buf>
    struct FastO : Buf
    {
        using Buf::push;
        using Buf::push_uncheck;
        using Buf::puts;
        vector<u32> pre;
        FastO(FILE *f, u32 size = 1 << 18) : Buf(f, size), pre(u64(1E4))
        {
            for (int i = 0; i < u64(1E4); ++i)
            {
                int ti = i;

                for (int j = 0; j < 4; ++j)
                {
                    pre[i] = pre[i] << 8 | ti % 10 | 0x30;
                    ti /= 10;
                }
            }
        }
        ~FastO()
        {
            Buf::flush();
        }
        template <signed_integral T>
        FastO &operator<<(T x)
        {
            if (x < 0)
                push('-'), x = -x;

            return *this << make_unsigned<T>::type(x);
        }
        void output4(int t)
        {
            auto tp = (const char *)&pre[t];

            if (t >= u64(1E2))
            {
                if (t >= u64(1E3))
                    push_uncheck(tp, 4);
                else
                    push_uncheck(tp + 1, 3);
            }
            else
            {
                if (t >= u64(1E1))
                    push_uncheck(tp + 2, 2);
                else
                    push_uncheck(t | 0x30);
            }
        };
        template <unsigned_integral T>
        FastO &operator<<(T x)
        {
            Buf::reserve(32);

            if (x >= u64(1E8))
            {
                u64 q0 = x / u64(1E8), r0 = x % u64(1E8);

                if (x >= u64(1E16))
                {
                    u64 q1 = q0 / u64(1E8), r1 = q0 % u64(1E8);
                    output4(q1);
                    push_uncheck(&pre[r1 / u64(1E4)], 4);
                    push_uncheck(&pre[r1 % u64(1E4)], 4);
                }
                else if (x >= u64(1E12))
                {
                    output4(q0 / u64(1E4));
                    push_uncheck(&pre[q0 % u64(1E4)], 4);
                }
                else
                {
                    output4(q0);
                }

                push_uncheck(&pre[r0 / u64(1E4)], 4);
                push_uncheck(&pre[r0 % u64(1E4)], 4);
            }
            else
            {
                if (x >= u64(1E4))
                {
                    output4(x / u64(1E4));
                    push_uncheck(&pre[x % u64(1E4)], 4);
                }
                else
                {
                    output4(x);
                }
            }

            return *this;
        }
        FastO &operator<<(char x)
        {
            return push(x), *this;
        }
        FastO &operator<<(const char *x)
        {
            return puts(x), *this;
        }
        template <size_t N>
        FastO &operator<<(const char x[N])
        {
            return push(x, N), *this;
        }
        FastO &operator<<(const string &x)
        {
            return push(x.c_str(), x.size()), *this;
        }
    };
    struct BufO
    {
        FILE *f;
        char *beg, *end, *p;
        BufO(FILE *f_, u32 sz) : f(f_), beg(new char[sz]), end(beg + sz - 1), p(beg) {}
        ~BufO()
        {
            delete[] beg;
        }
        void flush()
        {
            fwrite(beg, 1, p - beg, f);
            p = beg;
        }
        void reserve(u32 len)
        {
            if (end - p <= int(len))
                flush();
        }
        void push(char s)
        {
            *p++ = s;
            reserve(0);
        }
        void push(const char *s, u32 len)
        {
            reserve(len);
            push_uncheck(s, len);
        }
        void push_uncheck(char s) { *p++ = s; }
        void push_uncheck(const void *s, u32 len)
        {
            memcpy(p, s, len);
            p += len;
        }
        void puts(const char *s)
        {
            while (*s)
                push(*s++);
        }
    };
}
#include <sys/mman.h>
#include <sys/stat.h>
namespace detail
{
    struct BufI
    {
        struct stat sb;
        char *p;
        BufI(FILE *f, u32)
        {
            int fd = fileno(f);
            fstat(fd, &sb);
            p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
            madvise(p, sb.st_size, MADV_SEQUENTIAL);
        }
        ~BufI() { munmap(p, sb.st_size); }
        char pop() { return *p++; }
        char top() const { return *p; }
    };
}
using FastI = detail::FastI<detail::BufI>;
using FastO = detail::FastO<detail::BufO>;
#include <type_traits>
template <class T, T MOD>
struct MontgomerySpace;
template <u32 MOD>
struct MontgomerySpace<u32, MOD>
{
    static_assert(2 < MOD && MOD < u32(1) << 30, "mod must in [3, 2^30)");
    static_assert(MOD % 2 == 1, "mod must be odd");
    using ValueT = u32;
    using TransT = u32;
    using rawU32 = false_type;
    using isMontgomery = true_type;
    constexpr static u32 get_nr()
    {
        u32 x = 1;

        for (int i = 0; i < 5; ++i)
            x *= 2 - x * MOD;

        return x;
    }
    consteval static u32 mod()
    {
        return MOD;
    }
    enum : u32
    {
        R = u32(u64(1) << 32 % MOD),
        IR = u32(-get_nr()),
        MOD2 = MOD * 2,
    };
    constexpr static TransT trans(ValueT x)
    {
        return (u64(x) << 32) % MOD;
    }
    constexpr static u32 reduce(u64 x)
    {
        return (x + u64(u32(x) * IR) * MOD) >> 32;
    }
    constexpr static u32 reduce_m(u32 n)
    {
        return n >> 31 ? n + MOD : n;
    }
    constexpr static u32 reduce_2m(u32 n)
    {
        return n >> 31 ? n + MOD2 : n;
    }
    constexpr static u32 add(u32 a, u32 b)
    {
        return reduce_2m(a + b - MOD2);
    }
    constexpr static u32 sub(u32 a, u32 b)
    {
        return reduce_2m(a - b);
    }
    constexpr static u32 mul(u32 a, u32 b)
    {
        return reduce(u64(a) * b);
    }
    constexpr static u32 safe(i64 x)
    {
        return reduce_m(x % MOD);
    }
    constexpr static ValueT val(TransT x)
    {
        return reduce_m(reduce(x) - MOD);
    }
    constexpr static u32 shift2(u32 x)
    {
        x = reduce(x);
        return (x & 1 ? x + MOD : x) >> 1;
    }
};
constexpr u32 qpow(u32 a, u64 b, u32 m)
{
    u32 r = 1;

    for (; b > 0; b /= 2)
    {
        if (b % 2 == 1)
            r = u64(a) * r % m;

        a = u64(a) * a % m;
    }

    return r;
}
#include <algorithm>
#include <cassert>
#include <optional>
u32 legendre(u32 a, u32 p)
{
    return qpow(a, (p - 1) / 2, p);
}
optional<int> cipola(u32 n, u32 p)
{
    if (n == 0)
        return 0;

    if (legendre(n, p) != 1)
        return nullopt;

    if (p == 2)
        return 1;

    for (u32 a = 0; a < p; a++)
    {
        u32 i = (a * a - n + p) % p;
        using FP2 = pair<u64, u64>;
        auto mul = [p, i](const FP2 &l, const FP2 &r)
        {
            auto [la, lb] = l;
            auto [ra, rb] = r;
            return FP2{(la * ra + lb * rb % p * i) % p, (lb * ra + la * rb) % p};
        };

        if (legendre(i, p) == p - 1)
        {
            FP2 x = {1, 1}, u = {a, 1};

            for (int b = (p + 1) / 2; b; b /= 2)
            {
                if (b % 2 == 1)
                    x = mul(x, u);

                u = mul(u, u);
            }

            return min(x.first, p - x.first);
        }
    }

    return nullopt;
}
#include <type_traits>
template <class T, T MOD>
struct BasicModSpace;
template <u32 MOD>
struct BasicModSpace<u32, MOD>
{
    static_assert(2 < MOD && MOD < u32(1) << 31, "mod must in [3, 2^31)");
    using ValueT = u32;
    using TransT = u32;
    using rawU32 = true_type;
    using isMontgomery = false_type;
    enum : u32
    {
        MOD2 = MOD * 2,
    };
    constexpr static u32 mod()
    {
        return MOD;
    }
    constexpr static TransT trans(ValueT x)
    {
        return x;
    }
    constexpr static ValueT val(TransT x)
    {
        return reduce_m(x);
    }
    constexpr static u32 reduce_m(ValueT n)
    {
        return n >> 31 ? n + MOD : n;
    }
    constexpr static u32 reduce_2m(u32 n)
    {
        return n >> 31 ? n + MOD2 : n;
    }
    constexpr static u32 add(u32 a, u32 b)
    {
        return reduce_m(a + b - MOD);
    }
    constexpr static u32 sub(u32 a, u32 b)
    {
        return reduce_m(a - b);
    }
    constexpr static u32 mul(u32 a, u32 b)
    {
        return u64(a) * b % MOD;
    }
    constexpr static u32 safe(i64 x)
    {
        return reduce_m(x % MOD);
    }
    constexpr static u32 shift2(u32 x)
    {
        return (x & 1 ? x + MOD : x) >> 1;
    }
};
#include <iostream>
// 封装 Modint,功能由 Space 提供
template <class Space_>
struct StaticModint
{
    using Space = Space_;
    using ValueT = typename Space::ValueT;
    using TransT = typename Space::TransT;
    using isStatic = true_type;
    using rawU32 = typename Space::rawU32;
    using isMontgomery = typename Space::isMontgomery;
    TransT v;
    constexpr StaticModint() = default;
    constexpr StaticModint(ValueT v_) : v(Space::trans(v_)) {}
    using Self = StaticModint;
    explicit operator ValueT() const
    {
        return val();
    }
    constexpr static Self safe(i64 v)
    {
        return Self(Space::safe(v));
    }
    constexpr ValueT val() const
    {
        return Space::val(v);
    }
    constexpr TransT raw() const
    {
        return v;
    }
    constexpr static ValueT mod()
    {
        return Space::mod();
    }
    constexpr Self &operator+=(const Self &rhs)
    {
        v = Space::add(v, rhs.v);
        return *this;
    }
    constexpr Self &operator-=(const Self &rhs)
    {
        v = Space::sub(v, rhs.v);
        return *this;
    }
    constexpr Self &operator*=(const Self &rhs)
    {
        v = Space::mul(v, rhs.v);
        return *this;
    }
    friend constexpr inline Self operator+(const Self &lhs, const Self &rhs)
    {
        return Self(lhs) += rhs;
    }
    friend constexpr inline Self operator-(const Self &lhs, const Self &rhs)
    {
        return Self(lhs) -= rhs;
    }
    friend constexpr inline Self operator*(const Self &lhs, const Self &rhs)
    {
        return Self(lhs) *= rhs;
    }
    constexpr Self pow(u64 n) const
    {
        Self r(1), a(*this);

        for (; n > 0; n /= 2)
        {
            if (n % 2 == 1)
                r *= a;

            a *= a;
        }

        return r;
    }
    constexpr Self inv() const
    {
        return pow(Space::mod() - 2);
    }
    constexpr Self &operator/=(const Self &rhs)
    {
        return *this *= rhs.inv();
    }
    friend constexpr inline Self operator/(const Self &lhs, const Self &rhs)
    {
        return Self(lhs) /= rhs;
    }
    constexpr Self operator-() const
    {
        return Self() -= *this;
    }
    constexpr optional<Self> sqrt() const
    {
        return cipola(val(), mod());
    }
    constexpr Self shift2() const
    {
        return Space::shift2(v);
    }
    friend inline istream &operator>>(istream &is, Self &m)
    {
        i64 x;
        is >> x;
        m = Self::safe(x);
        return is;
    }
    friend inline ostream &operator<<(ostream &os, const Self &m)
    {
        return os << m.val();
    }
    friend inline bool operator==(const Self &lhs, const Self &rhs)
    {
        return lhs.val() == rhs.val();
    }
    friend inline bool operator!=(const Self &lhs, const Self &rhs)
    {
        return !(lhs == rhs);
    }
};
template <class T, T MOD>
using BasicStaticModint = StaticModint<BasicModSpace<T, MOD>>;
template <class Space>
inline FastI &operator>>(FastI &is, StaticModint<Space> &m)
{
    i64 x;
    is >> x;
    m = StaticModint<Space>(x);
    return is;
}
template <class Space>
inline FastO &operator<<(FastO &os, const StaticModint<Space> &m)
{
    return os << m.val();
}
#include <type_traits>
template <class ModT>
concept static_modint_concept = ModT::isStatic::value;
template <class ModT>
concept raw32_modint_concept = ModT::rawU32::value;
template <class ModT>
concept static_raw32_modint_concept = static_modint_concept<ModT> && raw32_modint_concept<ModT>;
template <class ModT>
concept runtime_modint_concept = !
ModT::isStatic::value;
template <class ModT>
concept montgomery_modint_concept = ModT::isMontgomery::value;
template <class ModT>
concept static_basic_modint_concept = !
montgomery_modint_concept<ModT> &&static_modint_concept<ModT>;
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
    u32 ntt_size = 0;
} // namespace detail
// #include "ntt-twisted-radix-2-basic.hpp"
// #include "ntt-barrett.hpp"
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
    template <static_modint_concept ModT>
    struct NttClassicalInfo
    {
        using ValueT = typename ModT::ValueT;
        static constexpr ValueT P = ModT::mod();
        static constexpr ValueT g = 3;
        static constexpr int rank2 = countr_zero(P - 1);
        array<ModT, rank2 + 1> rt, irt;
        array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
        constexpr NttClassicalInfo()
        {
            rt[rank2] = ModT(g).pow((P - 1) >> rank2);
            irt[rank2] = rt[rank2].inv();

            for (int i = rank2; i >= 1; --i)
            {
                rt[i - 1] = rt[i] * rt[i];
                irt[i - 1] = irt[i] * irt[i];
            }

            ModT prod = 1, iprod = 1;

            for (int i = 0; i < rank2 - 1; ++i)
            {
                rate2[i] = prod * rt[i + 2];
                irate2[i] = iprod * irt[i + 2];
                prod *= irt[i + 2];
                iprod *= rt[i + 2];
            }
        }
    };
    template <static_modint_concept ModT>
    static void ntt_classical_basic(span<ModT> f)
    { // dif
        static constexpr NttClassicalInfo<ModT> info;
        int n = f.size();

        for (int l = n / 2; l > 0; l /= 2)
        {
            ModT r = 1;

            for (int i = 0, k = 0; i < n; i += l * 2, ++k)
            {
                for (int j = 0; j < l; ++j)
                {
                    ModT x = f[i + j], y = f[i + j + l] * r;
                    f[i + j] = x + y;
                    f[i + j + l] = x - y;
                }

                r *= info.rate2[countr_one<u32>(k)];
            }
        }
    }
    template <static_modint_concept ModT>
    static void intt_classical_basic(span<ModT> f)
    { // dit
        static constexpr NttClassicalInfo<ModT> info;
        int n = f.size();

        for (int l = 1; l < n; l *= 2)
        {
            ModT r = 1;

            for (int i = 0, k = 0; i < n; i += l * 2, ++k)
            {
                for (int j = 0; j < l; ++j)
                {
                    ModT x = f[i + j], y = f[i + j + l];
                    f[i + j] = x + y;
                    f[i + j + l] = r * (x - y);
                }

                r *= info.irate2[countr_one<u32>(k)];
            }
        }

        const ModT ivn = ModT(n).inv();

        for (int i = 0; i < n; i++)
            f[i] *= ivn;
    }
} // namespace detail
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
namespace detail
{
    template <static_modint_concept ModT>
    struct NttClassicalInfo4
    {
        using ValueT = typename ModT::ValueT;
        static constexpr ValueT P = ModT::mod();
        static constexpr ValueT g = 3;
        static constexpr int rank2 = countr_zero(P - 1);
        array<ModT, rank2 + 1> rt, irt;
        array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
        array<ModT, max<int>(0, rank2 - 2)> rate3, irate3;
        constexpr NttClassicalInfo4()
        {
            rt[rank2] = ModT(g).pow((P - 1) >> rank2);
            irt[rank2] = rt[rank2].inv();

            for (int i = rank2; i >= 1; --i)
            {
                rt[i - 1] = rt[i] * rt[i];
                irt[i - 1] = irt[i] * irt[i];
            }

            ModT prod = 1, iprod = 1;

            for (int i = 0; i < rate2.size(); ++i)
            {
                rate2[i] = prod * rt[i + 2];
                irate2[i] = iprod * irt[i + 2];
                prod *= irt[i + 2];
                iprod *= rt[i + 2];
            }

            prod = 1, iprod = 1;

            for (int i = 0; i < rate3.size(); ++i)
            {
                rate3[i] = prod * rt[i + 3];
                irate3[i] = iprod * irt[i + 3];
                prod *= irt[i + 3];
                iprod *= rt[i + 3];
            }
        }
    };
    template <static_modint_concept ModT>
    static void ntt_classical_basic4(span<ModT> f)
    { // dif
        static constexpr NttClassicalInfo4<ModT> info;
        int n = f.size(), l = n / 2, n_4b = countr_zero<u32>(n) & 1;
        if (n_4b)
        {
            for (int j = 0; j < l; ++j)
            {
                ModT x = f[j], y = f[j + l];
                f[j] = x + y;
                f[j + l] = x - y;
            }
            l >>= 1;
        }
        for (l /= 2; l >= 1; l /= 4)
        {
            ModT r = 1, img = info.rt[2];

            for (int i = 0, k = 0; i < n; i += l * 4, ++k)
            {
                ModT r2 = r * r, r3 = r2 * r;

                for (int j = 0; j < l; ++j)
                {
                    ModT x0 = f[i + j + 0 * l];
                    ModT x1 = f[i + j + 1 * l] * r;
                    ModT x2 = f[i + j + 2 * l] * r2;
                    ModT x3 = f[i + j + 3 * l] * r3;
                    ModT x1x3 = (x1 - x3) * img;
                    f[i + j + 0 * l] = x0 + x2 + x1 + x3;
                    f[i + j + 1 * l] = x0 + x2 - x1 - x3;
                    f[i + j + 2 * l] = x0 - x2 + x1x3;
                    f[i + j + 3 * l] = x0 - x2 - x1x3;
                }
                r *= info.rate3[countr_one<u32>(k)];
            }
        }
    }
    template <static_modint_concept ModT>
    static void intt_classical_basic4(span<ModT> f)
    { // dit
        static constexpr NttClassicalInfo4<ModT> info;
        int n = f.size(), l = 1, n_4b = countr_zero<u32>(n) & 1;
        for (; l < (n_4b ? n / 2 : n); l *= 4)
        {
            ModT r = 1, img = info.irt[2];

            for (int i = 0, k = 0; i < n; i += l * 4, ++k)
            {
                ModT r2 = r * r, r3 = r2 * r;
                for (int j = 0; j < l; ++j)
                {
                    ModT x0 = f[i + j + 0 * l];
                    ModT x1 = f[i + j + 1 * l];
                    ModT x2 = f[i + j + 2 * l];
                    ModT x3 = f[i + j + 3 * l];
                    ModT x2x3 = (x2 - x3) * img;
                    f[i + j + 0 * l] = x0 + x1 + x2 + x3;
                    f[i + j + 1 * l] = (x0 - x1 + x2x3) * r;
                    f[i + j + 2 * l] = (x0 + x1 - x2 - x3) * r2;
                    f[i + j + 3 * l] = (x0 - x1 - x2x3) * r3;
                }

                r *= info.irate3[countr_one<u32>(k)];
            }
        }

        if (n_4b)
        {
            for (int j = 0; j < l; ++j)
            {
                ModT x = f[j], y = f[j + l];
                f[j] = x + y;
                f[j + l] = x - y;
            }
        }

        const ModT ivn = ModT(n).inv();

        for (int i = 0; i < n; i++)
            f[i] *= ivn;
    }
} // namespace detail
// #include "ntt-twisted-radix-2-avx.hpp"
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
#include <type_traits>
// https://judge.yosupo.jp/submission/92714
#pragma GCC target("avx2")
#include <immintrin.h>
#include <array>
namespace simd
{
    using I256 = __m256i;
    namespace i256
    {
        inline I256 loadu(const I256 *p) { return _mm256_loadu_si256(p); }
        inline I256 load(const I256 *p) { return _mm256_load_si256(p); }
        inline void store(I256 *p, const I256 &v) { _mm256_store_si256(p, v); }
        inline void storeu(I256 *p, const I256 &v) { _mm256_storeu_si256(p, v); }
        template <class T>
        inline auto to_array(const I256 &v)
        {
            constexpr u32 sizeT = sizeof(T);
            static_assert(sizeof(I256) % sizeT == 0);
            alignas(32) array<T, sizeT> arr;
            _mm256_store_si256((I256 *)arr.data(), v);
            return arr;
        }
        inline I256 bit_and(const I256 &a, const I256 &b) { return _mm256_and_si256(a, b); }
    }
    namespace i128x2
    {
        template <int imm>
        inline I256 permute(const I256 &a, const I256 &b)
        {
            return _mm256_permute2x128_si256(a, b, imm);
        }
        template <int imm>
        inline I256 shuffle(const I256 &a)
        {
            return permute<imm>(a, a);
        }
    } // namespace i128x2
    namespace i64x4
    {
        inline I256 add(const I256 &a, const I256 &b)
        {
            return _mm256_add_epi64(a, b);
        }
    } // namespace i64x4
    namespace i32x8
    {
        inline I256 from(int v)
        {
            return _mm256_set1_epi32(v);
        }
        inline I256 add(const I256 &a, const I256 &b)
        {
            return _mm256_add_epi32(a, b);
        }
        inline I256 sub(const I256 &a, const I256 &b)
        {
            return _mm256_sub_epi32(a, b);
        }
        inline I256 mul(const I256 &a, const I256 &b) { return _mm256_mul_epi32(a, b); }
        template <int imm>
        inline I256 shuffle(const I256 &a) { return _mm256_shuffle_epi32(a, imm); }
        template <int imm>
        inline I256 blend(const I256 &a, const I256 &b) { return _mm256_blend_epi32(a, b, imm); }
        inline I256 zero() { return _mm256_setzero_si256(); }
        inline I256 sign(const I256 &a) { return _mm256_cmpgt_epi32(zero(), a); }
        inline pair<I256, I256> mul_0246_1357(const I256 &a, const I256 &b)
        {
            auto x0246 = mul(a, b);
            auto x1357 = mul(shuffle<0b11110101>(a), shuffle<0b11110101>(b));
            return {x0246, x1357};
        }
        inline I256 abs(const I256 &a) { return _mm256_abs_epi32(a); }
    }
    namespace u32x8
    {
        inline I256 mul(const I256 &a, const I256 &b) { return _mm256_mul_epu32(a, b); }
    }
}
namespace simd
{
    template <class ModT>
    struct M32x8
    {
        I256 v;
        M32x8() : v() {}
        M32x8(const I256 &a) : v(a) {}
        template <class S>
        M32x8(const M32x8<S> &a) : v(a.v) {}
        template <class U32>
        M32x8(const array<U32, 8> &a)
        {
            static_assert(sizeof(U32) == 4);
            v = i256::load((const I256 *)a.data());
        }
        template <bool aligned = false>
        static M32x8 load(const I256 *p)
        {
            M32x8 r;

            if constexpr (aligned)
            {
                r = i256::load(p);
            }
            else
            {
                r = i256::loadu(p);
            }

            return r;
        }
        static M32x8 from(int v)
        {
            return i32x8::from(v);
        }
        static M32x8 from(ModT v)
        {
            return from(v.raw());
        }
        inline static I256 Rx8 = i32x8::from(ModT::Space::R);
        inline static I256 IRx8 = i32x8::from(ModT::Space::IR);
        inline static I256 MOD2x8 = i32x8::from(ModT::Space::MOD2);
        inline static I256 MODx8 = i32x8::from(ModT::Space::mod());
        M32x8 &operator+=(const M32x8 &rhs)
        {
            v = i32x8::add(v, rhs.v);
            v = i32x8::sub(v, MOD2x8);
            I256 sign = i32x8::sign(v);
            v = i32x8::add(v, i256::bit_and(sign, MOD2x8));
            return *this;
        }
        M32x8 &operator-=(const M32x8 &rhs)
        {
            v = i32x8::sub(v, rhs.v);
            I256 sign = i32x8::sign(v);
            v = i32x8::add(v, i256::bit_and(sign, MOD2x8));
            return *this;
        }
        static I256 reduce(const I256 &x0246, const I256 &x1357)
        {
            auto km0246 = u32x8::mul(u32x8::mul(x0246, IRx8), MODx8);
            auto km1357 = u32x8::mul(u32x8::mul(x1357, IRx8), MODx8);
            auto z0246 = i64x4::add(x0246, km0246);
            z0246 = i32x8::shuffle<0b11110101>(z0246);
            auto z1357 = i64x4::add(x1357, km1357);
            z1357 = i32x8::shuffle<0b11110101>(z1357);
            return i32x8::blend<0b10101010>(z0246, z1357);
        }
        M32x8 &operator*=(const M32x8 &rhs)
        {
            auto [x0246, x1357] = i32x8::mul_0246_1357(v, rhs.v);
            v = reduce(x0246, x1357);
            return *this;
        }
        friend M32x8 operator+(const M32x8 &lhs, const M32x8 &rhs)
        {
            return M32x8(lhs) += rhs;
        }
        friend M32x8 operator-(const M32x8 &lhs, const M32x8 &rhs)
        {
            return M32x8(lhs) -= rhs;
        }
        friend M32x8 operator*(const M32x8 &lhs, const M32x8 &rhs)
        {
            return M32x8(lhs) *= rhs;
        }
        I256 raw() const
        {
            return v;
        }
        template <int imm>
        M32x8 neg() const
        {
            auto m2 = i32x8::blend<imm>(i32x8::zero(), MOD2x8);
            return i32x8::abs(i32x8::sub(v, m2));
        }
        template <bool aligned = false>
        void store(I256 *p)
        {
            if constexpr (aligned)
            {
                i256::store(p, v);
            }
            else
            {
                i256::storeu(p, v);
            }
        }
        auto to_array() const
        {
            return i256::to_array<u32>(v);
        }
        template <int imm>
        M32x8 shuffle() const
        {
            return i32x8::shuffle<imm>(v);
        }
        template <int imm>
        M32x8 shufflex4() const
        {
            return i128x2::shuffle<imm>(v);
        }
    };
} // namespace simd
namespace detail
{
    template <montgomery_modint_concept ModT>
    struct NttClassicalInfoAvx
    {
        using X8 = simd::M32x8<ModT>;
        using ValueT = typename ModT::ValueT;
        static constexpr ValueT P = ModT::mod();
        static constexpr ValueT g = 3;
        static constexpr int rank2 = countr_zero(P - 1);
        array<ModT, rank2 + 1> rt, irt;
        array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
        array<ModT, max<int>(0, rank2 - 3)> rate4, irate4;
        array<X8, max<int>(0, rank2 - 1)> rate2x8, irate2x8;
        array<X8, max<int>(0, rank2 - 3)> rate4ix8, irate4ix8;
        constexpr NttClassicalInfoAvx()
        {
            rt[rank2] = ModT(g).pow((P - 1) >> rank2);
            irt[rank2] = rt[rank2].inv();

            for (int i = rank2; i >= 1; --i)
            {
                rt[i - 1] = rt[i] * rt[i];
                irt[i - 1] = irt[i] * irt[i];
            }

            {
                ModT prod = 1, iprod = 1;

                for (int i = 0; i < rate2.size(); ++i)
                {
                    rate2[i] = prod * rt[i + 2];
                    irate2[i] = iprod * irt[i + 2];
                    prod *= irt[i + 2];
                    iprod *= rt[i + 2];
                    rate2x8[i] = X8::from(rate2[i]);
                    irate2x8[i] = X8::from(irate2[i]);
                }

                prod = 1, iprod = 1;

                for (int i = 0; i < rate4.size(); ++i)
                {
                    rate4[i] = prod * rt[i + 4];
                    irate4[i] = iprod * irt[i + 4];
                    prod *= irt[i + 4];
                    iprod *= rt[i + 4];
                    array<ModT, 8> buf, ibuf;

                    for (int j = 0; j < 8; ++j)
                    {
                        buf[j] = rate4[i].pow(j);
                        ibuf[j] = irate4[i].pow(j);
                    }

                    rate4ix8[i] = buf;
                    irate4ix8[i] = ibuf;
                }
            }
        }
        template <int L>
        X8 rt_small()
        {
            array<ModT, 8> r;
            fill(r.begin(), r.end(), 1);

            if constexpr (L == 2)
            {
                r[3] = r[7] = rt[2];
            }
            else if constexpr (L == 4)
            {
                for (int i = 5; i < 8; ++i)
                    r[i] = r[i - 1] * rt[3];
            }

            return r;
        }
        template <int L>
        X8 irt_small()
        {
            array<ModT, 8> r;
            fill(r.begin(), r.end(), 1);

            if constexpr (L == 2)
            {
                r[3] = r[7] = irt[2];
            }
            else if constexpr (L == 4)
            {
                for (int i = 5; i < 8; ++i)
                    r[i] = r[i - 1] * irt[3];
            }

            return r;
        }
    };
    template <montgomery_modint_concept ModT, bool aligned>
    static void ntt_classical_avx(span<ModT> f0)
    { // dif
        using X8 = simd::M32x8<ModT>;
        static NttClassicalInfoAvx<ModT> info;
        int n8 = f0.size(), n = n8 / 8;
        assert(n8 % 16 == 0);
        span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
        static X8 rt2 = info.template rt_small<2>();
        static X8 rt4 = info.template rt_small<4>();

        for (int l = n / 2; l >= 1 * 1; l /= 2)
        {
            X8 r = X8::from(ModT(1));

            for (int i = 0, k = 0; i < n; i += l * 2, ++k)
            {
                for (int j = 0; j < l; ++j)
                {
                    X8 fx = X8::template load<aligned>(&f[i + j]);
                    X8 fy = X8::template load<aligned>(&f[i + j + l]) * r;
                    X8 rx = fx + fy;
                    X8 ry = fx - fy;
                    rx.template store<aligned>(&f[i + j]);
                    ry.template store<aligned>(&f[i + j + l]);
                }

                r *= info.rate2x8[countr_one<u32>(k)];
            }
        }

        X8 rti = X8::from(ModT(1));

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi *= rti;
            fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
            fi *= rt4;
            fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
            fi *= rt2;
            fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
            fi.template store<aligned>(&f[i]);
            rti *= info.rate4ix8[countr_one<u32>(i)];
        }
    }
    template <montgomery_modint_concept ModT, bool aligned>
    static void intt_classical_avx(span<ModT> f0)
    { // dit
        using X8 = simd::M32x8<ModT>;
        static NttClassicalInfoAvx<ModT> info;
        int n8 = f0.size(), n = n8 / 8;
        assert(n8 % 16 == 0);
        span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
        static X8 rt2 = info.template irt_small<2>();
        static X8 rt4 = info.template irt_small<4>();
        X8 rti = X8::from(ModT(1));

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
            fi *= rt2;
            fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
            fi *= rt4;
            fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
            fi *= rti;
            fi.template store<aligned>(&f[i]);
            rti *= info.irate4ix8[countr_one<u32>(i)];
        }

        for (i64 l = 1; l < n; l *= 2)
        {
            X8 r = X8::from(ModT(1));

            for (int i = 0, k = 0; i < n; i += l * 2, ++k)
            {
                for (int j = 0; j < l; ++j)
                {
                    X8 fx = X8::template load<aligned>(&f[i + j]);
                    X8 fy = X8::template load<aligned>(&f[i + j + l]);
                    X8 rx = fx + fy;
                    X8 ry = r * (fx - fy);
                    rx.template store<aligned>(&f[i + j]);
                    ry.template store<aligned>(&f[i + j + l]);
                }

                r *= info.irate2x8[countr_one<u32>(k)];
            }
        }

        X8 ivn8 = X8::from(ModT(n8).inv());

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi *= ivn8;
            fi.template store<aligned>(&f[i]);
        }
    }
} // namespace detail
#include <algorithm>
#include <bit>
#include <cassert>
#include <span>
#include <vector>
#include <iostream>
namespace detail
{
    template <montgomery_modint_concept ModT>
    struct NttClassicalInfoAvx4
    {
        using X8 = simd::M32x8<ModT>;
        using ValueT = typename ModT::ValueT;
        static constexpr ValueT P = ModT::mod();
        static constexpr ValueT g = 3;
        static constexpr int rank2 = countr_zero(P - 1);
        array<ModT, rank2 + 1> rt, irt;
        array<ModT, max<int>(0, rank2 - 1)> rate2, irate2;
        array<ModT, max<int>(0, rank2 - 2)> rate3, irate3;
        array<ModT, max<int>(0, rank2 - 3)> rate4, irate4;
        array<X8, max<int>(0, rank2 - 1)> rate2x8, irate2x8;
        array<X8, max<int>(0, rank2 - 2)> rate3x8, irate3x8;
        array<X8, max<int>(0, rank2 - 3)> rate4ix8, irate4ix8;
        constexpr NttClassicalInfoAvx4()
        {
            rt[rank2] = ModT(g).pow((P - 1) >> rank2);
            irt[rank2] = rt[rank2].inv();

            for (int i = rank2; i >= 1; --i)
            {
                rt[i - 1] = rt[i] * rt[i];
                irt[i - 1] = irt[i] * irt[i];
            }

            {
                ModT prod = 1, iprod = 1;

                for (int i = 0; i < rate2.size(); ++i)
                {
                    rate2[i] = prod * rt[i + 2];
                    irate2[i] = iprod * irt[i + 2];
                    prod *= irt[i + 2];
                    iprod *= rt[i + 2];
                    rate2x8[i] = X8::from(rate2[i]);
                    irate2x8[i] = X8::from(irate2[i]);
                }

                prod = 1, iprod = 1;

                for (int i = 0; i < rate3.size(); ++i)
                {
                    rate3[i] = prod * rt[i + 3];
                    irate3[i] = iprod * irt[i + 3];
                    prod *= irt[i + 3];
                    iprod *= rt[i + 3];
                    rate3x8[i] = X8::from(rate3[i]);
                    irate3x8[i] = X8::from(irate3[i]);
                }

                prod = 1, iprod = 1;

                for (int i = 0; i < rate4.size(); ++i)
                {
                    rate4[i] = prod * rt[i + 4];
                    irate4[i] = iprod * irt[i + 4];
                    prod *= irt[i + 4];
                    iprod *= rt[i + 4];
                    array<ModT, 8> buf, ibuf;

                    for (int j = 0; j < 8; ++j)
                    {
                        buf[j] = rate4[i].pow(j);
                        ibuf[j] = irate4[i].pow(j);
                    }

                    rate4ix8[i] = buf;
                    irate4ix8[i] = ibuf;
                }
            }
        }
        template <int L>
        X8 rt_small()
        {
            array<ModT, 8> r;
            fill(r.begin(), r.end(), 1);

            if constexpr (L == 2)
            {
                r[3] = r[7] = rt[2];
            }
            else if constexpr (L == 4)
            {
                for (int i = 5; i < 8; ++i)
                    r[i] = r[i - 1] * rt[3];
            }

            return r;
        }
        template <int L>
        X8 irt_small()
        {
            array<ModT, 8> r;
            fill(r.begin(), r.end(), 1);

            if constexpr (L == 2)
            {
                r[3] = r[7] = irt[2];
            }
            else if constexpr (L == 4)
            {
                for (int i = 5; i < 8; ++i)
                    r[i] = r[i - 1] * irt[3];
            }

            return r;
        }
    };
    template <montgomery_modint_concept ModT, bool aligned>
    static void ntt_classical_avx4(span<ModT> f0)
    { // dif
        using X8 = simd::M32x8<ModT>;
        static NttClassicalInfoAvx4<ModT> info;
        int n8 = f0.size(), n = n8 / 8, l = n / 2, n_4b = countr_zero<u32>(n) & 1;
        assert(n8 % 16 == 0);
        span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
        static X8 rt2 = info.template rt_small<2>();
        static X8 rt4 = info.template rt_small<4>();

        if (n_4b)
        {
            for (int j = 0; j < l; ++j)
            {
                X8 fx = X8::template load<aligned>(&f[j]);
                X8 fy = X8::template load<aligned>(&f[j + l]);
                X8 rx = fx + fy;
                X8 ry = fx - fy;
                rx.template store<aligned>(&f[j]);
                ry.template store<aligned>(&f[j + l]);
            }

            l /= 2;
        }

        for (l /= 2; l >= 1; l /= 4)
        {
            X8 r = X8::from(ModT(1)), img = X8::from(info.rt[2]);

            for (int i = 0, k = 0; i < n; i += l * 4, ++k)
            {
                X8 r2 = r * r, r3 = r2 * r;

                for (int j = 0; j < l; ++j)
                {
                    X8 x0 = X8::template load<aligned>(&f[i + j + 0 * l]);
                    X8 x1 = X8::template load<aligned>(&f[i + j + 1 * l]) * r;
                    X8 x2 = X8::template load<aligned>(&f[i + j + 2 * l]) * r2;
                    X8 x3 = X8::template load<aligned>(&f[i + j + 3 * l]) * r3;
                    X8 x1x3 = (x1 - x3) * img;
                    X8 y0 = x0 + x2 + x1 + x3;
                    X8 y1 = x0 + x2 - x1 - x3;
                    X8 y2 = x0 - x2 + x1x3;
                    X8 y3 = x0 - x2 - x1x3;
                    y0.template store<aligned>(&f[i + j + 0 * l]);
                    y1.template store<aligned>(&f[i + j + 1 * l]);
                    y2.template store<aligned>(&f[i + j + 2 * l]);
                    y3.template store<aligned>(&f[i + j + 3 * l]);
                }

                r *= info.rate3x8[countr_one<u32>(k)];
            }
        }

        X8 rti = X8::from(ModT(1));

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi *= rti;
            fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
            fi *= rt4;
            fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
            fi *= rt2;
            fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
            fi.template store<aligned>(&f[i]);
            rti *= info.rate4ix8[countr_one<u32>(i)];
        }
    }
    template <montgomery_modint_concept ModT, bool aligned>
    static void intt_classical_avx4(span<ModT> f0)
    { // dit
        using X8 = simd::M32x8<ModT>;
        static NttClassicalInfoAvx4<ModT> info;
        int n8 = f0.size(), n = n8 / 8, l = 1, n_4b = countr_zero<u32>(n) & 1;
        assert(n8 % 16 == 0);
        span<simd::I256> f{(simd::I256 *)f0.data(), u32(n)};
        static X8 rt2 = info.template irt_small<2>();
        static X8 rt4 = info.template irt_small<4>();
        X8 rti = X8::from(ModT(1));

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi = fi.template neg<0b10101010>() + fi.template shuffle<0b10110001>();
            fi *= rt2;
            fi = fi.template neg<0b11001100>() + fi.template shuffle<0b01001110>();
            fi *= rt4;
            fi = fi.template neg<0b11110000>() + fi.template shufflex4<0b01>();
            fi *= rti;
            fi.template store<aligned>(&f[i]);
            rti *= info.irate4ix8[countr_one<u32>(i)];
        }

        for (; l < (n_4b ? n / 2 : n); l *= 4)
        {
            X8 r = X8::from(ModT(1)), img = X8::from(info.irt[2]);

            for (int i = 0, k = 0; i < n; i += l * 4, ++k)
            {
                X8 r2 = r * r, r3 = r2 * r;

                for (int j = 0; j < l; ++j)
                {
                    X8 x0 = X8::template load<aligned>(&f[i + j + 0 * l]);
                    X8 x1 = X8::template load<aligned>(&f[i + j + 1 * l]);
                    X8 x2 = X8::template load<aligned>(&f[i + j + 2 * l]);
                    X8 x3 = X8::template load<aligned>(&f[i + j + 3 * l]);
                    X8 x2x3 = (x2 - x3) * img;
                    X8 y0 = x0 + x1 + x2 + x3;
                    X8 y1 = (x0 - x1 + x2x3) * r;
                    X8 y2 = (x0 + x1 - x2 - x3) * r2;
                    X8 y3 = (x0 - x1 - x2x3) * r3;
                    y0.template store<aligned>(&f[i + j + 0 * l]);
                    y1.template store<aligned>(&f[i + j + 1 * l]);
                    y2.template store<aligned>(&f[i + j + 2 * l]);
                    y3.template store<aligned>(&f[i + j + 3 * l]);
                }

                r *= info.irate3x8[countr_one<u32>(k)];
            }
        }

        if (n_4b)
        {
            for (int j = 0; j < l; ++j)
            {
                X8 fx = X8::template load<aligned>(&f[j]);
                X8 fy = X8::template load<aligned>(&f[j + l]);
                X8 rx = fx + fy;
                X8 ry = fx - fy;
                rx.template store<aligned>(&f[j]);
                ry.template store<aligned>(&f[j + l]);
            }
        }

        X8 ivn8 = X8::from(ModT(n8).inv());

        for (int i = 0; i < n; ++i)
        {
            X8 fi = X8::template load<aligned>(&f[i]);
            fi *= ivn8;
            fi.template store<aligned>(&f[i]);
        }
    }
}
template <static_modint_concept ModT>
void ntt_classical(span<ModT> f)
{
    if constexpr (montgomery_modint_concept<ModT>)
    {
        if (f.size() < 16)
            detail::ntt_classical_basic4(f);
        else if (u64(f.data()) & 0x1f)
            detail::ntt_classical_avx4<ModT, false>(f);
        else
            detail::ntt_classical_avx4<ModT, true>(f);
    }
    else if constexpr (raw32_modint_concept<ModT>)
    {
        detail::ntt_classical_basic4(f);
    }
    else
    {
        detail::ntt_classical_basic4(f);
    }
}
template <static_modint_concept ModT>
void intt_classical(span<ModT> f)
{
    if constexpr (montgomery_modint_concept<ModT>)
    {
        if (f.size() < 16)
            detail::intt_classical_basic4(f);
        else if (u64(f.data()) & 0x1f)
            detail::intt_classical_avx4<ModT, false>(f);
        else
            detail::intt_classical_avx4<ModT, true>(f);
    }
    else if constexpr (raw32_modint_concept<ModT>)
    {
        detail::intt_classical_basic4(f);
    }
    else
    {
        detail::intt_classical_basic4(f);
    }
}
template <static_modint_concept ModT>
void ntt(span<ModT> f)
{
    assert(has_single_bit<u32>(f.size()));
    detail::ntt_size += f.size();
    ntt_classical(f);
    // ntt_twisted(f);
}
template <static_modint_concept ModT>
void intt(span<ModT> f)
{
    assert(has_single_bit<u32>(f.size()));
    detail::ntt_size += f.size();
    intt_classical(f);
    // intt_twisted(f);
}
#include <span>
template <static_modint_concept ModT>
static void dot_basic(span<ModT> f, span<const ModT> g, span<ModT> dst)
{
    u32 n = dst.size();

    for (u32 i = 0; i < n; i++)
        dst[i] = f[i] * g[i];
}
template <static_modint_concept ModT>
static void dot_basic(span<ModT> f, span<const ModT> g)
{
    u32 n = f.size();

    for (u32 i = 0; i < n; i++)
        f[i] *= g[i];
}
template <montgomery_modint_concept ModT>
static void dot_avx(span<ModT> f, span<const ModT> g)
{
    u32 n8 = f.size();
    u32 i = 0;
    using X8 = simd::M32x8<ModT>;

    for (; i + 7 < n8; i += 8)
    {
        X8 fi = X8::load((simd::I256 *)&f[i]);
        X8 gi = X8::load((simd::I256 *)&g[i]);
        fi *= gi;
        fi.store((simd::I256 *)&f[i]);
    }

    for (; i < n8; i++)
        f[i] *= g[i];
}
template <montgomery_modint_concept ModT>
static void dot_avx(span<simd::I256> f, span<const ModT> g, span<ModT> dst)
{
    u32 n = dst.size();
    u32 i = 0;
    using X8 = simd::M32x8<ModT>;

    for (; i + 7 < n; i += 8)
    {
        X8 fi = X8::load((simd::I256 *)&f[i]);
        X8 gi = X8::load((simd::I256 *)&g[i]);
        X8 di = fi * gi;
        di.store((simd::I256 *)&dst[i]);
    }

    for (; i < n; i++)
        dst[i] = f[i] * g[i];
}
template <static_modint_concept ModT>
static void dot(span<ModT> f, span<const ModT> g, span<ModT> dst)
{
    if constexpr (montgomery_modint_concept<ModT>)
    {
        dot_avx(f, g, dst);
    }
    else
    {
        dot_basic(f, g, dst);
    }
}
template <static_modint_concept ModT>
static void dot(span<ModT> f, span<const ModT> g)
{
    if constexpr (montgomery_modint_concept<ModT>)
        dot_avx(f, g);
    else
        dot_basic(f, g);
}
using Space = MontgomerySpace<u32, 998244353>;
using ModT = StaticModint<Space>;
main()
{
    FastI fin(stdin);
    FastO fout(stdout);
    u32 n, m;
    fin >> n >> m;
    ++n, ++m;
    u32 L = bit_ceil(n + m - 1);
    ModT *f = new (align_val_t(32)) ModT[L];
    ModT *g = new (align_val_t(32)) ModT[L];
    for (int i = 0; i < n; ++i)
        fin >> f[i];
    for (int i = 0; i < m; ++i)
        fin >> g[i];
    ntt<ModT>({f, L}), ntt<ModT>({g, L});
    dot<ModT>({f, L}, {g, L});
    intt<ModT>({f, L});
    for (int i = 0; i < n + m - 1; ++i)
        fout << f[i] << ' ';
}

CompilationN/AN/ACompile ErrorScore: N/A


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-14 23:07:40 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠