提交记录 28647


用户 题目 状态 得分 用时 内存 语言 代码长度
jrjyy 1002. 测测你的多项式乘法 Accepted 100 152.127 ms 48056 KB C++17 21.46 KB
提交时间 评测时间
2025-10-19 18:12:51 2025-10-19 18:12:55
#include <bits/stdc++.h>

using namespace std;

using u32 = unsigned;
using i64 = long long;
using u64 = unsigned long long;

using i128 = __int128;
using u128 = unsigned __int128;

template <typename T>
constexpr T power(T a, u64 b, T res = T(1)) {
    while (b) {
        if (b % 2) {
            res *= a;
        }
        a *= a;
        b /= 2;
    }
    return res;
}

template <auto P, typename U = decltype(P)>
constexpr enable_if_t<is_unsigned<U>::value, U> addMod(U a, U b) {
    a += b;
    if (make_signed_t<U>(a - P) >= 0) {
        a -= P;
    }
    return a;
}
template <auto P, typename U = decltype(P)>
constexpr enable_if_t<is_unsigned<U>::value, U> subMod(U a, U b) {
    a -= b;
    if (make_signed_t<U>(a) < 0) {
        a += P;
    }
    return a;
}
template <u32 P>
constexpr u32 mulMod(u32 a, u32 b) {
    return u64(a) * b % P;
}
template <u64 P>
constexpr u64 mulMod(u64 a, u64 b) {
    // u64 res = a * b - u64(1.0L * a * b / P - 0.5L) * P;
    // return res % P;
    return u128(a) * b % P;
}
constexpr i64 safeMod(i64 x, i64 m) {
    x %= m;
    if (x < 0) {
        x += m;
    }
    return x;
}

constexpr pair<i64, i64> gcdInv(i64 a, i64 b) {
    a = safeMod(a, b);
    if (a == 0) {
        return {b, 0};
    }

    i64 s = b, t = a, m0 = 0, m1 = 1;
    while (t) {
        i64 u = s / t;
        s -= t * u;
        m0 -= m1 * u;
        swap(s, t);
        swap(m0, m1);
    }

    if (m0 < 0) {
        m0 += b / s;
    }
    return {s, m0};
}

template <typename U, U P, enable_if_t<is_unsigned<U>::value, int> = 0>
class ModIntBase {
    static_assert(P * 2 > P, "ModIntBase: mod is too large");
    // U x;
  public:
    U x;
    constexpr ModIntBase() : x{0} {}
    template <typename T, enable_if_t<is_unsigned<T>::value, int> = 0>
    constexpr ModIntBase(T x_) : x(x_ % mod()) {}
    template <typename T, enable_if_t<is_signed<T>::value, int> = 0>
    constexpr ModIntBase(T x_)
        : x([&]() {
              using S = make_signed_t<U>;
              S v = x_ % S(mod());
              if (v < 0)
                  v += mod();
              return U(v);
          }()) {}
    static constexpr ModIntBase fromNorm(U x) {
        ModIntBase v;
        v.x = x;
        return v;
    }
    constexpr static U mod() {
        return P;
    }
    constexpr U val() const {
        return x;
    }

    inline constexpr ModIntBase operator-() const {
        ModIntBase res;
        res.x = x == 0 ? x : mod() - x;
        return res;
    }
    inline constexpr ModIntBase inv() const {
        auto v = gcdInv(x, mod());
        assert(v.first == 1);
        return ModIntBase::fromNorm(v.second);
    }
    inline constexpr ModIntBase &operator+=(ModIntBase rhs) {
        x = addMod<mod()>(x, rhs.val());
        return *this;
    }
    inline constexpr ModIntBase &operator-=(ModIntBase rhs) {
        x = subMod<mod()>(x, rhs.val());
        return *this;
    }
    inline constexpr ModIntBase &operator*=(ModIntBase rhs) {
        x = mulMod<mod()>(x, rhs.val());
        return *this;
    }
    inline constexpr ModIntBase &operator/=(ModIntBase rhs) {
        return *this *= rhs.inv();
    }
    friend inline constexpr ModIntBase operator+(ModIntBase lhs, ModIntBase rhs) {
        return lhs += rhs;
    }
    friend inline constexpr ModIntBase operator-(ModIntBase lhs, ModIntBase rhs) {
        return lhs -= rhs;
    }
    friend inline constexpr ModIntBase operator*(ModIntBase lhs, ModIntBase rhs) {
        return lhs *= rhs;
    }
    friend inline constexpr ModIntBase operator/(ModIntBase lhs, ModIntBase rhs) {
        return lhs /= rhs;
    }
    friend inline constexpr bool operator==(ModIntBase lhs, ModIntBase rhs) {
        return lhs.val() == rhs.val();
    }
    friend inline constexpr bool operator!=(ModIntBase lhs, ModIntBase rhs) {
        return lhs.val() != rhs.val();
    }
    friend constexpr istream &operator>>(istream &is, ModIntBase &a) {
        i64 x = 0;
        is >> x;
        a = x;
        return is;
    }
    friend constexpr ostream &operator<<(ostream &os, const ModIntBase &a) {
        return os << a.val();
    }
};

template <u32 P>
using ModInt = ModIntBase<u32, P>;
template <u64 P>
using ModInt64 = ModIntBase<u64, P>;

template <typename Z>
struct IComb {
    int n;
    vector<Z> fac_, ifac_, inv_;

    IComb() : n(0), fac_{1}, ifac_{1}, inv_{0} {}
    IComb(int n) : IComb{} {
        init(n);
    }
    static IComb &instance() {
        static IComb comb;
        return comb;
    }
    void init(int m) {
        if (m <= n) {
            return;
        }
        assert(u32(m) < Z::mod());
        fac_.resize(m + 1);
        ifac_.resize(m + 1);
        inv_.resize(m + 1);

        for (int i = n + 1; i <= m; ++i) {
            fac_[i] = fac_[i - 1] * i;
        }
        ifac_[m] = fac_[m].inv();
        for (int i = m; i > n + 1; --i) {
            ifac_[i - 1] = ifac_[i] * i;
        }
        for (int i = m; i > n; --i) {
            inv_[i] = ifac_[i] * fac_[i - 1];
        }
        n = m;
    }

    Z fac(int m) {
        if (n < m) {
            init(2 * m);
        }
        return fac_[m];
    }
    Z ifac(int m) {
        if (n < m) {
            init(2 * m);
        }
        return ifac_[m];
    }
    Z inv(int m) {
        if (n < m) {
            init(2 * m);
        }
        return inv_[m];
    }
    Z binom(int n, int m) {
        if (n < m || m < 0) {
            return Z{};
        }
        return fac(n) * ifac(m) * ifac(n - m);
    }
};

template <typename Z>
struct IComplexZ {
    static Z i2;
    Z a, b;
    IComplexZ(Z a_ = 1) : IComplexZ(a_, 0) {}
    IComplexZ(Z a_, Z b_) : a(a_), b(b_) {}
    IComplexZ &operator*=(const IComplexZ &y) {
        *this = IComplexZ(a * y.a + i2 * b * y.b, a * y.b + b * y.a);
        return *this;
    }
};
template <typename Z>
Z IComplexZ<Z>::i2;

template <typename Z>
vector<Z> sqrtMod(Z n) {
    static mt19937 rnd(random_device{}());
    if (n == 0) {
        return {0};
    }
    if (Z::mod() == 2) {
        return {1};
    }
    if (power(n, (Z::mod() - 1) / 2) != 1) {
        return {};
    }
    Z a;
    do {
        a = rnd() % (Z::mod() - 1) + 1;
    } while (power(a * a - n, (Z::mod() - 1) / 2) == 1);
    IComplexZ<Z>::i2 = a * a - n;
    Z x = power(IComplexZ<Z>(a, 1), (Z::mod() + 1) / 2).a, y = Z::mod() - x;
    if (x.val() > y.val()) {
        swap(x, y);
    }
    if (x == y) {
        return {x};
    }
    return {x, y};
}

namespace Polynomial {
constexpr int bitceil(int x) {
    assert(x >= 0);
    return x <= 1 ? 1 : 2 << __lg(x - 1);
}

template <typename Z>
vector<Z> roots{1};

template <typename Z>
constexpr Z findPrimitiveRoot() {
    for (Z i = 2;; i += 1) {
        if (power(i, (Z::mod() - 1) / 2) != 1) {
            return power(i, (Z::mod() - 1) >> __lg(Z::mod() - 1));
        }
    }
}

template <typename Z>
constexpr Z primitiveRoot = findPrimitiveRoot<Z>();

template <typename Z>
void initRoots(int n) {
    assert((n & -n) == n && ((Z::mod() - 1) & (n - 1)) == 0);
    if (int(roots<Z>.size()) >= n) {
        return;
    }
    roots<Z>.reserve(n);
    while (int(roots<Z>.size()) < n) {
        int l = int(roots<Z>.size());
        roots<Z>.resize(2 * l);
        auto w = roots<Z>.begin() + l;
        w[0] = 1;
        Z x = power(primitiveRoot<Z>, (Z::mod() - 1) / l / 2);
        for (int i = 1; i < l; ++i) {
            w[i] = w[i - 1] * x;
        }
    }
}

template <auto P, typename U = decltype(P)>
inline void applyDft(U &x, U &y, U w) {
    U t = subMod<P>(x, y);
    x = addMod<P>(x, y);
    y = mulMod<P>(t, w);
}
template <auto P, typename U = decltype(P)>
inline void applyIdft(U &x, U &y, U w) {
    auto t = mulMod<P>(y, w);
    y = subMod<P>(x, t);
    x = addMod<P>(x, t);
}
template <typename Z, typename U, void F(U &x, U &y, U w),
    typename enable_if<is_unsigned<U>::value, int>::type = 0>
inline void applyTrans(vector<Z> &a, int l) {
    //     for (auto p = a.begin(), q = p + 2 * l; p != a.end(); p = q, q += 2 * l) {
    // #pragma GCC unroll 4
    //         for (auto i = p, j = p + l, w = roots<Z>.begin() + l; j != q; ++i, ++j, ++w) {
    //             F(i->x, j->x, w->x);
    //         }
    //     }
    if (l == 1) {
#pragma GCC unroll 8
        for (int i = 0; i < int(a.size()); i += 2) {
            F(a[i].x, a[i + 1].x, roots<Z>[1].x);
        }
    } else if (l == 2) {
#pragma GCC unroll 8
        for (int i = 0; i < int(a.size()); i += 4) {
            F(a[i].x, a[i + 2].x, roots<Z>[2].x);
            F(a[i + 1].x, a[i + 3].x, roots<Z>[3].x);
        }
    } else if (l == 4) {
#pragma GCC unroll 4
        for (int i = 0; i < int(a.size()); i += 8) {
            F(a[i].x, a[i + 4].x, roots<Z>[4].x);
            F(a[i + 1].x, a[i + 5].x, roots<Z>[5].x);
            F(a[i + 2].x, a[i + 6].x, roots<Z>[6].x);
            F(a[i + 3].x, a[i + 7].x, roots<Z>[7].x);
        }
    } else {
        for (int i = 0; i < int(a.size()); i += 2 * l) {
#pragma GCC unroll 8
            for (int j = 0; j < l; ++j) {
                F(a[i + j].x, a[i + j + l].x, roots<Z>[l + j].x);
            }
        }
    }
}

template <typename Z, typename U = decltype(Z::mod())>
void dft(vector<Z> &a) {
    const int n = int(a.size());
    initRoots<Z>(n);
    for (int l = n / 2; l; l /= 2) {
        applyTrans<Z, U, applyDft<Z::mod()>>(a, l);
    }
}
template <typename Z, typename U = decltype(Z::mod())>
void idft(vector<Z> &a) {
    const int n = int(a.size());
    initRoots<Z>(n);
    for (int l = 1; l < n; l *= 2) {
        applyTrans<Z, U, applyIdft<Z::mod()>>(a, l);
    }
    reverse(next(a.begin()), a.end());
    Z c = (1 - make_signed_t<U>(Z::mod())) / n;
    for (auto &x : a) {
        x *= c;
    }
}

template <typename Z>
struct IPoly : public vector<Z> {
    IPoly() = default;
    explicit IPoly(size_t n) : vector<Z>(n) {}
    explicit IPoly(const vector<Z> &a) : vector<Z>{a} {}
    IPoly(initializer_list<Z> a) : vector<Z>{a} {}
    template <typename InputIt, typename = _RequireInputIter<InputIt>>
    explicit IPoly(InputIt first, InputIt last) : vector<Z>(first, last) {}
    template <typename F>
    explicit IPoly(size_t n, F &&f) : IPoly(n) {
        generate(this->begin(), this->end(), [&, i = 0]() mutable {
            return f(i++);
        });
    }

    IPoly shift(int k) const {
        if (k >= 0) {
            auto f = *this;
            f.insert(f.begin(), k, 0);
            return f;
        } else if (int(this->size()) <= -k) {
            return IPoly{};
        } else {
            return IPoly(this->begin() + -k, this->end());
        }
    }
    IPoly trunc(int k) const {
        if (k < int(this->size())) {
            return IPoly(this->begin(), this->begin() + k);
        } else {
            auto f = *this;
            f.resize(k);
            return f;
        }
    }
    Z get(int p) const {
        if (p < 0 || int(this->size()) <= p) {
            return 0;
        }
        return (*this)[p];
    }

    IPoly &operator+=(const IPoly &b) {
        if (this->size() < b.size()) {
            this->resize(b.size());
        }
        for (int i = 0; i < int(b.size()); ++i) {
            (*this)[i] += b[i];
        }
        return *this;
    }
    IPoly &operator-=(const IPoly &b) {
        if (this->size() < b.size()) {
            this->resize(b.size());
        }
        for (int i = 0; i < int(b.size()); ++i) {
            (*this)[i] -= b[i];
        }
        return *this;
    }
    friend IPoly operator+(IPoly a, const IPoly &b) {
        return a += b;
    }
    friend IPoly operator-(IPoly a, const IPoly &b) {
        return a -= b;
    }
    friend IPoly operator-(const IPoly &a) {
        IPoly c(a.size());
        for (int i = 0; i < int(a.size()); ++i) {
            c[i] = -a[i];
        }
        return c;
    }
    friend IPoly operator*(const IPoly &f, const IPoly &g) {
        static IPoly a, b;
        if (f.empty() || g.empty()) {
            return IPoly{};
        }
        int n = int(f.size()) + int(g.size()) - 1, len = bitceil(n);
        if (n <= 128) {
            IPoly c(n);
            for (int i = 0; i < int(f.size()); ++i) {
                for (int j = 0; j < int(g.size()); ++j) {
                    c[i + j] += f[i] * g[j];
                }
            }
            return c;
        }
        bool eq = f == g;
        a.resize(len);
        copy(f.begin(), f.end(), a.begin());
        fill(a.begin() + f.size(), a.end(), 0);
        if (!eq) {
            b.resize(len);
            copy(g.begin(), g.end(), b.begin());
            fill(b.begin() + g.size(), b.end(), 0);
        }
        assert(((Z::mod() - 1) & (len - 1)) == 0);
        dft(a);
        if (eq) {
            for (int i = 0; i < len; ++i) {
                a[i] *= a[i];
            }
        } else {
            dft(b);
            for (int i = 0; i < len; ++i) {
                a[i] *= b[i];
            }
        }
        idft(a);
        a.resize(n);
        return a;
    }
    friend IPoly operator*(IPoly a, Z b) {
        for (int i = 0; i < int(a.size()); ++i) {
            a[i] *= b;
        }
        return a;
    }
    friend IPoly operator*(Z a, const IPoly &b) {
        return b * a;
    }
    friend IPoly operator/(IPoly a, Z b) {
        return a * b.inv();
    }
    IPoly &operator*=(const IPoly &b) {
        return *this = *this * b;
    }
    IPoly &operator*=(Z b) {
        return *this = *this * b;
    }
    IPoly &operator/=(Z b) {
        return *this = *this / b;
    }

    IPoly deriv() const {
        if (this->empty()) {
            return IPoly{};
        }
        IPoly f(this->size() - 1);
        for (int i = 1; i < int(this->size()); ++i) {
            f[i - 1] = (*this)[i] * i;
        }
        return f;
    }
    IPoly integr() const {
        IPoly f(this->size() + 1);
        auto &comb = IComb<Z>::instance();
        comb.init(int(this->size()));
        for (int i = 0; i < int(this->size()); ++i) {
            f[i + 1] = (*this)[i] * comb.inv_[i + 1];
        }
        return f;
    }
    IPoly inv() const {
        static IPoly a, b;
        assert(!this->empty());
        int len = bitceil(int(this->size()));
        IPoly res{this->front().inv()};
        res.resize(len);
        a.reserve(len);
        b.reserve(len);
        for (int l = 2; l <= len; l *= 2) {
            a.resize(l);
            copy(this->begin(), this->begin() + min(l, int(this->size())), a.begin());
            dft(a);
            b.resize(l);
            copy(res.begin(), res.begin() + l, b.begin());
            dft(b);
            for (int i = 0; i < l; ++i) {
                a[i] *= b[i];
            }
            idft(a);
            fill(a.begin(), a.begin() + l / 2, 0);
            dft(a);
            for (int i = 0; i < l; ++i) {
                a[i] *= -b[i];
            }
            idft(a);
            copy(a.begin() + l / 2, a.end(), res.begin() + l / 2);
        }
        return res.trunc(int(this->size()));
    }
    IPoly log() const {
        assert(!this->empty() && this->front() == 1);
        return (deriv() * inv()).trunc(int(this->size()) - 1).integr();
    }
    template <typename F>
    IPoly semiRelaxedConv(F &&relax) const {
        static constexpr int Block = 64;
        static IPoly a;
        assert((Block & (Block - 1)) == 0);
        int len = bitceil(int(this->size()));
        IPoly res(len);
        vector<IPoly> d(__lg(len));
        a.reserve(len);
        auto next = [&](int m) {
            int l = m & -m;
            if (l < Block) {
                for (int i = m & -Block; i < m; ++i) {
                    res[m] += res[i] * (*this)[m - i];
                }
            } else {
                a.resize(2 * l);
                copy(res.begin() + m - l, res.begin() + m, a.begin());
                fill(a.begin() + l, a.end(), 0);
                dft(a);
                auto &b = d[__lg(l)];
                if (b.empty()) {
                    b = trunc(2 * l);
                    dft(b);
                }
                for (int i = 0; i < 2 * l; ++i) {
                    a[i] *= b[i];
                }
                idft(a);
                for (int i = m; i < m + l; ++i) {
                    res[i] += a[i - m + l];
                }
            }
            res[m] = relax(m, res[m]);
        };
        for (int i = 0; i < int(this->size()); ++i) {
            next(i);
        }
        return res.trunc(int(this->size()));
    }
    IPoly exp() const {
        assert(!this->empty() && this->front() == 0);
        auto &comb = IComb<Z>::instance();
        comb.init(int(this->size()));
        return deriv().shift(1).semiRelaxedConv([&inv = comb.inv_](int p, Z x) {
            return p == 0 ? 1 : x * inv[p];
        });
    }
    IPoly pow(i64 k) const {
        if (k == 0) {
            return IPoly{1}.trunc(int(this->size()));
        }
        int i = find_if(this->begin(), this->end(), [](Z x) {
            return x != 0;
        }) - this->begin();
        if (i >= (int(this->size()) - 1) / k + 1) {
            return IPoly(int(this->size()));
        }
        Z x = (*this)[i];
        auto f = shift(-i).trunc(int(this->size()) - i * k) / x;
        return (f.log() * k).exp().shift(i * k) * power(x, k);
    }
    IPoly sqrt() const {
        int i = find_if(this->begin(), this->end(), [](Z x) {
            return x != 0;
        }) - this->begin();
        if (i == int(this->size())) {
            return IPoly(this->size());
        }
        if (i % 2) {
            return IPoly{};
        }
        auto f = shift(-i);
        auto sq = sqrtMod(f.front());
        if (sq.empty()) {
            return IPoly{};
        }
        IPoly g{sq.front()};
        int k = 1;
        while (k < int(this->size())) {
            k *= 2;
            g = (g + (f.trunc(k) * g.trunc(k).inv()).trunc(k)) / 2;
        }
        return g.trunc(int(this->size()) - i / 2).shift(i / 2);
    }
    vector<Z> eval(vector<Z> x) const {
        if (this->empty()) {
            return vector<Z>(x.size());
        }

        vector<Z> ans(x.size());
        const int n = int(max(x.size(), this->size()));
        vector<IPoly> s(n << 2);
        auto init = [&](auto self, int u, int l, int r) -> void {
            s[u].reserve(bitceil(r - l));
            if (r - l == 1) {
                return;
            }
            int m = (l + r) / 2;
            self(self, u << 1, l, m);
            self(self, u << 1 | 1, m, r);
        };
        init(init, 1, 0, n);

        x.resize(n);
        auto build = [&](auto self, int u, int l, int r) -> void {
            if (r - l == 1) {
                s[u] = IPoly{1, -x[l]};
                return;
            }
            int m = (l + r) / 2;
            self(self, u << 1, l, m);
            self(self, u << 1 | 1, m, r);
            s[u] = s[u << 1] * s[u << 1 | 1];
        };
        build(build, 1, 0, n);

        auto mulT = [&](const IPoly &a, IPoly b) {
            if (b.empty()) {
                return IPoly{};
            }
            int n = int(b.size());
            reverse(b.begin(), b.end());
            b.resize(a.size());
            dft(b);
            for (int i = 0; i < int(a.size()); ++i) {
                b[i] *= a[i];
            }
            idft(b);
            return b.shift(-(n - 1));
        };
        auto work = [&](auto self, int u, int l, int r, IPoly v) -> void {
            v.resize(r - l);
            if (r - l == 1) {
                if (l < int(ans.size())) {
                    ans[l] = v[0];
                }
                return;
            }
            int m = (l + r) / 2;
            v.resize(bitceil(v.size()));
            dft(v);
            self(self, u << 1, l, m, mulT(v, s[u << 1 | 1]));
            self(self, u << 1 | 1, m, r, mulT(v, s[u << 1]));
        };
        auto d = *this;
        d.resize(bitceil(d.size() + n + 1));
        dft(d);
        work(work, 1, 0, n, mulT(d, s[1].inv()));

        return ans;
    }
};
} // namespace Polynomial

using Polynomial::IPoly;

using Z = ModInt<998244353>;
auto &comb = IComb<Z>::instance();
using Poly = IPoly<Z>;

constexpr int BufSize = 1 << 20;

inline char readChar() {
    static char buf[BufSize], *p1, *p2;
    static streambuf *inbuf = cin.rdbuf();
    return p1 == p2 && (p2 = (p1 = buf) + inbuf->sgetn(buf, BufSize), p1 == p2) ? EOF : *p1++;
}
template <typename T>
void read(T &x) {
    x = 0;
    char c = readChar();
    bool f = false;
    while (!isdigit(c)) {
        f |= c == '-';
        c = readChar();
    }
    while (isdigit(c)) {
        x = x * 10 + (c - '0');
        c = readChar();
    }
    if (f) {
        x *= -1;
    }
}

inline void printChar(char c) {
    static streambuf *outbuf = cout.rdbuf();
    outbuf->sputc(c);
}
template <typename T>
void print(T x) {
    if (x < 0) {
        printChar('-');
        x = -x;
    }
    static array<char, 50> stk;
    int top = 0;
    do {
        stk[top++] = x % 10 + '0';
        x /= 10;
    } while (x > 0);
    while (top--) {
        printChar(stk[top]);
    }
}
inline void print_char(char c) {
    static streambuf *outbuf = cout.rdbuf();
    outbuf->sputc(c);
}

void read(Z &x) {
    read(x.x);
    x.x %= Z::mod();
}
void print(Z x) {
    print(x.val());
}

struct Clock {
    string name;
    chrono::steady_clock::time_point last;
    Clock(const string &name_) : name(name_), last(chrono::steady_clock::now()) {}
    void step(const string &msg, bool update = true) {
        auto now = chrono::steady_clock::now();
        chrono::duration<double> diff = now - last;
        cerr << name << "(" << msg << "): " << diff.count() << "s\n";
        if (update) {
            last = now;
        }
    }
};

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
    ++n, ++m;
    Poly f(n), g(m);
    auto p = (Z *)a, q = (Z *)b, r = (Z *)c;
    copy(p, p + n, f.begin());
    copy(q, q + m, g.begin());
    f *= g;
    copy(f.begin(), f.begin() + n + m - 1, r);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1152.127 ms46 MB + 952 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-10-20 18:36:38 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠