提交记录 21710


用户 题目 状态 得分 用时 内存 语言 代码长度
Xiaohuba 1002. 测测你的多项式乘法 Accepted 100 152.188 ms 52336 KB C++ 5.01 KB
提交时间 评测时间
2024-05-02 20:50:11 2024-05-02 20:50:14
#include <bits/stdc++.h>

using namespace std;
using u32 = uint32_t;
using u64 = uint64_t;

#define _static_helper __attribute__((always_inline)) constexpr inline static
#define _set_op(tp, op, attr)                                                  \
  attr tp operator op(const tp &rhs) const {                                   \
    tp lhs = *this;                                                            \
    return lhs op## = rhs;                                                     \
  }

template <u32 mod> class Modint {
  u32 _v;
  _static_helper u32 __get_inv_r(u32 x) {
    u32 y = x;
    y *= 2ull - x * y, y *= 2ull - x * y, y *= 2ull - x * y, y *= 2ull - x * y;
    return y;
  }
  _static_helper u32 __shrk(u32 x) { return min(x, x - mod); }
  _static_helper u32 __shrk2(u32 x) { return min(x, x - 2 * mod); }
  _static_helper u32 __dilt2(u32 x) { return min(x, x + 2 * mod); }
  _static_helper u32 __reduce(u64 x) {
    return (x + u64(u32(x) * -mod_inv) * mod) >> 32;
  }

public:
  constexpr static inline u32 r2 = (1ull << 32) % mod * (1ull << 32) % mod,
                              mod_inv = __get_inv_r(mod);
  static_assert(mod && (mod < (1u << 30)) && (mod & 1));
  static_assert((mod_inv * mod) == 1u);
  static_assert(__reduce(r2) == (1ull << 32) % mod);

  constexpr Modint(u32 v = 0) : _v(__reduce(u64(v) * r2)) {}
  constexpr Modint &operator+=(Modint rhs) {
    return _v = __shrk2(_v + rhs._v), *this;
  }
  constexpr Modint &operator-=(Modint rhs) {
    return _v = __dilt2(_v - rhs._v), *this;
  }
  constexpr Modint &operator*=(Modint rhs) {
    return _v = __reduce(u64(_v) * rhs._v), *this;
  }
  constexpr Modint pow(u64 y) const {
    Modint ans = 1, x = *this;
    while (y) {
      if (y & 1)
        ans *= x;
      x *= x, y >>= 1;
    }
    return ans;
  }
  constexpr Modint inv() const { return this->pow(mod - 2); }
  constexpr u32 value() const { return __shrk(__reduce(_v)); }
  constexpr Modint &operator/=(Modint rhs) { return (*this) *= rhs.inv(); }
  constexpr operator bool() const { return _v; }
  _set_op(Modint, +, constexpr);
  _set_op(Modint, -, constexpr);
  _set_op(Modint, *, constexpr);
  _set_op(Modint, /, constexpr);
};

template <u32 mod> istream &operator>>(istream &x, Modint<mod> &y) {
  u32 _v;
  return x >> _v, y = _v, x;
}
template <u32 mod> ostream &operator<<(ostream &x, Modint<mod> &y) {
  return x << y.value();
}

using Z = Modint<998244353>;
class Poly : public vector<Z> {
  constexpr static inline Z G = 3;
  _static_helper int lg(int x) { return 31 ^ __builtin_clz(x); }
  _static_helper int ceil2pow(int x) { return 1 << (lg(x - 1) + 1); }
  static inline vector<Z> w = {1}, iw = {1};
  static inline void extend(int new_sz) {
    int cur = w.size();
    if (cur >= new_sz)
      return;
    w.resize(new_sz), iw.resize(new_sz);
    while (cur < new_sz) {
      Z _wn = G.pow(998244352 / cur / 4), _iwn = _wn.inv();
#pragma GCC unroll(4)
      for (int i = 0; i < cur; i++)
        w[i + cur] = w[i] * _wn, iw[i + cur] = iw[i] * _iwn;
      cur <<= 1;
    }
  }
  inline void DFT() {
    int n = this->size();
    extend(n);
    assert(__builtin_ctz(n) == lg(n));
    for (int i = n >> 1; i; i >>= 1)
      for (int j = 0; j < n / i / 2; j++)
#pragma GCC unroll(4)
        for (int k = j * i * 2; k < j * i * 2 + i; k++) {
          Z u = (*this)[k], v = (*this)[k + i] * w[j];
          (*this)[k] = u + v, (*this)[k + i] = u - v;
        }
  }
  inline void IDFT() {
    int n = this->size();
    extend(n);
    assert(__builtin_ctz(n) == lg(n));
    for (int i = 1; i <= n >> 1; i <<= 1)
      for (int j = 0; j < n / i / 2; j++)
#pragma GCC unroll(4)
        for (int k = j * i * 2; k < j * i * 2 + i; k++) {
          Z u = (*this)[k], v = (*this)[k + i];
          (*this)[k] = u + v, (*this)[k + i] = (u - v) * iw[j];
        }
    Z inv = Z(n).inv();
#pragma GCC unroll(4)
    for (int i = 0; i < n; i++)
      (*this)[i] *= inv;
  }

public:
  Poly &operator+=(const Poly &rhs) {
    if (this->size() < rhs.size())
      this->resize(rhs.size(), Z{});
#pragma GCC unroll(4)
    for (int i = 0; i < rhs.size(); i++)
      (*this)[i] += rhs[i];
    return *this;
  }
  Poly &operator-=(const Poly &rhs) {
    if (this->size() < rhs.size())
      this->resize(rhs.size(), Z{});
#pragma GCC unroll(4)
    for (int i = 0; i < rhs.size(); i++)
      (*this)[i] -= rhs[i];
    return *this;
  }
  friend Poly &operator*=(Poly &x, Z num) {
#pragma GCC unroll(4)
    for (auto &i : x)
      i *= num;
    return x;
  }
  Poly &operator*=(const Poly &rhs) {
    int n = ceil2pow(this->size() + rhs.size());
    Poly tmp = rhs;
    this->resize(n), this->DFT();
    tmp.resize(n), tmp.DFT();
#pragma GCC unroll(4)
    for (int i = 0; i < n; i++)
      (*this)[i] *= tmp[i];
    return this->IDFT(), *this;
  }
  _set_op(Poly, +, );
  _set_op(Poly, -, );
  _set_op(Poly, *, );
};

Poly A, B;

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
  A.resize(n + 1), B.resize(m + 1);
  copy(a, a + 1 + n, A.begin());
  copy(b, b + 1 + m, B.begin());
  A *= B;
  for (int i = 0; i <= n + m; i++)
    c[i] = A[i].value();
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1152.188 ms51 MB + 112 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2024-11-23 05:12:54 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠