// http://uoj.ac/submission/206573
// 从 UOJ 上抄代码好开心啊……
#include <math.h>
#define debug(...) fprintf(stderr, __VA_ARGS__)
#define PI 3.14159265358979323846
#define float double
struct Complex {
float x, y;
Complex conj() {
return (Complex){x, -y};
}
Complex conj2() {
return (Complex){-x, y};
}
};
inline Complex operator + (const Complex &a, const Complex &b) {
return (Complex){a.x + b.x, a.y + b.y};
}
inline Complex operator - (const Complex &a, const Complex &b) {
return (Complex){a.x - b.x, a.y - b.y};
}
inline Complex operator * (const Complex &a, const Complex &b) {
return (Complex){a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};
}
Complex A[(1 << 21) + 1], B[(1 << 21) + 1];
int bitrev[1 << 11];
void bitrev_init() {
for (int i = 0; i < 1 << 11; i++) {
int t = 0;
for (int j = 0; j < 11; j++) {
t |= ((i >> j) & 1) << (10 - j);
}
bitrev[i] = t;
}
}
inline int get_bitrev(int x, int len) {
return (bitrev[x >> 11] | (bitrev[x & ((1 << 11) - 1)] << 11)) >> (22 - len);
}
void FFT(Complex *a, int lg_n, bool rev) {
int n = 1 << lg_n;
for (int i = lg_n - 1; i >= 0; i--) {
int S = 1 << i;
Complex w1 = (Complex){cos(PI / S), -sin(PI / S) * (rev ? -1 : 1)};
for (int j = 0; j < n; j += 2 * S) {
Complex w = (Complex){1.0, 0.0};
Complex *A = a + j;
for (int k = 0; k < S; k++) {
Complex t = A[k + S];
A[k + S] = (A[k] - t) * w;
A[k] = A[k] + t;
w = w * w1;
}
}
}
for (int i = 0; i < n; i++) {
int t = get_bitrev(i, lg_n);
if (i < t) {
Complex tmp = a[i];
a[i] = a[t], a[t] = tmp;
}
}
}
void FFT(Complex *a, Complex *b, int lg_n, bool rev) {
int n = 1 << lg_n;
for (int i = lg_n - 1; i >= 0; i--) {
int S = 1 << i;
Complex w1 = (Complex){cos(PI / S), -sin(PI / S) * (rev ? -1 : 1)};
for (int j = 0; j < n; j += 2 * S) {
Complex w = (Complex){1.0, 0.0};
Complex *A = a + j, *B = b + j;
for (int k = 0; k < S; k++) {
Complex ta = A[k + S], tb = B[k + S];
A[k + S] = (A[k] - ta) * w;
B[k + S] = (B[k] - tb) * w;
A[k] = A[k] + ta;
B[k] = B[k] + tb;
w = w * w1;
}
}
}
for (int i = 0; i < n; i++) {
int t = get_bitrev(i, lg_n);
if (i < t) {
Complex tmp = a[i];
a[i] = a[t], a[t] = tmp;
tmp = b[i];
b[i] = b[t], b[t] = tmp;
}
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c) {
int lg_n = 3;
while (1 << lg_n < n + 1 || 1 << lg_n < m + 1) ++lg_n;
bitrev_init();
int N = 1 << lg_n;
int N2 = N >> 1;
for (int i = 0; i <= n; i++) {
// *0.5 is to avoid *2 in later caclulation
(i & 1 ? A[i >> 1].y : A[i >> 1].x) = a[i] * 0.5;
}
for (int i = 0; i <= m; i++) {
(i & 1 ? B[i >> 1].y : B[i >> 1].x) = b[i] * 0.5;
}
FFT(A, B, lg_n, false);
A[N] = A[0];
B[N] = B[0];
const Complex w = (Complex){cos(2 * PI / N), sin(2 * PI / N)};
Complex w_product = (Complex){1, 0};
for (int i = 0; i <= N2; i++) {
Complex a1 = A[i] + A[N - i];
Complex a2 = A[i] - A[N - i];
Complex b1 = B[i] + B[N - i];
Complex b2 = B[i] - B[N - i];
Complex a = (Complex){a1.x, a2.y};
Complex b = (Complex){a2.x, a1.y};
Complex c = (Complex){b1.x, b2.y};
Complex d = (Complex){b2.x, b1.y};
// a * c - b * d * delta(x - 1) + a * d + b * c
// @Signal Processing
A[i] = a * c + a * d + b * c - b * d * w_product.conj();
A[N - i] = (a * c).conj() + a.conj() * d.conj2() + b.conj2() * c.conj() - b.conj2() * d.conj2() * w_product;
w_product = w_product * w;
}
FFT(A, lg_n, true);
float inv_n = 1.0 / N;
for (int i = 0; i <= n + m; i++) {
c[i] = (int)((i & 1 ? A[i >> 1].y : A[i >> 1].x) * inv_n + 0.5);
}
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 197.685 ms | 39 MB + 672 KB | Accepted | Score: 100 | 显示更多 |