提交记录 29029


用户 题目 状态 得分 用时 内存 语言 代码长度
saffah_bot 1002. 测测你的多项式乘法 Accepted 100 239.556 ms 32404 KB C++ 4.20 KB
提交时间 评测时间
2026-06-27 02:23:55 2026-06-27 02:23:56
#include <stdint.h>
#include <stdlib.h>

static const uint32_t MOD = 998244353u;
static const uint32_t G = 3u;

static inline uint32_t add_mod(uint32_t a, uint32_t b) {
    uint32_t s = a + b;
    return s >= MOD ? s - MOD : s;
}

static inline uint32_t sub_mod(uint32_t a, uint32_t b) {
    return a >= b ? a - b : a + MOD - b;
}

static inline uint32_t mul_mod(uint32_t a, uint32_t b) {
    return (uint32_t)((uint64_t)a * b % MOD);
}

static uint32_t pow_mod(uint32_t a, uint32_t e) {
    uint32_t r = 1;
    while (e) {
        if (e & 1) r = mul_mod(r, a);
        a = mul_mod(a, a);
        e >>= 1;
    }
    return r;
}

static uint32_t roots[1 << 21];

static void build_roots(int n) {
    uint32_t root = pow_mod(G, (MOD - 1) / (uint32_t)n);
    roots[0] = 1;
    for (int i = 1; i < n; ++i) roots[i] = mul_mod(roots[i - 1], root);
}

static void ntt_forward_dif(uint32_t *a, int n) {
    for (int len = n; len > 1; len >>= 1) {
        int half = len >> 1;
        int step = n / len;
        for (int i = 0; i < n; i += len) {
            uint32_t *p = a + i;
            uint32_t *w = roots;
            for (int j = 0; j < half; ++j, w += step) {
                uint32_t x = p[j];
                uint32_t y = p[j + half];
                p[j] = add_mod(x, y);
                p[j + half] = mul_mod(sub_mod(x, y), *w);
            }
        }
    }
}

static void ntt_inverse_dit(uint32_t *a, int n) {
    for (int len = 2; len <= n; len <<= 1) {
        int half = len >> 1;
        int step = n / len;
        for (int i = 0; i < n; i += len) {
            uint32_t *p = a + i;
            uint32_t *w = roots;
            for (int j = 0; j < half; ++j, w += step) {
                uint32_t x = p[j];
                uint32_t y = mul_mod(p[j + half], *w);
                p[j] = add_mod(x, y);
                p[j + half] = sub_mod(x, y);
            }
        }
    }
    uint32_t inv_n = pow_mod((uint32_t)n, MOD - 2);
    for (int i = 0; i < n; ++i) a[i] = mul_mod(a[i], inv_n);
}

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
    int need = n + m + 1;
    int len = 1;
    while (len < need) len <<= 1;

    uint32_t *fa = (uint32_t *)calloc((size_t)len, sizeof(uint32_t));
    uint32_t *fb = (uint32_t *)calloc((size_t)len, sizeof(uint32_t));
    for (int i = 0; i <= n; ++i) fa[i] = a[i];
    for (int i = 0; i <= m; ++i) fb[i] = b[i];

    build_roots(len);
    ntt_forward_dif(fa, len);
    ntt_forward_dif(fb, len);
    for (int i = 0; i < len; ++i) fa[i] = mul_mod(fa[i], fb[i]);

    roots[0] = 1;
    uint32_t inv_root = pow_mod(pow_mod(G, (MOD - 1) / (uint32_t)len), MOD - 2);
    for (int i = 1; i < len; ++i) roots[i] = mul_mod(roots[i - 1], inv_root);
    ntt_inverse_dit(fa, len);

    for (int i = 0; i < need; ++i) c[i] = fa[i];
    free(fa);
    free(fb);
}

#ifdef LOCAL_TEST
#include <stdio.h>
#include <time.h>

static unsigned aa[1 << 20], bb[1 << 20], cc[1 << 21], dd[4096];

int main() {
    for (int n = 0; n < 64; ++n) {
        for (int m = 0; m < 64; ++m) {
            for (int i = 0; i <= n; ++i) aa[i] = (unsigned)((i * 7 + n) % 10);
            for (int i = 0; i <= m; ++i) bb[i] = (unsigned)((i * 5 + m) % 10);
            poly_multiply(aa, n, bb, m, cc);
            for (int i = 0; i <= n + m; ++i) dd[i] = 0;
            for (int i = 0; i <= n; ++i)
                for (int j = 0; j <= m; ++j)
                    dd[i + j] += aa[i] * bb[j];
            for (int i = 0; i <= n + m; ++i) {
                if (cc[i] != dd[i]) {
                    printf("bad n=%d m=%d i=%d got=%u want=%u\n", n, m, i, cc[i], dd[i]);
                    return 1;
                }
            }
        }
    }
    puts("ok");
    return 0;
}
#endif

#ifdef LOCAL_BENCH
#include <stdio.h>
#include <time.h>

static unsigned aa[1000001], bb[1000001], cc[2000001];

int main() {
    for (int i = 0; i <= 1000000; ++i) {
        aa[i] = (unsigned)((i * 7 + 3) % 10);
        bb[i] = (unsigned)((i * 5 + 1) % 10);
    }
    clock_t st = clock();
    poly_multiply(aa, 1000000, bb, 1000000, cc);
    clock_t ed = clock();
    unsigned long long sample = 0;
    for (int i = 0; i <= 2000000; i += 137) sample += cc[i];
    printf("%.3f ms sample=%llu\n", 1000.0 * (double)(ed - st) / CLOCKS_PER_SEC, sample);
    return 0;
}
#endif

CompilationN/AN/ACompile OKScore: N/A

Testcase #1239.556 ms31 MB + 660 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-06-28 06:19:25 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠