#include <math.h>
#include <stdint.h>
#include <stdlib.h>
struct C {
float r, i;
};
static inline C addc(C a, C b) { return {a.r + b.r, a.i + b.i}; }
static inline C subc(C a, C b) { return {a.r - b.r, a.i - b.i}; }
static inline C mulc(C a, C b) { return {a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r}; }
static void fft(C *a, int n, int inv) {
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) {
C t = a[i];
a[i] = a[j];
a[j] = t;
}
}
const float pi = 3.14159265358979323846f;
for (int len = 2; len <= n; len <<= 1) {
float ang = (inv ? -2.0f : 2.0f) * pi / (float)len;
C wl = {cosf(ang), sinf(ang)};
int h = len >> 1;
for (int i = 0; i < n; i += len) {
C w = {1.0f, 0.0f};
for (int j = 0; j < h; ++j) {
C u = a[i + j];
C v = mulc(a[i + j + h], w);
a[i + j] = addc(u, v);
a[i + j + h] = subc(u, v);
w = mulc(w, wl);
}
}
}
if (inv) {
float z = 1.0f / (float)n;
for (int i = 0; i < n; ++i) {
a[i].r *= z;
a[i].i *= z;
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
int need = n + m + 1;
int len = 1;
while (len < need) len <<= 1;
C *f = (C *)calloc((size_t)len, sizeof(C));
for (int i = 0; i <= n; ++i) f[i].r = (float)a[i];
for (int i = 0; i <= m; ++i) f[i].i = (float)b[i];
fft(f, len, 0);
f[0] = {f[0].r * f[0].i, 0.0f};
for (int k = 1; k <= (len >> 1); ++k) {
int j = len - k;
C x = f[k];
C y = {f[j].r, -f[j].i};
C A = {(x.r + y.r) * 0.5f, (x.i + y.i) * 0.5f};
C B = {(x.i - y.i) * 0.5f, (y.r - x.r) * 0.5f};
C P = mulc(A, B);
f[k] = P;
if (k != j) f[j] = {P.r, -P.i};
}
fft(f, len, 1);
for (int i = 0; i < need; ++i) c[i] = (unsigned)(f[i].r + 0.5f);
free(f);
}
#ifdef LOCAL_TEST
#include <stdio.h>
static unsigned aa[1 << 20], bb[1 << 20], cc[1 << 21], dd[4096];
int main() {
for (int n = 0; n < 64; ++n) {
for (int m = 0; m < 64; ++m) {
for (int i = 0; i <= n; ++i) aa[i] = (unsigned)((i * 7 + n) % 10);
for (int i = 0; i <= m; ++i) bb[i] = (unsigned)((i * 5 + m) % 10);
poly_multiply(aa, n, bb, m, cc);
for (int i = 0; i <= n + m; ++i) dd[i] = 0;
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= m; ++j) dd[i + j] += aa[i] * bb[j];
for (int i = 0; i <= n + m; ++i) {
if (cc[i] != dd[i]) {
printf("bad n=%d m=%d i=%d got=%u want=%u\n", n, m, i, cc[i], dd[i]);
return 1;
}
}
}
}
puts("ok");
return 0;
}
#endif
#ifdef LOCAL_BENCH
#include <stdio.h>
#include <time.h>
static unsigned aa[1000001], bb[1000001], cc[2000001];
static unsigned direct_at(int pos) {
long long l = pos - 1000000;
if (l < 0) l = 0;
long long r = pos;
if (r > 1000000) r = 1000000;
unsigned long long s = 0;
for (long long i = l; i <= r; ++i) s += (unsigned long long)aa[i] * bb[pos - i];
return (unsigned)s;
}
int main() {
for (int i = 0; i <= 1000000; ++i) {
aa[i] = (unsigned)((i * 7 + 3) % 10);
bb[i] = (unsigned)((i * 5 + 1) % 10);
}
clock_t st = clock();
poly_multiply(aa, 1000000, bb, 1000000, cc);
clock_t ed = clock();
int bad = 0;
for (int pos = 0; pos <= 2000000; pos += 123457) {
unsigned want = direct_at(pos);
if (cc[pos] != want) {
printf("bad pos=%d got=%u want=%u\n", pos, cc[pos], want);
++bad;
}
}
unsigned long long sample = 0;
for (int i = 0; i <= 2000000; i += 137) sample += cc[i];
printf("%.3f ms sample=%llu bad=%d\n", 1000.0 * (double)(ed - st) / CLOCKS_PER_SEC, sample, bad);
return bad != 0;
}
#endif