提交记录 29053


用户 题目 状态 得分 用时 内存 语言 代码长度
saffah_bot 1002. 测测你的多项式乘法 Runtime Error 0 54.378 ms 24588 KB C++ 22.01 KB
提交时间 评测时间
2026-06-27 14:18:20 2026-06-27 14:18:22
#pragma GCC target("avx2")
#include <immintrin.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#ifdef LOCAL_PROFILE
#include <stdio.h>
#include <time.h>
#endif

static const uint32_t MOD = 104857601u;
static const uint32_t TWO_MOD = 209715202u;
static const uint32_t FOUR_MOD = 419430404u;
static const uint32_t EIGHT_MOD = 838860808u;
static const uint32_t G = 3u;
static const uint32_t QINV = 104857599u;
static const uint32_t R2 = 45971250u;
static const uint32_t R = 100663256u;

static inline uint32_t add_mod(uint32_t a, uint32_t b) {
    uint32_t s = a + b;
    return s >= TWO_MOD ? s - TWO_MOD : s;
}

static inline uint32_t sub_mod(uint32_t a, uint32_t b) {
    uint32_t d = a + TWO_MOD - b;
    return d >= TWO_MOD ? d - TWO_MOD : d;
}

static inline uint32_t mont_reduce(uint64_t x) {
    uint32_t y = (uint32_t)x * QINV;
    return (uint32_t)((x + (uint64_t)y * MOD) >> 32);
}

static inline uint32_t mont_mul(uint32_t a, uint32_t b) {
    return mont_reduce((uint64_t)a * b);
}

static inline uint32_t to_mont(uint32_t x) {
    return mont_mul(x, R2);
}

static inline __m256i add_vec(__m256i a, __m256i b) {
    const __m256i mod = _mm256_set1_epi32((int)TWO_MOD);
    const __m256i modm1 = _mm256_set1_epi32((int)(TWO_MOD - 1));
    __m256i s = _mm256_add_epi32(a, b);
    __m256i mask = _mm256_cmpgt_epi32(s, modm1);
    return _mm256_sub_epi32(s, _mm256_and_si256(mask, mod));
}

static inline __m256i sub_vec(__m256i a, __m256i b) {
    const __m256i mod = _mm256_set1_epi32((int)TWO_MOD);
    const __m256i modm1 = _mm256_set1_epi32((int)(TWO_MOD - 1));
    __m256i d = _mm256_sub_epi32(_mm256_add_epi32(a, mod), b);
    __m256i mask = _mm256_cmpgt_epi32(d, modm1);
    return _mm256_sub_epi32(d, _mm256_and_si256(mask, mod));
}

static inline __m256i add_raw(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
static inline __m256i sub2_raw(__m256i a, __m256i b) {
    return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)TWO_MOD)), b);
}
static inline __m256i sub4_raw(__m256i a, __m256i b) {
    return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)FOUR_MOD)), b);
}
static inline __m256i sub8_raw(__m256i a, __m256i b) {
    return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)EIGHT_MOD)), b);
}
static inline __m256i reduce16_to2_vec(__m256i x) {
    const __m256i two = _mm256_set1_epi32((int)TWO_MOD);
    const __m256i four = _mm256_set1_epi32((int)FOUR_MOD);
    const __m256i eight = _mm256_set1_epi32((int)EIGHT_MOD);
    __m256i m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(EIGHT_MOD - 1)));
    x = _mm256_sub_epi32(x, _mm256_and_si256(m, eight));
    m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(FOUR_MOD - 1)));
    x = _mm256_sub_epi32(x, _mm256_and_si256(m, four));
    m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(TWO_MOD - 1)));
    return _mm256_sub_epi32(x, _mm256_and_si256(m, two));
}

static inline __m256i mont_mul_vec(__m256i a, __m256i b) {
    const __m256i mod = _mm256_set1_epi32((int)MOD);
    const __m256i qinv = _mm256_set1_epi32((int)QINV);

    __m256i lo = _mm256_mullo_epi32(a, b);
    __m256i m = _mm256_mullo_epi32(lo, qinv);

    __m256i t0 = _mm256_mul_epu32(a, b);
    __m256i mp0 = _mm256_mul_epu32(m, mod);
    __m256i r0 = _mm256_srli_epi64(_mm256_add_epi64(t0, mp0), 32);

    __m256i a1 = _mm256_srli_epi64(a, 32);
    __m256i b1 = _mm256_srli_epi64(b, 32);
    __m256i m1 = _mm256_srli_epi64(m, 32);
    __m256i t1 = _mm256_mul_epu32(a1, b1);
    __m256i mp1 = _mm256_mul_epu32(m1, mod);
    __m256i r1 = _mm256_srli_epi64(_mm256_add_epi64(t1, mp1), 32);

    return _mm256_or_si256(r0, _mm256_slli_epi64(r1, 32));
}

static inline __m256i reduce_mod_vec(__m256i a) {
    const __m256i mod = _mm256_set1_epi32((int)MOD);
    const __m256i modm1 = _mm256_set1_epi32((int)(MOD - 1));
    __m256i mask = _mm256_cmpgt_epi32(a, modm1);
    return _mm256_sub_epi32(a, _mm256_and_si256(mask, mod));
}

static uint32_t pow_mod(uint32_t a, uint32_t e) {
    uint32_t r = 1;
    while (e) {
        if (e & 1) r = (uint32_t)((uint64_t)r * a % MOD);
        a = (uint32_t)((uint64_t)a * a % MOD);
        e >>= 1;
    }
    return r;
}

static uint32_t roots[1 << 21];

static inline __m256i make_wvec(uint32_t step) {
    uint32_t tmp[8];
    tmp[0] = R;
    for (int i = 1; i < 8; ++i) tmp[i] = mont_mul(tmp[i - 1], step);
    return _mm256_loadu_si256((const __m256i *)tmp);
}

static inline uint32_t step_pow8(uint32_t step) {
    uint32_t r = R;
    for (int i = 0; i < 8; ++i) r = mont_mul(r, step);
    return r;
}

static void build_roots_forward(int n) {
    int off = 0;
    for (int len = n; len >= 8; len >>= 3) {
        int q = len >> 3;
        uint32_t step = to_mont(pow_mod(G, (MOD - 1) / (uint32_t)len));
        if (q >= 8) {
            __m256i w1 = make_wvec(step);
            __m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
            for (int i = 0; i < q; i += 8) {
                __m256i w2 = mont_mul_vec(w1, w1);
                __m256i w3 = mont_mul_vec(w2, w1);
                __m256i w4 = mont_mul_vec(w2, w2);
                __m256i w5 = mont_mul_vec(w4, w1);
                __m256i w6 = mont_mul_vec(w4, w2);
                __m256i w7 = mont_mul_vec(w4, w3);
                _mm256_storeu_si256((__m256i *)(roots + off + i), w1);
                _mm256_storeu_si256((__m256i *)(roots + off + q + i), w2);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 2 + i), w3);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 3 + i), w4);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 4 + i), w5);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 5 + i), w6);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 6 + i), w7);
                w1 = mont_mul_vec(w1, vstep8);
            }
        } else {
            roots[off] = R;
            roots[off + q] = R;
            roots[off + q * 2] = R;
            roots[off + q * 3] = R;
            roots[off + q * 4] = R;
            roots[off + q * 5] = R;
            roots[off + q * 6] = R;
        }
        off += q * 7;
    }
}

static void build_roots_inverse(int n) {
    int off = 0;
    for (int len = 8; len <= n; len <<= 3) {
        int q = len >> 3;
        uint32_t base = pow_mod(G, (MOD - 1) / (uint32_t)len);
        uint32_t step = to_mont(pow_mod(base, MOD - 2));
        if (q >= 8) {
            __m256i w1 = make_wvec(step);
            __m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
            for (int i = 0; i < q; i += 8) {
                __m256i w2 = mont_mul_vec(w1, w1);
                __m256i w3 = mont_mul_vec(w2, w1);
                __m256i w4 = mont_mul_vec(w2, w2);
                __m256i w5 = mont_mul_vec(w4, w1);
                __m256i w6 = mont_mul_vec(w4, w2);
                __m256i w7 = mont_mul_vec(w4, w3);
                _mm256_storeu_si256((__m256i *)(roots + off + i), w1);
                _mm256_storeu_si256((__m256i *)(roots + off + q + i), w2);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 2 + i), w3);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 3 + i), w4);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 4 + i), w5);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 5 + i), w6);
                _mm256_storeu_si256((__m256i *)(roots + off + q * 6 + i), w7);
                w1 = mont_mul_vec(w1, vstep8);
            }
        } else {
            roots[off] = R;
            roots[off + q] = R;
            roots[off + q * 2] = R;
            roots[off + q * 3] = R;
            roots[off + q * 4] = R;
            roots[off + q * 5] = R;
            roots[off + q * 6] = R;
        }
        off += q * 7;
    }
}

static void ntt_forward(uint32_t *a, int n) {
    int off = 0;
    uint32_t z8_plain = pow_mod(G, (MOD - 1) / 8);
    const uint32_t zeta = to_mont(z8_plain);
    const uint32_t imag = mont_mul(zeta, zeta);
    const uint32_t zeta3 = mont_mul(imag, zeta);
    const __m256i vz = _mm256_set1_epi32((int)zeta);
    const __m256i vi = _mm256_set1_epi32((int)imag);
    const __m256i vz3 = _mm256_set1_epi32((int)zeta3);

    for (int len = n; len >= 8; len >>= 3) {
        int q = len >> 3;
        uint32_t *r1 = roots + off;
        uint32_t *r2 = r1 + q;
        uint32_t *r3 = r2 + q;
        uint32_t *r4 = r3 + q;
        uint32_t *r5 = r4 + q;
        uint32_t *r6 = r5 + q;
        uint32_t *r7 = r6 + q;
        for (int i = 0; i < n; i += len) {
            uint32_t *p = a + i;
            int j = 0;
            for (; j + 8 <= q; j += 8) {
                __m256i a0 = _mm256_load_si256((const __m256i *)(p + j));
                __m256i a1 = _mm256_load_si256((const __m256i *)(p + j + q));
                __m256i a2 = _mm256_load_si256((const __m256i *)(p + j + q + q));
                __m256i a3 = _mm256_load_si256((const __m256i *)(p + j + q + q + q));
                __m256i a4 = _mm256_load_si256((const __m256i *)(p + j + q * 4));
                __m256i a5 = _mm256_load_si256((const __m256i *)(p + j + q * 5));
                __m256i a6 = _mm256_load_si256((const __m256i *)(p + j + q * 6));
                __m256i a7 = _mm256_load_si256((const __m256i *)(p + j + q * 7));

                __m256i e04 = add_raw(a0, a4);
                __m256i f04 = sub2_raw(a0, a4);
                __m256i e26 = add_raw(a2, a6);
                __m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
                __m256i E0 = add_raw(e04, e26);
                __m256i E1 = add_raw(f04, f26);
                __m256i E2 = sub4_raw(e04, e26);
                __m256i E3 = sub4_raw(f04, f26);

                __m256i o15 = add_raw(a1, a5);
                __m256i g15 = sub2_raw(a1, a5);
                __m256i o37 = add_raw(a3, a7);
                __m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
                __m256i O0 = add_raw(o15, o37);
                __m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
                __m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
                __m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);

                __m256i w1 = _mm256_loadu_si256((const __m256i *)(r1 + j));
                __m256i w2 = _mm256_loadu_si256((const __m256i *)(r2 + j));
                __m256i w3 = _mm256_loadu_si256((const __m256i *)(r3 + j));
                __m256i w4 = _mm256_loadu_si256((const __m256i *)(r4 + j));
                __m256i w5 = _mm256_loadu_si256((const __m256i *)(r5 + j));
                __m256i w6 = _mm256_loadu_si256((const __m256i *)(r6 + j));
                __m256i w7 = _mm256_loadu_si256((const __m256i *)(r7 + j));

                _mm256_store_si256((__m256i *)(p + j), reduce16_to2_vec(add_raw(E0, O0)));
                _mm256_store_si256((__m256i *)(p + j + q), mont_mul_vec(add_raw(E1, O1), w1));
                _mm256_store_si256((__m256i *)(p + j + q * 2), mont_mul_vec(add_raw(E2, O2), w2));
                _mm256_store_si256((__m256i *)(p + j + q * 3), mont_mul_vec(add_raw(E3, O3), w3));
                _mm256_store_si256((__m256i *)(p + j + q * 4), mont_mul_vec(sub8_raw(E0, O0), w4));
                _mm256_store_si256((__m256i *)(p + j + q * 5), mont_mul_vec(sub8_raw(E1, O1), w5));
                _mm256_store_si256((__m256i *)(p + j + q * 6), mont_mul_vec(sub8_raw(E2, O2), w6));
                _mm256_store_si256((__m256i *)(p + j + q * 7), mont_mul_vec(sub8_raw(E3, O3), w7));
            }
            for (; j < q; ++j) {
                uint32_t a0 = p[j], a1 = p[j + q], a2 = p[j + q * 2], a3 = p[j + q * 3];
                uint32_t a4 = p[j + q * 4], a5 = p[j + q * 5], a6 = p[j + q * 6], a7 = p[j + q * 7];
                uint32_t e04 = add_mod(a0, a4), f04 = sub_mod(a0, a4);
                uint32_t e26 = add_mod(a2, a6), f26 = mont_mul(sub_mod(a2, a6), imag);
                uint32_t E0 = add_mod(e04, e26), E1 = add_mod(f04, f26);
                uint32_t E2 = sub_mod(e04, e26), E3 = sub_mod(f04, f26);
                uint32_t o15 = add_mod(a1, a5), g15 = sub_mod(a1, a5);
                uint32_t o37 = add_mod(a3, a7), g37 = mont_mul(sub_mod(a3, a7), imag);
                uint32_t O0 = add_mod(o15, o37);
                uint32_t O1 = mont_mul(add_mod(g15, g37), zeta);
                uint32_t O2 = mont_mul(sub_mod(o15, o37), imag);
                uint32_t O3 = mont_mul(sub_mod(g15, g37), zeta3);
                p[j] = add_mod(E0, O0);
                p[j + q] = mont_mul(add_mod(E1, O1), r1[j]);
                p[j + q * 2] = mont_mul(add_mod(E2, O2), r2[j]);
                p[j + q * 3] = mont_mul(add_mod(E3, O3), r3[j]);
                p[j + q * 4] = mont_mul(sub_mod(E0, O0), r4[j]);
                p[j + q * 5] = mont_mul(sub_mod(E1, O1), r5[j]);
                p[j + q * 6] = mont_mul(sub_mod(E2, O2), r6[j]);
                p[j + q * 7] = mont_mul(sub_mod(E3, O3), r7[j]);
            }
        }
        off += q * 7;
    }
}

static void ntt_inverse(uint32_t *a, int n) {
    int off = 0;
    uint32_t z8_plain = pow_mod(G, (MOD - 1) / 8);
    uint32_t inv_z8_plain = pow_mod(z8_plain, MOD - 2);
    const uint32_t zeta = to_mont(inv_z8_plain);
    const uint32_t imag = mont_mul(zeta, zeta);
    const uint32_t zeta3 = mont_mul(imag, zeta);
    const __m256i vz = _mm256_set1_epi32((int)zeta);
    const __m256i vi = _mm256_set1_epi32((int)imag);
    const __m256i vz3 = _mm256_set1_epi32((int)zeta3);

    for (int len = 8; len <= n; len <<= 3) {
        int q = len >> 3;
        uint32_t *r1 = roots + off;
        uint32_t *r2 = r1 + q;
        uint32_t *r3 = r2 + q;
        uint32_t *r4 = r3 + q;
        uint32_t *r5 = r4 + q;
        uint32_t *r6 = r5 + q;
        uint32_t *r7 = r6 + q;
        for (int i = 0; i < n; i += len) {
            uint32_t *p = a + i;
            int j = 0;
            for (; j + 8 <= q; j += 8) {
                __m256i a0 = _mm256_load_si256((const __m256i *)(p + j));
                __m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q)),
                                           _mm256_loadu_si256((const __m256i *)(r1 + j)));
                __m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 2)),
                                           _mm256_loadu_si256((const __m256i *)(r2 + j)));
                __m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 3)),
                                           _mm256_loadu_si256((const __m256i *)(r3 + j)));
                __m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 4)),
                                           _mm256_loadu_si256((const __m256i *)(r4 + j)));
                __m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 5)),
                                           _mm256_loadu_si256((const __m256i *)(r5 + j)));
                __m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 6)),
                                           _mm256_loadu_si256((const __m256i *)(r6 + j)));
                __m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 7)),
                                           _mm256_loadu_si256((const __m256i *)(r7 + j)));

                __m256i e04 = add_vec(a0, a4);
                __m256i f04 = sub_vec(a0, a4);
                __m256i e26 = add_vec(a2, a6);
                __m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
                __m256i E0 = add_vec(e04, e26);
                __m256i E1 = add_vec(f04, f26);
                __m256i E2 = sub_vec(e04, e26);
                __m256i E3 = sub_vec(f04, f26);

                __m256i o15 = add_vec(a1, a5);
                __m256i g15 = sub_vec(a1, a5);
                __m256i o37 = add_vec(a3, a7);
                __m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
                __m256i O0 = add_vec(o15, o37);
                __m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
                __m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
                __m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);

                _mm256_store_si256((__m256i *)(p + j), add_vec(E0, O0));
                _mm256_store_si256((__m256i *)(p + j + q), add_vec(E1, O1));
                _mm256_store_si256((__m256i *)(p + j + q * 2), add_vec(E2, O2));
                _mm256_store_si256((__m256i *)(p + j + q * 3), add_vec(E3, O3));
                _mm256_store_si256((__m256i *)(p + j + q * 4), sub_vec(E0, O0));
                _mm256_store_si256((__m256i *)(p + j + q * 5), sub_vec(E1, O1));
                _mm256_store_si256((__m256i *)(p + j + q * 6), sub_vec(E2, O2));
                _mm256_store_si256((__m256i *)(p + j + q * 7), sub_vec(E3, O3));
            }
            for (; j < q; ++j) {
                uint32_t a0 = p[j];
                uint32_t a1 = mont_mul(p[j + q], r1[j]);
                uint32_t a2 = mont_mul(p[j + q * 2], r2[j]);
                uint32_t a3 = mont_mul(p[j + q * 3], r3[j]);
                uint32_t a4 = mont_mul(p[j + q * 4], r4[j]);
                uint32_t a5 = mont_mul(p[j + q * 5], r5[j]);
                uint32_t a6 = mont_mul(p[j + q * 6], r6[j]);
                uint32_t a7 = mont_mul(p[j + q * 7], r7[j]);
                uint32_t e04 = add_mod(a0, a4), f04 = sub_mod(a0, a4);
                uint32_t e26 = add_mod(a2, a6), f26 = mont_mul(sub_mod(a2, a6), imag);
                uint32_t E0 = add_mod(e04, e26), E1 = add_mod(f04, f26);
                uint32_t E2 = sub_mod(e04, e26), E3 = sub_mod(f04, f26);
                uint32_t o15 = add_mod(a1, a5), g15 = sub_mod(a1, a5);
                uint32_t o37 = add_mod(a3, a7), g37 = mont_mul(sub_mod(a3, a7), imag);
                uint32_t O0 = add_mod(o15, o37);
                uint32_t O1 = mont_mul(add_mod(g15, g37), zeta);
                uint32_t O2 = mont_mul(sub_mod(o15, o37), imag);
                uint32_t O3 = mont_mul(sub_mod(g15, g37), zeta3);
                p[j] = add_mod(E0, O0);
                p[j + q] = add_mod(E1, O1);
                p[j + q * 2] = add_mod(E2, O2);
                p[j + q * 3] = add_mod(E3, O3);
                p[j + q * 4] = sub_mod(E0, O0);
                p[j + q * 5] = sub_mod(E1, O1);
                p[j + q * 6] = sub_mod(E2, O2);
                p[j + q * 7] = sub_mod(E3, O3);
            }
        }
        off += q * 7;
    }
}

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
#ifdef LOCAL_PROFILE
    clock_t t0 = clock();
#endif
    int need = n + m + 1;
    int len = 1;
    while (len < need) len <<= 1;

    uint32_t *raw = (uint32_t *)malloc(((size_t)len * 2 + 8) * sizeof(uint32_t));
    uint32_t *fa = (uint32_t *)(((uintptr_t)raw + 31u) & ~(uintptr_t)31u);
    uint32_t *fb = fa + len;

    uint32_t digit[10];
    for (int i = 0; i < 10; ++i) digit[i] = to_mont((uint32_t)i);
    for (int i = 0; i <= n; ++i) fa[i] = digit[a[i]];
    for (int i = 0; i <= m; ++i) fb[i] = digit[b[i]];
    memset(fa + n + 1, 0, ((size_t)len - (size_t)n - 1) * sizeof(uint32_t));
    memset(fb + m + 1, 0, ((size_t)len - (size_t)m - 1) * sizeof(uint32_t));
#ifdef LOCAL_PROFILE
    clock_t t1 = clock();
#endif

    build_roots_forward(len);
#ifdef LOCAL_PROFILE
    clock_t t2 = clock();
#endif
    ntt_forward(fa, len);
#ifdef LOCAL_PROFILE
    clock_t t3 = clock();
#endif
    ntt_forward(fb, len);
#ifdef LOCAL_PROFILE
    clock_t t4 = clock();
#endif

    int i = 0;
    for (; i + 8 <= len; i += 8) {
        __m256i x = _mm256_load_si256((const __m256i *)(fa + i));
        __m256i y = _mm256_load_si256((const __m256i *)(fb + i));
        _mm256_store_si256((__m256i *)(fa + i), mont_mul_vec(x, y));
    }
    for (; i < len; ++i) fa[i] = mont_mul(fa[i], fb[i]);
#ifdef LOCAL_PROFILE
    clock_t t5 = clock();
#endif

    build_roots_inverse(len);
#ifdef LOCAL_PROFILE
    clock_t t6 = clock();
#endif
    ntt_inverse(fa, len);
#ifdef LOCAL_PROFILE
    clock_t t7 = clock();
#endif

    uint32_t inv_n = pow_mod((uint32_t)len, MOD - 2);
    __m256i vinv = _mm256_set1_epi32((int)inv_n);
    i = 0;
    for (; i + 8 <= need; i += 8) {
        __m256i x = _mm256_load_si256((const __m256i *)(fa + i));
        _mm256_store_si256((__m256i *)(c + i), reduce_mod_vec(mont_mul_vec(x, vinv)));
    }
    for (; i < need; ++i) {
        uint32_t v = mont_mul(fa[i], inv_n);
        c[i] = v >= MOD ? v - MOD : v;
    }
#ifdef LOCAL_PROFILE
    clock_t t8 = clock();
    fprintf(stderr,
            "fill %.3f rootsF %.3f fwdA %.3f fwdB %.3f mul %.3f rootsI %.3f inv %.3f out %.3f total %.3f\n",
            1000.0 * (double)(t1 - t0) / CLOCKS_PER_SEC,
            1000.0 * (double)(t2 - t1) / CLOCKS_PER_SEC,
            1000.0 * (double)(t3 - t2) / CLOCKS_PER_SEC,
            1000.0 * (double)(t4 - t3) / CLOCKS_PER_SEC,
            1000.0 * (double)(t5 - t4) / CLOCKS_PER_SEC,
            1000.0 * (double)(t6 - t5) / CLOCKS_PER_SEC,
            1000.0 * (double)(t7 - t6) / CLOCKS_PER_SEC,
            1000.0 * (double)(t8 - t7) / CLOCKS_PER_SEC,
            1000.0 * (double)(t8 - t0) / CLOCKS_PER_SEC);
#endif

    free(raw);
}

#ifdef LOCAL_TEST
#include <stdio.h>
static unsigned aa[1 << 20], bb[1 << 20], cc[1 << 21], dd[4096];
int main() {
    for (int n = 16; n < 32; ++n) {
        for (int m = 16; m < 32; ++m) {
            for (int i = 0; i <= n; ++i) aa[i] = (unsigned)((i * 7 + n) % 10);
            for (int i = 0; i <= m; ++i) bb[i] = (unsigned)((i * 5 + m) % 10);
            poly_multiply(aa, n, bb, m, cc);
            for (int i = 0; i <= n + m; ++i) dd[i] = 0;
            for (int i = 0; i <= n; ++i)
                for (int j = 0; j <= m; ++j)
                    dd[i + j] += aa[i] * bb[j];
            for (int i = 0; i <= n + m; ++i) {
                if (cc[i] != dd[i]) {
                    printf("bad n=%d m=%d i=%d got=%u want=%u\n", n, m, i, cc[i], dd[i]);
                    return 1;
                }
            }
        }
    }
    puts("ok");
    return 0;
}
#endif

#ifdef LOCAL_BENCH
#include <stdio.h>
#include <time.h>
static unsigned aa[1000001], bb[1000001], cc[2000001];
int main() {
    for (int i = 0; i <= 1000000; ++i) {
        aa[i] = (unsigned)((i * 7 + 3) % 10);
        bb[i] = (unsigned)((i * 5 + 1) % 10);
    }
    clock_t st = clock();
    poly_multiply(aa, 1000000, bb, 1000000, cc);
    clock_t ed = clock();
    unsigned long long sample = 0;
    for (int i = 0; i <= 2000000; i += 137) sample += cc[i];
    printf("%.3f ms sample=%llu\n", 1000.0 * (double)(ed - st) / CLOCKS_PER_SEC, sample);
    return 0;
}
#endif

CompilationN/AN/ACompile OKScore: N/A

Testcase #154.378 ms24 MB + 12 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-07-01 06:42:33 | Loaded in 2 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠