#include <math.h>
#include <stdint.h>
#include <stdlib.h>
struct Complex {
double r, i;
};
static inline Complex addc(Complex a, Complex b) { return {a.r + b.r, a.i + b.i}; }
static inline Complex subc(Complex a, Complex b) { return {a.r - b.r, a.i - b.i}; }
static inline Complex mulc(Complex a, Complex b) {
return {a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r};
}
static void fft(Complex *a, int n, int invert) {
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) {
Complex t = a[i];
a[i] = a[j];
a[j] = t;
}
}
const double pi = 3.141592653589793238462643383279502884;
for (int len = 2; len <= n; len <<= 1) {
double ang = (invert ? -2.0 : 2.0) * pi / (double)len;
Complex wlen = {cos(ang), sin(ang)};
int half = len >> 1;
for (int i = 0; i < n; i += len) {
Complex w = {1.0, 0.0};
Complex *p = a + i;
for (int j = 0; j < half; ++j) {
Complex u = p[j];
Complex v = mulc(p[j + half], w);
p[j] = addc(u, v);
p[j + half] = subc(u, v);
w = mulc(w, wlen);
}
}
}
if (invert) {
double inv_n = 1.0 / (double)n;
for (int i = 0; i < n; ++i) {
a[i].r *= inv_n;
a[i].i *= inv_n;
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
int need = n + m + 1;
int len = 1;
while (len < need) len <<= 1;
Complex *z = (Complex *)calloc((size_t)len, sizeof(Complex));
for (int i = 0; i <= n; ++i) z[i].r = (double)a[i];
for (int i = 0; i <= m; ++i) z[i].i = (double)b[i];
fft(z, len, 0);
z[0] = {z[0].r * z[0].i, 0.0};
for (int k = 1; k <= (len >> 1); ++k) {
int j = len - k;
Complex x = z[k];
Complex y = {z[j].r, -z[j].i};
Complex ar = {(x.r + y.r) * 0.5, (x.i + y.i) * 0.5};
Complex br = {(x.i - y.i) * 0.5, (y.r - x.r) * 0.5};
Complex prod = mulc(ar, br);
z[k] = prod;
if (k != j) {
z[j] = {prod.r, -prod.i};
}
}
fft(z, len, 1);
for (int i = 0; i < need; ++i) {
c[i] = (unsigned)(z[i].r + 0.5);
}
free(z);
}
#ifdef LOCAL_TEST
#include <stdio.h>
#include <time.h>
static unsigned aa[1 << 20], bb[1 << 20], cc[1 << 21], dd[4096];
int main() {
for (int n = 0; n < 64; ++n) {
for (int m = 0; m < 64; ++m) {
for (int i = 0; i <= n; ++i) aa[i] = (unsigned)((i * 7 + n) % 10);
for (int i = 0; i <= m; ++i) bb[i] = (unsigned)((i * 5 + m) % 10);
poly_multiply(aa, n, bb, m, cc);
for (int i = 0; i <= n + m; ++i) dd[i] = 0;
for (int i = 0; i <= n; ++i)
for (int j = 0; j <= m; ++j)
dd[i + j] += aa[i] * bb[j];
for (int i = 0; i <= n + m; ++i) {
if (cc[i] != dd[i]) {
printf("bad n=%d m=%d i=%d got=%u want=%u\n", n, m, i, cc[i], dd[i]);
return 1;
}
}
}
}
puts("ok");
return 0;
}
#endif
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 277.13 ms | 39 MB + 660 KB | Accepted | Score: 100 | 显示更多 |