#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e6 + 5;
const double PI = acos(-1.0);
struct Complex {
typedef double valueType;
valueType real, imag;
Complex(valueType a = 0, valueType b = 0) : real(a), imag(b) {}
Complex(const Complex &rhs) : real(rhs.real), imag(rhs.imag) {}
~Complex() = default;
Complex conj() const { return {real, -imag}; }
Complex &operator=(const Complex &rhs) {
real = rhs.real, imag = rhs.imag;
return *this;
}
Complex &operator+=(const Complex &rhs) {
real += rhs.real, imag += rhs.imag;
return *this;
}
Complex &operator-=(const Complex &rhs) {
real -= rhs.real, imag -= rhs.imag;
return *this;
}
Complex &operator*=(const Complex &rhs) {
valueType tmp1 = real * rhs.real - imag * rhs.imag,
tmp2 = real * rhs.imag + imag * rhs.real;
real = tmp1, imag = tmp2;
return *this;
}
Complex &operator/=(const Complex &rhs) {
valueType tmp1 = (real * rhs.real + imag * rhs.imag) /
(rhs.real * rhs.real + rhs.imag * rhs.imag),
tmp2 = (imag * rhs.real - real * rhs.imag) /
(rhs.real * rhs.real + rhs.imag * rhs.imag);
real = tmp1, imag = tmp2;
return *this;
}
friend Complex operator+(const Complex &lhs, const Complex &rhs) {
Complex res(lhs);
res += rhs;
return res;
}
friend Complex operator-(const Complex &lhs, const Complex &rhs) {
Complex res(lhs);
res -= rhs;
return res;
}
friend Complex operator*(const Complex &lhs, const Complex &rhs) {
Complex res(lhs);
res *= rhs;
return res;
}
friend Complex operator/(const Complex &lhs, const Complex &rhs) {
Complex res(lhs);
res /= rhs;
return res;
}
} eps[N], a[N];
void init(int n) {
int l = n >> 1;
for (int i = 0; i < l; ++i) {
eps[l + i].real = cos(i * PI / l);
eps[l + i].imag = sin(i * PI / l);
}
for (int i = l - 1; i > 0; --i) eps[i] = eps[i << 1];
}
void idft(int n, Complex x[]) { // dit
for (int i = 0; i < n; ++i) x[i].imag = -x[i].imag;
Complex u, v;
for (int j = 0; j < n; j += 2) {
u = x[j], v = x[j + 1];
x[j] = u + v;
x[j + 1] = u - v;
}
for (int i = 4; i <= n; i <<= 1) {
for (int j = 0, l = i >> 1; j < n; j += i) {
for (int k = 0; k < l; ++k) {
u = x[k + j], v = eps[l + k] * x[k + j + l];
x[k + j] = u + v;
x[k + j + l] = u - v;
}
}
}
double c = 1.0 / n;
for (int i = 0; i < n; ++i) x[i].real *= c, x[i].imag *= -c;
}
void dft(int n, Complex x[]) { // dif
Complex u, v;
for (int i = n; i >= 4; i >>= 1) {
for (int j = 0, l = i >> 1; j < n; j += i) {
for (int k = 0; k < l; ++k) {
u = x[k + j], v = x[k + j + l];
x[k + j] = u + v;
x[k + j + l] = (u - v) * eps[l + k];
}
}
}
for (int j = 0; j < n; j += 2) {
u = x[j], v = x[j + 1];
x[j] = u + v;
x[j + 1] = u - v;
}
}
char t[N << 2];
int main() {
#ifdef LOCAL
freopen("..\\in", "r", stdin), freopen("..\\out", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
const int block = 4; // 压位
cin >> (t + 1);
int len1 = strlen(t + 1); // 字符串长度
for (int i = len1, j = 0; i > 0; i -= block, ++j) {
for (int k = block - 1; ~k; --k) {
if (i - k > 0) {
a[j].real = a[j].real * 10 + t[i - k] - '0';
}
}
}
cin >> (t + 1);
int len2 = strlen(t + 1); // 字符串长度
for (int i = len2, j = 0; i > 0; i -= block, ++j) {
for (int k = block - 1; ~k; --k) {
if (i - k > 0) {
a[j].imag = a[j].imag * 10 + t[i - k] - '0';
}
}
}
int len = 1;
while (len < (len1 + len2 - 1) / block + 1) len <<= 1; // 这边是压位后的长度
init(len);
dft(len, a);
for (int i = 0; i < len; ++i) a[i] *= a[i];
idft(len, a);
int top = 0; // 将t想象成一个栈,存放答案
ll tmp = 0;
for (int i = 0; i < (len1 + len2 - 1) / block || tmp; ++i) {
tmp += static_cast<ll>(a[i].imag / 2.0 + 0.5);
ll j = tmp % 10000; // 压多少位这边多少个0
for (int i = 1; i <= block; ++i) {
t[top++] = j % 10 + '0';
j /= 10;
}
tmp /= 10000; // 压多少位这边多少个0
}
while (top && t[top - 1] == '0') --top; // 去前导0
if(!top)cout<<'0';
while(top)cout<<t[--top];
return 0;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 110.607 ms | 34 MB + 412 KB | Accepted | Score: 100 | 显示更多 |