#pragma GCC target("avx2")
#include <immintrin.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#ifdef LOCAL_PROFILE
#include <stdio.h>
#include <time.h>
#endif
static const uint32_t MOD = 81788929u;
static const uint32_t TWO_MOD = 163577858u;
static const uint32_t FOUR_MOD = 327155716u;
static const uint32_t EIGHT_MOD = 654311432u;
static const uint32_t G = 7u;
static const uint32_t QINV = 81788927u;
static const uint32_t R2 = 56088131u;
static const uint32_t R = 41942988u;
static inline uint32_t add_mod(uint32_t a, uint32_t b) {
uint32_t s = a + b;
return s >= TWO_MOD ? s - TWO_MOD : s;
}
static inline uint32_t sub_mod(uint32_t a, uint32_t b) {
uint32_t d = a + TWO_MOD - b;
return d >= TWO_MOD ? d - TWO_MOD : d;
}
static inline uint32_t mont_reduce(uint64_t x) {
uint32_t y = (uint32_t)x * QINV;
return (uint32_t)((x + (uint64_t)y * MOD) >> 32);
}
static inline uint32_t mont_mul(uint32_t a, uint32_t b) {
return mont_reduce((uint64_t)a * b);
}
static inline uint32_t to_mont(uint32_t x) {
return mont_mul(x, R2);
}
static inline __m256i add_vec(__m256i a, __m256i b) {
const __m256i mod = _mm256_set1_epi32((int)TWO_MOD);
const __m256i modm1 = _mm256_set1_epi32((int)(TWO_MOD - 1));
__m256i s = _mm256_add_epi32(a, b);
__m256i mask = _mm256_cmpgt_epi32(s, modm1);
return _mm256_sub_epi32(s, _mm256_and_si256(mask, mod));
}
static inline __m256i sub_vec(__m256i a, __m256i b) {
const __m256i mod = _mm256_set1_epi32((int)TWO_MOD);
const __m256i modm1 = _mm256_set1_epi32((int)(TWO_MOD - 1));
__m256i d = _mm256_sub_epi32(_mm256_add_epi32(a, mod), b);
__m256i mask = _mm256_cmpgt_epi32(d, modm1);
return _mm256_sub_epi32(d, _mm256_and_si256(mask, mod));
}
static inline __m256i add_raw(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
static inline __m256i sub2_raw(__m256i a, __m256i b) {
return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)TWO_MOD)), b);
}
static inline __m256i sub4_raw(__m256i a, __m256i b) {
return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)FOUR_MOD)), b);
}
static inline __m256i sub8_raw(__m256i a, __m256i b) {
return _mm256_sub_epi32(_mm256_add_epi32(a, _mm256_set1_epi32((int)EIGHT_MOD)), b);
}
static inline __m256i reduce16_to2_vec(__m256i x) {
const __m256i two = _mm256_set1_epi32((int)TWO_MOD);
const __m256i four = _mm256_set1_epi32((int)FOUR_MOD);
const __m256i eight = _mm256_set1_epi32((int)EIGHT_MOD);
__m256i m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(EIGHT_MOD - 1)));
x = _mm256_sub_epi32(x, _mm256_and_si256(m, eight));
m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(FOUR_MOD - 1)));
x = _mm256_sub_epi32(x, _mm256_and_si256(m, four));
m = _mm256_cmpgt_epi32(x, _mm256_set1_epi32((int)(TWO_MOD - 1)));
return _mm256_sub_epi32(x, _mm256_and_si256(m, two));
}
static inline __m256i mont_mul_vec(__m256i a, __m256i b) {
const __m256i mod = _mm256_set1_epi32((int)MOD);
const __m256i qinv = _mm256_set1_epi32((int)QINV);
__m256i lo = _mm256_mullo_epi32(a, b);
__m256i m = _mm256_mullo_epi32(lo, qinv);
__m256i t0 = _mm256_mul_epu32(a, b);
__m256i mp0 = _mm256_mul_epu32(m, mod);
__m256i r0 = _mm256_srli_epi64(_mm256_add_epi64(t0, mp0), 32);
__m256i a1 = _mm256_srli_epi64(a, 32);
__m256i b1 = _mm256_srli_epi64(b, 32);
__m256i m1 = _mm256_srli_epi64(m, 32);
__m256i t1 = _mm256_mul_epu32(a1, b1);
__m256i mp1 = _mm256_mul_epu32(m1, mod);
__m256i r1 = _mm256_srli_epi64(_mm256_add_epi64(t1, mp1), 32);
return _mm256_or_si256(r0, _mm256_slli_epi64(r1, 32));
}
static inline __m256i reduce_mod_vec(__m256i a) {
const __m256i mod = _mm256_set1_epi32((int)MOD);
const __m256i modm1 = _mm256_set1_epi32((int)(MOD - 1));
__m256i mask = _mm256_cmpgt_epi32(a, modm1);
return _mm256_sub_epi32(a, _mm256_and_si256(mask, mod));
}
static inline __m256i reduce_small_mont_vec(__m256i x) {
const __m256i mod = _mm256_set1_epi32((int)MOD);
const __m256i modm1 = _mm256_set1_epi32((int)(MOD - 1));
__m256i mask = _mm256_cmpgt_epi32(x, modm1);
x = _mm256_sub_epi32(x, _mm256_and_si256(mask, mod));
mask = _mm256_cmpgt_epi32(x, modm1);
x = _mm256_sub_epi32(x, _mm256_and_si256(mask, mod));
mask = _mm256_cmpgt_epi32(x, modm1);
x = _mm256_sub_epi32(x, _mm256_and_si256(mask, mod));
mask = _mm256_cmpgt_epi32(x, modm1);
return _mm256_sub_epi32(x, _mm256_and_si256(mask, mod));
}
static inline void fill_mont_digits(uint32_t *dst, const unsigned *src, int count) {
const __m256i vr = _mm256_set1_epi32((int)R);
int i = 0;
for (; i + 8 <= count; i += 8) {
__m256i x = _mm256_loadu_si256((const __m256i *)(src + i));
x = _mm256_mullo_epi32(x, vr);
_mm256_store_si256((__m256i *)(dst + i), reduce_small_mont_vec(x));
}
for (; i < count; ++i) dst[i] = (uint32_t)((uint64_t)src[i] * R % MOD);
}
static uint32_t pow_mod(uint32_t a, uint32_t e) {
uint32_t r = 1;
while (e) {
if (e & 1) r = (uint32_t)((uint64_t)r * a % MOD);
a = (uint32_t)((uint64_t)a * a % MOD);
e >>= 1;
}
return r;
}
static uint32_t roots[1 << 21];
static inline __m256i make_wvec(uint32_t step) {
uint32_t tmp[8];
tmp[0] = R;
for (int i = 1; i < 8; ++i) tmp[i] = mont_mul(tmp[i - 1], step);
return _mm256_loadu_si256((const __m256i *)tmp);
}
static inline uint32_t step_pow8(uint32_t step) {
uint32_t r = R;
for (int i = 0; i < 8; ++i) r = mont_mul(r, step);
return r;
}
static void build_roots_forward(int n) {
int off = 0;
for (int len = n; len >= 8; len >>= 3) {
int q = len >> 3;
uint32_t step = to_mont(pow_mod(G, (MOD - 1) / (uint32_t)len));
if (q >= 8) {
__m256i w1 = make_wvec(step);
__m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
for (int i = 0; i < q; i += 8) {
__m256i w2 = mont_mul_vec(w1, w1);
__m256i w3 = mont_mul_vec(w2, w1);
__m256i w4 = mont_mul_vec(w2, w2);
__m256i w5 = mont_mul_vec(w4, w1);
__m256i w6 = mont_mul_vec(w4, w2);
__m256i w7 = mont_mul_vec(w4, w3);
_mm256_storeu_si256((__m256i *)(roots + off + i), w1);
_mm256_storeu_si256((__m256i *)(roots + off + q + i), w2);
_mm256_storeu_si256((__m256i *)(roots + off + q * 2 + i), w3);
_mm256_storeu_si256((__m256i *)(roots + off + q * 3 + i), w4);
_mm256_storeu_si256((__m256i *)(roots + off + q * 4 + i), w5);
_mm256_storeu_si256((__m256i *)(roots + off + q * 5 + i), w6);
_mm256_storeu_si256((__m256i *)(roots + off + q * 6 + i), w7);
w1 = mont_mul_vec(w1, vstep8);
}
} else {
roots[off] = R;
roots[off + q] = R;
roots[off + q * 2] = R;
roots[off + q * 3] = R;
roots[off + q * 4] = R;
roots[off + q * 5] = R;
roots[off + q * 6] = R;
}
off += q * 7;
}
}
static void build_roots_inverse(int n) {
int off = 0;
for (int len = 8; len <= n; len <<= 3) {
int q = len >> 3;
uint32_t base = pow_mod(G, (MOD - 1) / (uint32_t)len);
uint32_t step = to_mont(pow_mod(base, MOD - 2));
if (q >= 8) {
__m256i w1 = make_wvec(step);
__m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
for (int i = 0; i < q; i += 8) {
__m256i w2 = mont_mul_vec(w1, w1);
__m256i w3 = mont_mul_vec(w2, w1);
__m256i w4 = mont_mul_vec(w2, w2);
__m256i w5 = mont_mul_vec(w4, w1);
__m256i w6 = mont_mul_vec(w4, w2);
__m256i w7 = mont_mul_vec(w4, w3);
_mm256_storeu_si256((__m256i *)(roots + off + i), w1);
_mm256_storeu_si256((__m256i *)(roots + off + q + i), w2);
_mm256_storeu_si256((__m256i *)(roots + off + q * 2 + i), w3);
_mm256_storeu_si256((__m256i *)(roots + off + q * 3 + i), w4);
_mm256_storeu_si256((__m256i *)(roots + off + q * 4 + i), w5);
_mm256_storeu_si256((__m256i *)(roots + off + q * 5 + i), w6);
_mm256_storeu_si256((__m256i *)(roots + off + q * 6 + i), w7);
w1 = mont_mul_vec(w1, vstep8);
}
} else {
roots[off] = R;
roots[off + q] = R;
roots[off + q * 2] = R;
roots[off + q * 3] = R;
roots[off + q * 4] = R;
roots[off + q * 5] = R;
roots[off + q * 6] = R;
}
off += q * 7;
}
}
static inline void transpose8_epi32(__m256i &r0, __m256i &r1, __m256i &r2, __m256i &r3,
__m256i &r4, __m256i &r5, __m256i &r6, __m256i &r7) {
__m256i t0 = _mm256_unpacklo_epi32(r0, r1);
__m256i t1 = _mm256_unpackhi_epi32(r0, r1);
__m256i t2 = _mm256_unpacklo_epi32(r2, r3);
__m256i t3 = _mm256_unpackhi_epi32(r2, r3);
__m256i t4 = _mm256_unpacklo_epi32(r4, r5);
__m256i t5 = _mm256_unpackhi_epi32(r4, r5);
__m256i t6 = _mm256_unpacklo_epi32(r6, r7);
__m256i t7 = _mm256_unpackhi_epi32(r6, r7);
__m256i u0 = _mm256_unpacklo_epi64(t0, t2);
__m256i u1 = _mm256_unpackhi_epi64(t0, t2);
__m256i u2 = _mm256_unpacklo_epi64(t1, t3);
__m256i u3 = _mm256_unpackhi_epi64(t1, t3);
__m256i u4 = _mm256_unpacklo_epi64(t4, t6);
__m256i u5 = _mm256_unpackhi_epi64(t4, t6);
__m256i u6 = _mm256_unpacklo_epi64(t5, t7);
__m256i u7 = _mm256_unpackhi_epi64(t5, t7);
r0 = _mm256_permute2x128_si256(u0, u4, 0x20);
r1 = _mm256_permute2x128_si256(u1, u5, 0x20);
r2 = _mm256_permute2x128_si256(u2, u6, 0x20);
r3 = _mm256_permute2x128_si256(u3, u7, 0x20);
r4 = _mm256_permute2x128_si256(u0, u4, 0x31);
r5 = _mm256_permute2x128_si256(u1, u5, 0x31);
r6 = _mm256_permute2x128_si256(u2, u6, 0x31);
r7 = _mm256_permute2x128_si256(u3, u7, 0x31);
}
static inline void ntt_len8_forward(uint32_t *a, int n, __m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 64) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + 8));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + 16));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + 24));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + 32));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + 40));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + 48));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + 56));
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
a0 = reduce16_to2_vec(add_raw(E0, O0));
a1 = reduce16_to2_vec(add_raw(E1, O1));
a2 = reduce16_to2_vec(add_raw(E2, O2));
a3 = reduce16_to2_vec(add_raw(E3, O3));
a4 = reduce16_to2_vec(sub8_raw(E0, O0));
a5 = reduce16_to2_vec(sub8_raw(E1, O1));
a6 = reduce16_to2_vec(sub8_raw(E2, O2));
a7 = reduce16_to2_vec(sub8_raw(E3, O3));
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
_mm256_store_si256((__m256i *)(a + base), a0);
_mm256_store_si256((__m256i *)(a + base + 8), a1);
_mm256_store_si256((__m256i *)(a + base + 16), a2);
_mm256_store_si256((__m256i *)(a + base + 24), a3);
_mm256_store_si256((__m256i *)(a + base + 32), a4);
_mm256_store_si256((__m256i *)(a + base + 40), a5);
_mm256_store_si256((__m256i *)(a + base + 48), a6);
_mm256_store_si256((__m256i *)(a + base + 56), a7);
}
}
static inline void dft8_reduce_vec(__m256i &a0, __m256i &a1, __m256i &a2, __m256i &a3,
__m256i &a4, __m256i &a5, __m256i &a6, __m256i &a7,
__m256i vz, __m256i vi, __m256i vz3) {
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
a0 = reduce16_to2_vec(add_raw(E0, O0));
a1 = reduce16_to2_vec(add_raw(E1, O1));
a2 = reduce16_to2_vec(add_raw(E2, O2));
a3 = reduce16_to2_vec(add_raw(E3, O3));
a4 = reduce16_to2_vec(sub8_raw(E0, O0));
a5 = reduce16_to2_vec(sub8_raw(E1, O1));
a6 = reduce16_to2_vec(sub8_raw(E2, O2));
a7 = reduce16_to2_vec(sub8_raw(E3, O3));
}
static inline void ntt_len64_forward(uint32_t *a, int n, uint32_t *rbase,
__m256i vz, __m256i vi, __m256i vz3) {
const __m256i w1 = _mm256_loadu_si256((const __m256i *)(rbase));
const __m256i w2 = _mm256_loadu_si256((const __m256i *)(rbase + 8));
const __m256i w3 = _mm256_loadu_si256((const __m256i *)(rbase + 16));
const __m256i w4 = _mm256_loadu_si256((const __m256i *)(rbase + 24));
const __m256i w5 = _mm256_loadu_si256((const __m256i *)(rbase + 32));
const __m256i w6 = _mm256_loadu_si256((const __m256i *)(rbase + 40));
const __m256i w7 = _mm256_loadu_si256((const __m256i *)(rbase + 48));
for (int base = 0; base < n; base += 64) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + 8));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + 16));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + 24));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + 32));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + 40));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + 48));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + 56));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
a0 = reduce16_to2_vec(add_raw(E0, O0));
a1 = mont_mul_vec(add_raw(E1, O1), w1);
a2 = mont_mul_vec(add_raw(E2, O2), w2);
a3 = mont_mul_vec(add_raw(E3, O3), w3);
a4 = mont_mul_vec(sub8_raw(E0, O0), w4);
a5 = mont_mul_vec(sub8_raw(E1, O1), w5);
a6 = mont_mul_vec(sub8_raw(E2, O2), w6);
a7 = mont_mul_vec(sub8_raw(E3, O3), w7);
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
dft8_reduce_vec(a0, a1, a2, a3, a4, a5, a6, a7, vz, vi, vz3);
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
_mm256_store_si256((__m256i *)(a + base), a0);
_mm256_store_si256((__m256i *)(a + base + 8), a1);
_mm256_store_si256((__m256i *)(a + base + 16), a2);
_mm256_store_si256((__m256i *)(a + base + 24), a3);
_mm256_store_si256((__m256i *)(a + base + 32), a4);
_mm256_store_si256((__m256i *)(a + base + 40), a5);
_mm256_store_si256((__m256i *)(a + base + 48), a6);
_mm256_store_si256((__m256i *)(a + base + 56), a7);
}
}
static inline void ntt_len64_inverse(uint32_t *a, uint32_t *mul, int n, uint32_t *rbase,
__m256i vz, __m256i vi, __m256i vz3) {
const __m256i w1 = _mm256_loadu_si256((const __m256i *)(rbase));
const __m256i w2 = _mm256_loadu_si256((const __m256i *)(rbase + 8));
const __m256i w3 = _mm256_loadu_si256((const __m256i *)(rbase + 16));
const __m256i w4 = _mm256_loadu_si256((const __m256i *)(rbase + 24));
const __m256i w5 = _mm256_loadu_si256((const __m256i *)(rbase + 32));
const __m256i w6 = _mm256_loadu_si256((const __m256i *)(rbase + 40));
const __m256i w7 = _mm256_loadu_si256((const __m256i *)(rbase + 48));
for (int base = 0; base < n; base += 64) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + 8));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + 16));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + 24));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + 32));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + 40));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + 48));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + 56));
if (mul) {
a0 = mont_mul_vec(a0, _mm256_load_si256((const __m256i *)(mul + base)));
a1 = mont_mul_vec(a1, _mm256_load_si256((const __m256i *)(mul + base + 8)));
a2 = mont_mul_vec(a2, _mm256_load_si256((const __m256i *)(mul + base + 16)));
a3 = mont_mul_vec(a3, _mm256_load_si256((const __m256i *)(mul + base + 24)));
a4 = mont_mul_vec(a4, _mm256_load_si256((const __m256i *)(mul + base + 32)));
a5 = mont_mul_vec(a5, _mm256_load_si256((const __m256i *)(mul + base + 40)));
a6 = mont_mul_vec(a6, _mm256_load_si256((const __m256i *)(mul + base + 48)));
a7 = mont_mul_vec(a7, _mm256_load_si256((const __m256i *)(mul + base + 56)));
}
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
dft8_reduce_vec(a0, a1, a2, a3, a4, a5, a6, a7, vz, vi, vz3);
transpose8_epi32(a0, a1, a2, a3, a4, a5, a6, a7);
a1 = mont_mul_vec(a1, w1);
a2 = mont_mul_vec(a2, w2);
a3 = mont_mul_vec(a3, w3);
a4 = mont_mul_vec(a4, w4);
a5 = mont_mul_vec(a5, w5);
a6 = mont_mul_vec(a6, w6);
a7 = mont_mul_vec(a7, w7);
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + base), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + 8), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + 16), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + 24), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + base + 32), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + 40), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + 48), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + 56), sub_vec(E3, O3));
}
}
static inline void ntt_len512_forward(uint32_t *a, int n, uint32_t *r512, uint32_t *r64,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 512) {
for (int j = 0; j < 64; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + j + 64));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + j + 128));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + j + 192));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + j + 256));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + j + 320));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + j + 384));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + j + 448));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w1 = _mm256_loadu_si256((const __m256i *)(r512 + j));
__m256i w2 = _mm256_loadu_si256((const __m256i *)(r512 + 64 + j));
__m256i w3 = _mm256_loadu_si256((const __m256i *)(r512 + 128 + j));
__m256i w4 = _mm256_loadu_si256((const __m256i *)(r512 + 192 + j));
__m256i w5 = _mm256_loadu_si256((const __m256i *)(r512 + 256 + j));
__m256i w6 = _mm256_loadu_si256((const __m256i *)(r512 + 320 + j));
__m256i w7 = _mm256_loadu_si256((const __m256i *)(r512 + 384 + j));
_mm256_store_si256((__m256i *)(a + base + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(a + base + 64 + j), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(a + base + 128 + j), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(a + base + 192 + j), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(a + base + 256 + j), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(a + base + 320 + j), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(a + base + 384 + j), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(a + base + 448 + j), mont_mul_vec(sub8_raw(E3, O3), w7));
}
for (int t = 0; t < 8; ++t) ntt_len64_forward(a + base + t * 64, 64, r64, vz, vi, vz3);
}
}
static inline void ntt_len512_inverse(uint32_t *a, uint32_t *mul, int n, uint32_t *r64, uint32_t *r512,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 512) {
for (int t = 0; t < 8; ++t) {
ntt_len64_inverse(a + base + t * 64, mul ? mul + base + t * 64 : 0, 64, r64, vz, vi, vz3);
}
for (int j = 0; j < 64; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 64 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + j)));
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 128 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 64 + j)));
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 192 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 128 + j)));
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 256 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 192 + j)));
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 320 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 256 + j)));
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 384 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 320 + j)));
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 448 + j)),
_mm256_loadu_si256((const __m256i *)(r512 + 384 + j)));
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + base + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 64), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 128), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 192), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + base + j + 256), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 320), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 384), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 448), sub_vec(E3, O3));
}
}
}
static inline void ntt_len4096_forward(uint32_t *a, int n, uint32_t *r4096, uint32_t *r512,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 4096) {
for (int j = 0; j < 512; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + j + 512));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + j + 1024));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + j + 1536));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + j + 2048));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + j + 2560));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + j + 3072));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + j + 3584));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w1 = _mm256_loadu_si256((const __m256i *)(r4096 + j));
__m256i w2 = _mm256_loadu_si256((const __m256i *)(r4096 + 512 + j));
__m256i w3 = _mm256_loadu_si256((const __m256i *)(r4096 + 1024 + j));
__m256i w4 = _mm256_loadu_si256((const __m256i *)(r4096 + 1536 + j));
__m256i w5 = _mm256_loadu_si256((const __m256i *)(r4096 + 2048 + j));
__m256i w6 = _mm256_loadu_si256((const __m256i *)(r4096 + 2560 + j));
__m256i w7 = _mm256_loadu_si256((const __m256i *)(r4096 + 3072 + j));
_mm256_store_si256((__m256i *)(a + base + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(a + base + 512 + j), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(a + base + 1024 + j), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(a + base + 1536 + j), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(a + base + 2048 + j), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(a + base + 2560 + j), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(a + base + 3072 + j), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(a + base + 3584 + j), mont_mul_vec(sub8_raw(E3, O3), w7));
}
for (int t = 0; t < 8; ++t) ntt_len512_forward(a + base + t * 512, 512, r512, r512 + 448, vz, vi, vz3);
}
}
static inline void ntt_len4096_inverse(uint32_t *a, uint32_t *mul, int n, uint32_t *r512, uint32_t *r4096,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 4096) {
for (int t = 0; t < 8; ++t) {
ntt_len512_inverse(a + base + t * 512, mul ? mul + base + t * 512 : 0, 512, r512 - 56, r512, vz, vi, vz3);
}
for (int j = 0; j < 512; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 512 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + j)));
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 1024 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 512 + j)));
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 1536 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 1024 + j)));
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 2048 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 1536 + j)));
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 2560 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 2048 + j)));
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 3072 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 2560 + j)));
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 3584 + j)),
_mm256_loadu_si256((const __m256i *)(r4096 + 3072 + j)));
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + base + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 512), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 1024), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 1536), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + base + j + 2048), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 2560), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 3072), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 3584), sub_vec(E3, O3));
}
}
}
static inline void ntt_len32768_forward(uint32_t *a, int n, uint32_t *r32768, uint32_t *r4096,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 32768) {
for (int j = 0; j < 4096; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + j + 4096));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + j + 8192));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + j + 12288));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + j + 16384));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + j + 20480));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + j + 24576));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + j + 28672));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w1 = _mm256_loadu_si256((const __m256i *)(r32768 + j));
__m256i w2 = _mm256_loadu_si256((const __m256i *)(r32768 + 4096 + j));
__m256i w3 = _mm256_loadu_si256((const __m256i *)(r32768 + 8192 + j));
__m256i w4 = _mm256_loadu_si256((const __m256i *)(r32768 + 12288 + j));
__m256i w5 = _mm256_loadu_si256((const __m256i *)(r32768 + 16384 + j));
__m256i w6 = _mm256_loadu_si256((const __m256i *)(r32768 + 20480 + j));
__m256i w7 = _mm256_loadu_si256((const __m256i *)(r32768 + 24576 + j));
_mm256_store_si256((__m256i *)(a + base + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(a + base + 4096 + j), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(a + base + 8192 + j), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(a + base + 12288 + j), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(a + base + 16384 + j), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(a + base + 20480 + j), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(a + base + 24576 + j), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(a + base + 28672 + j), mont_mul_vec(sub8_raw(E3, O3), w7));
}
for (int t = 0; t < 8; ++t) {
ntt_len4096_forward(a + base + t * 4096, 4096, r4096, r4096 + 3584, vz, vi, vz3);
}
}
}
static inline void ntt_len32768_inverse(uint32_t *a, uint32_t *mul, int n, uint32_t *r512, uint32_t *r4096, uint32_t *r32768,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 32768) {
for (int t = 0; t < 8; ++t) {
ntt_len4096_inverse(a + base + t * 4096, mul ? mul + base + t * 4096 : 0, 4096, r512, r4096, vz, vi, vz3);
}
for (int j = 0; j < 4096; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 4096 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + j)));
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 8192 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 4096 + j)));
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 12288 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 8192 + j)));
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 16384 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 12288 + j)));
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 20480 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 16384 + j)));
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 24576 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 20480 + j)));
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 28672 + j)),
_mm256_loadu_si256((const __m256i *)(r32768 + 24576 + j)));
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + base + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 4096), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 8192), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 12288), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + base + j + 16384), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 20480), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 24576), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 28672), sub_vec(E3, O3));
}
}
}
static inline void ntt_len262144_forward(uint32_t *a, int n, uint32_t *r262144, uint32_t *r32768,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 262144) {
for (int j = 0; j < 32768; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + base + j + 32768));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + base + j + 65536));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + base + j + 98304));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + base + j + 131072));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + base + j + 163840));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + base + j + 196608));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + base + j + 229376));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w1 = _mm256_loadu_si256((const __m256i *)(r262144 + j));
__m256i w2 = _mm256_loadu_si256((const __m256i *)(r262144 + 32768 + j));
__m256i w3 = _mm256_loadu_si256((const __m256i *)(r262144 + 65536 + j));
__m256i w4 = _mm256_loadu_si256((const __m256i *)(r262144 + 98304 + j));
__m256i w5 = _mm256_loadu_si256((const __m256i *)(r262144 + 131072 + j));
__m256i w6 = _mm256_loadu_si256((const __m256i *)(r262144 + 163840 + j));
__m256i w7 = _mm256_loadu_si256((const __m256i *)(r262144 + 196608 + j));
_mm256_store_si256((__m256i *)(a + base + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(a + base + 32768 + j), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(a + base + 65536 + j), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(a + base + 98304 + j), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(a + base + 131072 + j), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(a + base + 163840 + j), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(a + base + 196608 + j), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(a + base + 229376 + j), mont_mul_vec(sub8_raw(E3, O3), w7));
}
for (int t = 0; t < 8; ++t) {
ntt_len32768_forward(a + base + t * 32768, 32768, r32768, r32768 + 28672, vz, vi, vz3);
}
}
}
static inline void ntt_len262144_inverse(uint32_t *a, uint32_t *mul, int n, uint32_t *r512, uint32_t *r4096,
uint32_t *r32768, uint32_t *r262144,
__m256i vz, __m256i vi, __m256i vz3) {
for (int base = 0; base < n; base += 262144) {
for (int t = 0; t < 8; ++t) {
ntt_len32768_inverse(a + base + t * 32768, mul + base + t * 32768, 32768, r512, r4096, r32768, vz, vi, vz3);
}
for (int j = 0; j < 32768; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + base + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 32768 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + j)));
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 65536 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 32768 + j)));
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 98304 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 65536 + j)));
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 131072 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 98304 + j)));
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 163840 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 131072 + j)));
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 196608 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 163840 + j)));
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + base + 229376 + j)),
_mm256_loadu_si256((const __m256i *)(r262144 + 196608 + j)));
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + base + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 32768), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 65536), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 98304), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + base + j + 131072), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + base + j + 163840), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + base + j + 196608), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + base + j + 229376), sub_vec(E3, O3));
}
}
}
static inline void ntt_len2097152_forward(uint32_t *a, uint32_t *r262144,
__m256i vz, __m256i vi, __m256i vz3) {
uint32_t step = to_mont(pow_mod(G, (MOD - 1) / 2097152u));
__m256i w1 = make_wvec(step);
__m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
for (int j = 0; j < 262144; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(a + j + 262144));
__m256i a2 = _mm256_load_si256((const __m256i *)(a + j + 524288));
__m256i a3 = _mm256_load_si256((const __m256i *)(a + j + 786432));
__m256i a4 = _mm256_load_si256((const __m256i *)(a + j + 1048576));
__m256i a5 = _mm256_load_si256((const __m256i *)(a + j + 1310720));
__m256i a6 = _mm256_load_si256((const __m256i *)(a + j + 1572864));
__m256i a7 = _mm256_load_si256((const __m256i *)(a + j + 1835008));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w2 = mont_mul_vec(w1, w1);
__m256i w3 = mont_mul_vec(w2, w1);
__m256i w4 = mont_mul_vec(w2, w2);
__m256i w5 = mont_mul_vec(w4, w1);
__m256i w6 = mont_mul_vec(w4, w2);
__m256i w7 = mont_mul_vec(w4, w3);
_mm256_store_si256((__m256i *)(a + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(a + 262144 + j), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(a + 524288 + j), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(a + 786432 + j), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(a + 1048576 + j), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(a + 1310720 + j), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(a + 1572864 + j), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(a + 1835008 + j), mont_mul_vec(sub8_raw(E3, O3), w7));
w1 = mont_mul_vec(w1, vstep8);
}
for (int t = 0; t < 8; ++t) {
ntt_len262144_forward(a + t * 262144, 262144, r262144, r262144 + 229376, vz, vi, vz3);
}
}
static inline void ntt_len2097152_inverse(uint32_t *a, uint32_t *mul, uint32_t *r512, uint32_t *r4096,
uint32_t *r32768, uint32_t *r262144,
__m256i vz, __m256i vi, __m256i vz3) {
for (int t = 0; t < 8; ++t) {
ntt_len262144_inverse(a + t * 262144, mul + t * 262144, 262144, r512, r4096, r32768, r262144, vz, vi, vz3);
}
uint32_t base_root = pow_mod(G, (MOD - 1) / 2097152u);
uint32_t step = to_mont(pow_mod(base_root, MOD - 2));
__m256i w1v = make_wvec(step);
__m256i vstep8 = _mm256_set1_epi32((int)step_pow8(step));
for (int j = 0; j < 262144; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(a + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 262144 + j)),
w1v);
__m256i w2 = mont_mul_vec(w1v, w1v);
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 524288 + j)),
w2);
__m256i w3 = mont_mul_vec(w2, w1v);
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 786432 + j)),
w3);
__m256i w4 = mont_mul_vec(w2, w2);
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 1048576 + j)),
w4);
__m256i w5 = mont_mul_vec(w4, w1v);
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 1310720 + j)),
w5);
__m256i w6 = mont_mul_vec(w4, w2);
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 1572864 + j)),
w6);
__m256i w7 = mont_mul_vec(w4, w3);
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(a + 1835008 + j)),
w7);
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(a + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + j + 262144), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + j + 524288), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + j + 786432), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(a + j + 1048576), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(a + j + 1310720), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(a + j + 1572864), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(a + j + 1835008), sub_vec(E3, O3));
w1v = mont_mul_vec(w1v, vstep8);
}
}
static void ntt_forward(uint32_t *a, int n) {
int off = 0;
uint32_t z8_plain = pow_mod(G, (MOD - 1) / 8);
const uint32_t zeta = to_mont(z8_plain);
const uint32_t imag = mont_mul(zeta, zeta);
const uint32_t zeta3 = mont_mul(imag, zeta);
const __m256i vz = _mm256_set1_epi32((int)zeta);
const __m256i vi = _mm256_set1_epi32((int)imag);
const __m256i vz3 = _mm256_set1_epi32((int)zeta3);
if (n == 2097152) {
ntt_len2097152_forward(a, roots, vz, vi, vz3);
return;
}
for (int len = n; len >= 8; len >>= 3) {
if (len == 262144) {
ntt_len262144_forward(a, n, roots + off, roots + off + 229376, vz, vi, vz3);
break;
}
if (len == 32768) {
ntt_len32768_forward(a, n, roots + off, roots + off + 28672, vz, vi, vz3);
break;
}
if (len == 4096) {
ntt_len4096_forward(a, n, roots + off, roots + off + 3584, vz, vi, vz3);
break;
}
if (len == 512) {
ntt_len512_forward(a, n, roots + off, roots + off + 448, vz, vi, vz3);
break;
}
if (len == 64) {
ntt_len64_forward(a, n, roots + off, vz, vi, vz3);
break;
}
if (len == 8) {
ntt_len8_forward(a, n, vz, vi, vz3);
break;
}
int q = len >> 3;
uint32_t *r1 = roots + off;
uint32_t *r2 = r1 + q;
uint32_t *r3 = r2 + q;
uint32_t *r4 = r3 + q;
uint32_t *r5 = r4 + q;
uint32_t *r6 = r5 + q;
uint32_t *r7 = r6 + q;
for (int i = 0; i < n; i += len) {
uint32_t *p = a + i;
int j = 0;
for (; j + 8 <= q; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(p + j));
__m256i a1 = _mm256_load_si256((const __m256i *)(p + j + q));
__m256i a2 = _mm256_load_si256((const __m256i *)(p + j + q + q));
__m256i a3 = _mm256_load_si256((const __m256i *)(p + j + q + q + q));
__m256i a4 = _mm256_load_si256((const __m256i *)(p + j + q * 4));
__m256i a5 = _mm256_load_si256((const __m256i *)(p + j + q * 5));
__m256i a6 = _mm256_load_si256((const __m256i *)(p + j + q * 6));
__m256i a7 = _mm256_load_si256((const __m256i *)(p + j + q * 7));
__m256i e04 = add_raw(a0, a4);
__m256i f04 = sub2_raw(a0, a4);
__m256i e26 = add_raw(a2, a6);
__m256i f26 = mont_mul_vec(sub2_raw(a2, a6), vi);
__m256i E0 = add_raw(e04, e26);
__m256i E1 = add_raw(f04, f26);
__m256i E2 = sub4_raw(e04, e26);
__m256i E3 = sub4_raw(f04, f26);
__m256i o15 = add_raw(a1, a5);
__m256i g15 = sub2_raw(a1, a5);
__m256i o37 = add_raw(a3, a7);
__m256i g37 = mont_mul_vec(sub2_raw(a3, a7), vi);
__m256i O0 = add_raw(o15, o37);
__m256i O1 = mont_mul_vec(add_raw(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub4_raw(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub4_raw(g15, g37), vz3);
__m256i w1 = _mm256_loadu_si256((const __m256i *)(r1 + j));
__m256i w2 = _mm256_loadu_si256((const __m256i *)(r2 + j));
__m256i w3 = _mm256_loadu_si256((const __m256i *)(r3 + j));
__m256i w4 = _mm256_loadu_si256((const __m256i *)(r4 + j));
__m256i w5 = _mm256_loadu_si256((const __m256i *)(r5 + j));
__m256i w6 = _mm256_loadu_si256((const __m256i *)(r6 + j));
__m256i w7 = _mm256_loadu_si256((const __m256i *)(r7 + j));
_mm256_store_si256((__m256i *)(p + j), reduce16_to2_vec(add_raw(E0, O0)));
_mm256_store_si256((__m256i *)(p + j + q), mont_mul_vec(add_raw(E1, O1), w1));
_mm256_store_si256((__m256i *)(p + j + q * 2), mont_mul_vec(add_raw(E2, O2), w2));
_mm256_store_si256((__m256i *)(p + j + q * 3), mont_mul_vec(add_raw(E3, O3), w3));
_mm256_store_si256((__m256i *)(p + j + q * 4), mont_mul_vec(sub8_raw(E0, O0), w4));
_mm256_store_si256((__m256i *)(p + j + q * 5), mont_mul_vec(sub8_raw(E1, O1), w5));
_mm256_store_si256((__m256i *)(p + j + q * 6), mont_mul_vec(sub8_raw(E2, O2), w6));
_mm256_store_si256((__m256i *)(p + j + q * 7), mont_mul_vec(sub8_raw(E3, O3), w7));
}
for (; j < q; ++j) {
uint32_t a0 = p[j], a1 = p[j + q], a2 = p[j + q * 2], a3 = p[j + q * 3];
uint32_t a4 = p[j + q * 4], a5 = p[j + q * 5], a6 = p[j + q * 6], a7 = p[j + q * 7];
uint32_t e04 = add_mod(a0, a4), f04 = sub_mod(a0, a4);
uint32_t e26 = add_mod(a2, a6), f26 = mont_mul(sub_mod(a2, a6), imag);
uint32_t E0 = add_mod(e04, e26), E1 = add_mod(f04, f26);
uint32_t E2 = sub_mod(e04, e26), E3 = sub_mod(f04, f26);
uint32_t o15 = add_mod(a1, a5), g15 = sub_mod(a1, a5);
uint32_t o37 = add_mod(a3, a7), g37 = mont_mul(sub_mod(a3, a7), imag);
uint32_t O0 = add_mod(o15, o37);
uint32_t O1 = mont_mul(add_mod(g15, g37), zeta);
uint32_t O2 = mont_mul(sub_mod(o15, o37), imag);
uint32_t O3 = mont_mul(sub_mod(g15, g37), zeta3);
p[j] = add_mod(E0, O0);
p[j + q] = mont_mul(add_mod(E1, O1), r1[j]);
p[j + q * 2] = mont_mul(add_mod(E2, O2), r2[j]);
p[j + q * 3] = mont_mul(add_mod(E3, O3), r3[j]);
p[j + q * 4] = mont_mul(sub_mod(E0, O0), r4[j]);
p[j + q * 5] = mont_mul(sub_mod(E1, O1), r5[j]);
p[j + q * 6] = mont_mul(sub_mod(E2, O2), r6[j]);
p[j + q * 7] = mont_mul(sub_mod(E3, O3), r7[j]);
}
}
off += q * 7;
}
}
static void ntt_inverse(uint32_t *a, int n, uint32_t *mul) {
int off = 0;
uint32_t z8_plain = pow_mod(G, (MOD - 1) / 8);
uint32_t inv_z8_plain = pow_mod(z8_plain, MOD - 2);
const uint32_t zeta = to_mont(inv_z8_plain);
const uint32_t imag = mont_mul(zeta, zeta);
const uint32_t zeta3 = mont_mul(imag, zeta);
const __m256i vz = _mm256_set1_epi32((int)zeta);
const __m256i vi = _mm256_set1_epi32((int)imag);
const __m256i vz3 = _mm256_set1_epi32((int)zeta3);
if (n == 2097152) {
ntt_len2097152_inverse(a, mul, roots + 63, roots + 511, roots + 4095, roots + 32767, vz, vi, vz3);
return;
}
int start_len = 8;
if (n >= 262144) {
ntt_len262144_inverse(a, mul, n, roots + 63, roots + 511, roots + 4095, roots + 32767, vz, vi, vz3);
off = 262143;
start_len = 2097152;
} else if (n >= 32768) {
ntt_len32768_inverse(a, mul, n, roots + 63, roots + 511, roots + 4095, vz, vi, vz3);
off = 32767;
start_len = 262144;
} else if (n >= 4096) {
ntt_len4096_inverse(a, mul, n, roots + 63, roots + 511, vz, vi, vz3);
off = 4095;
start_len = 32768;
} else if (n >= 512) {
ntt_len512_inverse(a, mul, n, roots + 7, roots + 63, vz, vi, vz3);
off = 511;
start_len = 4096;
} else if (n >= 64) {
ntt_len64_inverse(a, mul, n, roots + 7, vz, vi, vz3);
off = 63;
start_len = 512;
}
for (int len = start_len; len <= n; len <<= 3) {
if (len == 8) {
ntt_len8_forward(a, n, vz, vi, vz3);
off += 7;
continue;
}
int q = len >> 3;
uint32_t *r1 = roots + off;
uint32_t *r2 = r1 + q;
uint32_t *r3 = r2 + q;
uint32_t *r4 = r3 + q;
uint32_t *r5 = r4 + q;
uint32_t *r6 = r5 + q;
uint32_t *r7 = r6 + q;
for (int i = 0; i < n; i += len) {
uint32_t *p = a + i;
int j = 0;
for (; j + 8 <= q; j += 8) {
__m256i a0 = _mm256_load_si256((const __m256i *)(p + j));
__m256i a1 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q)),
_mm256_loadu_si256((const __m256i *)(r1 + j)));
__m256i a2 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 2)),
_mm256_loadu_si256((const __m256i *)(r2 + j)));
__m256i a3 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 3)),
_mm256_loadu_si256((const __m256i *)(r3 + j)));
__m256i a4 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 4)),
_mm256_loadu_si256((const __m256i *)(r4 + j)));
__m256i a5 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 5)),
_mm256_loadu_si256((const __m256i *)(r5 + j)));
__m256i a6 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 6)),
_mm256_loadu_si256((const __m256i *)(r6 + j)));
__m256i a7 = mont_mul_vec(_mm256_load_si256((const __m256i *)(p + j + q * 7)),
_mm256_loadu_si256((const __m256i *)(r7 + j)));
__m256i e04 = add_vec(a0, a4);
__m256i f04 = sub_vec(a0, a4);
__m256i e26 = add_vec(a2, a6);
__m256i f26 = mont_mul_vec(sub_vec(a2, a6), vi);
__m256i E0 = add_vec(e04, e26);
__m256i E1 = add_vec(f04, f26);
__m256i E2 = sub_vec(e04, e26);
__m256i E3 = sub_vec(f04, f26);
__m256i o15 = add_vec(a1, a5);
__m256i g15 = sub_vec(a1, a5);
__m256i o37 = add_vec(a3, a7);
__m256i g37 = mont_mul_vec(sub_vec(a3, a7), vi);
__m256i O0 = add_vec(o15, o37);
__m256i O1 = mont_mul_vec(add_vec(g15, g37), vz);
__m256i O2 = mont_mul_vec(sub_vec(o15, o37), vi);
__m256i O3 = mont_mul_vec(sub_vec(g15, g37), vz3);
_mm256_store_si256((__m256i *)(p + j), add_vec(E0, O0));
_mm256_store_si256((__m256i *)(p + j + q), add_vec(E1, O1));
_mm256_store_si256((__m256i *)(p + j + q * 2), add_vec(E2, O2));
_mm256_store_si256((__m256i *)(p + j + q * 3), add_vec(E3, O3));
_mm256_store_si256((__m256i *)(p + j + q * 4), sub_vec(E0, O0));
_mm256_store_si256((__m256i *)(p + j + q * 5), sub_vec(E1, O1));
_mm256_store_si256((__m256i *)(p + j + q * 6), sub_vec(E2, O2));
_mm256_store_si256((__m256i *)(p + j + q * 7), sub_vec(E3, O3));
}
for (; j < q; ++j) {
uint32_t a0 = p[j];
uint32_t a1 = mont_mul(p[j + q], r1[j]);
uint32_t a2 = mont_mul(p[j + q * 2], r2[j]);
uint32_t a3 = mont_mul(p[j + q * 3], r3[j]);
uint32_t a4 = mont_mul(p[j + q * 4], r4[j]);
uint32_t a5 = mont_mul(p[j + q * 5], r5[j]);
uint32_t a6 = mont_mul(p[j + q * 6], r6[j]);
uint32_t a7 = mont_mul(p[j + q * 7], r7[j]);
uint32_t e04 = add_mod(a0, a4), f04 = sub_mod(a0, a4);
uint32_t e26 = add_mod(a2, a6), f26 = mont_mul(sub_mod(a2, a6), imag);
uint32_t E0 = add_mod(e04, e26), E1 = add_mod(f04, f26);
uint32_t E2 = sub_mod(e04, e26), E3 = sub_mod(f04, f26);
uint32_t o15 = add_mod(a1, a5), g15 = sub_mod(a1, a5);
uint32_t o37 = add_mod(a3, a7), g37 = mont_mul(sub_mod(a3, a7), imag);
uint32_t O0 = add_mod(o15, o37);
uint32_t O1 = mont_mul(add_mod(g15, g37), zeta);
uint32_t O2 = mont_mul(sub_mod(o15, o37), imag);
uint32_t O3 = mont_mul(sub_mod(g15, g37), zeta3);
p[j] = add_mod(E0, O0);
p[j + q] = add_mod(E1, O1);
p[j + q * 2] = add_mod(E2, O2);
p[j + q * 3] = add_mod(E3, O3);
p[j + q * 4] = sub_mod(E0, O0);
p[j + q * 5] = sub_mod(E1, O1);
p[j + q * 6] = sub_mod(E2, O2);
p[j + q * 7] = sub_mod(E3, O3);
}
}
off += q * 7;
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
#ifdef LOCAL_PROFILE
clock_t t0 = clock();
#endif
int need = n + m + 1;
int len = 1 << 21;
uint32_t *raw = (uint32_t *)malloc(((size_t)len * 2 + 8 + 58368) * sizeof(uint32_t));
uint32_t *fa = (uint32_t *)((((uintptr_t)raw + 31u) & ~(uintptr_t)31u) + 118784u);
uint32_t *fb = fa + len + 28672;
fill_mont_digits(fa, a, n + 1);
fill_mont_digits(fb, b, m + 1);
#ifdef LOCAL_PROFILE
clock_t t1 = clock();
#endif
build_roots_forward(262144);
#ifdef LOCAL_PROFILE
clock_t t2 = clock();
#endif
ntt_forward(fa, len);
#ifdef LOCAL_PROFILE
clock_t t3 = clock();
#endif
ntt_forward(fb, len);
#ifdef LOCAL_PROFILE
clock_t t4 = clock();
#endif
#ifdef LOCAL_PROFILE
clock_t t5 = clock();
#endif
build_roots_inverse(262144);
#ifdef LOCAL_PROFILE
clock_t t6 = clock();
#endif
ntt_inverse(fa, len, fb);
#ifdef LOCAL_PROFILE
clock_t t7 = clock();
#endif
uint32_t inv_n = 81788890u;
__m256i vinv = _mm256_set1_epi32((int)inv_n);
int i = 0;
i = 0;
for (; i + 8 <= need; i += 8) {
__m256i x = _mm256_load_si256((const __m256i *)(fa + i));
_mm256_storeu_si256((__m256i *)(c + i), reduce_mod_vec(mont_mul_vec(x, vinv)));
}
for (; i < need; ++i) {
uint32_t v = mont_mul(fa[i], inv_n);
c[i] = v >= MOD ? v - MOD : v;
}
#ifdef LOCAL_PROFILE
clock_t t8 = clock();
fprintf(stderr,
"fillAB %.3f rootsF %.3f fwdA %.3f fwdB %.3f gap %.3f rootsI %.3f invmul %.3f out %.3f total %.3f\n",
1000.0 * (double)(t1 - t0) / CLOCKS_PER_SEC,
1000.0 * (double)(t2 - t1) / CLOCKS_PER_SEC,
1000.0 * (double)(t3 - t2) / CLOCKS_PER_SEC,
1000.0 * (double)(t4 - t3) / CLOCKS_PER_SEC,
1000.0 * (double)(t5 - t4) / CLOCKS_PER_SEC,
1000.0 * (double)(t6 - t5) / CLOCKS_PER_SEC,
1000.0 * (double)(t7 - t6) / CLOCKS_PER_SEC,
1000.0 * (double)(t8 - t7) / CLOCKS_PER_SEC,
1000.0 * (double)(t8 - t0) / CLOCKS_PER_SEC);
#endif
}
#ifdef LOCAL_TEST
#include <stdio.h>
static unsigned aa[1 << 20], bb[1 << 20], cc[1 << 21], dd[4096];
int main() {
for (int n = 16; n < 32; ++n) {
for (int m = 16; m < 32; ++m) {
for (int i = 0; i <= n; ++i) aa[i] = (unsigned)((i * 7 + n) % 10);
for (int i = 0; i <= m; ++i) bb[i] = (unsigned)((i * 5 + m) % 10);
poly_multiply(aa, n, bb, m, cc);
for (int i = 0; i <= n + m; ++i) dd[i] = 0;
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= m; ++j)
dd[i + j] += aa[i] * bb[j];
for (int i = 0; i <= n + m; ++i) {
if (cc[i] != dd[i]) {
printf("bad n=%d m=%d i=%d got=%u want=%u\n", n, m, i, cc[i], dd[i]);
return 1;
}
}
}
}
puts("ok");
return 0;
}
#endif
#ifdef LOCAL_BENCH
#include <stdio.h>
#include <time.h>
static unsigned aa[1000001], bb[1000001], cc[2000001];
int main() {
for (int i = 0; i <= 1000000; ++i) {
aa[i] = (unsigned)((i * 7 + 3) % 10);
bb[i] = (unsigned)((i * 5 + 1) % 10);
}
clock_t st = clock();
poly_multiply(aa, 1000000, bb, 1000000, cc);
clock_t ed = clock();
unsigned long long sample = 0;
for (int i = 0; i <= 2000000; i += 137) sample += cc[i];
printf("%.3f ms sample=%llu\n", 1000.0 * (double)(ed - st) / CLOCKS_PER_SEC, sample);
return 0;
}
#endif
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 33.065 ms | 24 MB + 672 KB | Accepted | Score: 100 | 显示更多 |