#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("fma")
#pragma GCC target("avx")
namespace hint_simd
{
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(4096) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
T *data()
{
return reinterpret_cast<T *>(ary);
}
const T *data() const
{
return reinterpret_cast<const T *>(ary);
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *cast_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
// Use AVX
// 256bit simd
using HintFloat = double;
using Complex = std::complex<HintFloat>;
// 2个复数并行
struct Complex2
{
__m256d data;
Complex2()
{
data = _mm256_setzero_pd();
}
Complex2(double input)
{
data = _mm256_set1_pd(input);
}
Complex2(__m256d input)
{
data = input;
}
Complex2(const Complex2 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex2(double const *ptr)
{
data = _mm256_load_pd(ptr);
}
Complex2(const Complex *ptr)
{
data = _mm256_load_pd((const double *)ptr);
}
void clr()
{
data = _mm256_setzero_pd();
}
void store(Complex *a) const
{
_mm256_store_pd((double *)a, data);
}
Complex2(const Complex &a, const Complex &b)
{
data = _mm256_set_m128d(*(const __m128d *)&b, *(const __m128d *)&a);
}
void print() const
{
double ary[4];
_mm256_storeu_pd(ary, data);
printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
}
template <int M>
Complex2 element_mask_neg() const
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
return _mm256_xor_pd(data, neg_mask);
}
template <int M>
Complex2 element_permute() const
{
return _mm256_permute_pd(data, M);
}
template <int M>
Complex2 element_permute64() const
{
return _mm256_permute4x64_pd(data, M);
}
Complex2 all_real() const
{
return _mm256_unpacklo_pd(data, data);
// return _mm256_shuffle_pd(data, data, 0);
// return _mm256_movedup_pd(data);
}
Complex2 all_imag() const
{
return _mm256_unpackhi_pd(data, data);
// return _mm256_shuffle_pd(data, data, 15);
// return element_permute<0XF>();
}
Complex2 swap() const
{
return _mm256_shuffle_pd(data, data, 5);
// return element_permute<0X5>();
}
Complex2 mul_neg_i() const
{
return Complex2(_mm256_addsub_pd(_mm256_setzero_pd(), data)).swap();
// return swap().conj();
}
Complex2 conj() const
{
return element_mask_neg<10>();
}
Complex2 linear_mul(Complex2 input) const
{
return _mm256_mul_pd(data, input.data);
}
Complex2 square() const
{
const __m256d rr = all_real().data;
const __m256d ir = swap().data;
const __m256d add = _mm256_add_pd(rr, ir);
const __m256d sub = _mm256_sub_pd(rr, ir);
return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
}
Complex2 operator+(const Complex2 &input) const
{
return _mm256_add_pd(data, input.data);
}
Complex2 operator-(const Complex2 &input) const
{
return _mm256_sub_pd(data, input.data);
}
Complex2 operator*(const Complex2 &input) const
{
auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
}
};
}
#endif
#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"
#define TABLE_ENABLE 1
#define MULTITHREAD 0 // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
using namespace hint_simd;
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using HintFloat = double;
using Complex = std::complex<HintFloat>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr HintFloat HINT_PI = 3.1415926535897932384626433832795;
constexpr HintFloat HINT_2PI = HINT_PI * 2;
constexpr HintFloat HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
/// @brief 生成不大于n的最大的2的幂次的数
/// @param n
/// @return 不大于n的最大的2的幂次的数
template <typename T>
constexpr T max_2pow(T n)
{
T res = 1;
res <<= (sizeof(T) * 8 - 1);
while (res > n)
{
res /= 2;
}
return res;
}
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename T>
constexpr T hint_log2(T n)
{
T res = 0;
while (n > 1)
{
n /= 2;
res++;
}
return res;
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
#if MULTITHREAD == 1
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
// 模板数组拷贝
template <typename T>
void ary_copy(T *target, const T *source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::memcpy(target, source, len * sizeof(T));
}
// 从其他类型数组拷贝到复数组
template <typename T>
inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
{
size_t min_len = std::min(len1, len2);
size_t i = 0;
while (i < min_len)
{
target[i] = Complex(source1[i], source2[i]);
i++;
}
while (i < len1)
{
target[i].real(source1[i]);
i++;
}
while (i < len2)
{
target[i].imag(source2[i]);
i++;
}
}
// FFT与类FFT变换的命名空间
namespace hint_transform
{
// 二进制逆序
template <typename T>
void binary_reverse_swap(T &ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
namespace hint_fft
{
template <typename T>
inline void fft_2point(T &sum, T &diff)
{
T tmp0 = sum;
T tmp1 = diff;
sum = tmp0 + tmp1;
diff = tmp0 - tmp1;
}
template <size_t LEN>
struct FFT_RADIX2
{
enum
{
max_fft_len = LEN,
table_len = max_fft_len / 2
};
static constexpr Complex mul(Complex a, Complex b)
{
return Complex(a.real() * b.real() - a.imag() * b.imag(), a.imag() * b.real() + a.real() * b.imag());
}
std::array<Complex, table_len> TABLE;
constexpr FFT_RADIX2()
{
table_init();
}
constexpr void table_init()
{
constexpr Complex pre_compute[24] = {
Complex(1, 0),
Complex(0, -1),
Complex(0.7071067811865476, -0.7071067811865476),
Complex(0.9238795325112867, -0.3826834323650898),
Complex(0.9807852804032304, -0.19509032201612825),
Complex(0.9951847266721969, -0.0980171403295606),
Complex(0.9987954562051724, -0.049067674327418015),
Complex(0.9996988186962042, -0.024541228522912288),
Complex(0.9999247018391445, -0.012271538285719925),
Complex(0.9999811752826011, -0.006135884649154475),
Complex(0.9999952938095762, -0.003067956762965976),
Complex(0.9999988234517019, -0.0015339801862847655),
Complex(0.9999997058628822, -0.0007669903187427045),
Complex(0.9999999264657179, -0.00038349518757139556),
Complex(0.9999999816164293, -0.0001917475973107033),
Complex(0.9999999954041073, -9.587379909597734e-05),
Complex(0.9999999988510269, -4.793689960306688e-05),
Complex(0.9999999997127567, -2.396844980841822e-05),
Complex(0.9999999999281892, -1.1984224905069705e-05),
Complex(0.9999999999820472, -5.9921124526424275e-06),
Complex(0.9999999999955118, -2.996056226334661e-06),
Complex(0.999999999998878, -1.4980281131690111e-06),
Complex(0.9999999999997194, -7.490140565847157e-07),
Complex(0.9999999999999298, -3.7450702829238413e-07),
};
TABLE[0] = Complex(1.0, 0);
for (size_t i = 1, index = 1; i < table_len; i *= 2, index++)
{
TABLE[i] = pre_compute[index];
}
for (size_t i = 0; i < table_len; i++)
{
size_t j = i & (i - 1);
TABLE[i] = mul(TABLE[i - j], TABLE[j]);
}
}
void dif_avx(Complex *input, size_t len) const
{
if (len > max_fft_len)
{
return;
}
for (size_t i = 0; i < len / 2; i++)
{
fft_2point(input[i], input[i + len / 2]);
}
for (size_t rank = len / 4; rank > 1; rank /= 2)
{
for (size_t begin = 0, index = 0; begin < len; begin += rank * 2, index++)
{
Complex *p1 = input + begin;
const Complex2 omega(TABLE[index], TABLE[index]);
for (size_t pos = 0; pos < rank; pos += 2)
{
Complex2 tmp1 = p1 + pos;
Complex2 tmp2 = Complex2(p1 + pos + rank) * omega;
(tmp1 + tmp2).store(p1 + pos);
(tmp1 - tmp2).store(p1 + pos + rank);
}
}
}
for (size_t i = 0; i < len; i += 2)
{
Complex tmp1 = input[i];
Complex tmp2 = input[i + 1] * TABLE[i / 2];
input[i] = tmp1 + tmp2;
input[i + 1] = tmp1 - tmp2;
}
}
void dit_avx(Complex *input, size_t len) const
{
if (len > max_fft_len)
{
return;
}
for (size_t i = 0; i < len; i += 2)
{
Complex tmp1 = input[i];
Complex tmp2 = input[i + 1];
input[i] = tmp1 + tmp2;
input[i + 1] = (tmp1 - tmp2) * TABLE[i / 2];
}
for (size_t rank = 2; rank < len / 2; rank *= 2)
{
for (size_t begin = 0, index = 0; begin < len; begin += rank * 2, index++)
{
Complex *p1 = input + begin;
const Complex2 omega(TABLE[index], TABLE[index]);
for (size_t pos = 0; pos < rank; pos += 2)
{
Complex2 tmp1 = p1 + pos;
Complex2 tmp2 = p1 + pos + rank;
(tmp1 + tmp2).store(p1 + pos);
((tmp1 - tmp2) * omega).store(p1 + pos + rank);
}
}
}
for (size_t i = 0; i < len / 2; i++)
{
fft_2point(input[i], input[i + len / 2]);
}
}
};
constexpr size_t len1 = 1 << 19;
static FFT_RADIX2<len1> small_fft;
}
}
inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
{
UINT_64 result = 0;
for (size_t i = 0; i < dig; i++)
{
result *= 10;
result += (s[i] - '0');
}
return result;
}
constexpr UINT_32 stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
constexpr UINT_32 stobase100000(const char *s)
{
return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
}
static constexpr INT_64 DIGIT = 4;
inline size_t char_to_real(const char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
// hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].real(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_32 tmp = stoui64(buffer, pos);
comary[i].real(tmp);
}
return len;
}
inline size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
// hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].imag(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_32 tmp = stoui64(buffer, pos);
comary[i].imag(tmp);
}
return len;
}
inline void num_to_s(char *s, UINT_64 num)
{
char c = '0';
int i = DIGIT;
while (i > 0)
{
i--;
std::tie(num, c) = div_mod<UINT_64>(num, 10);
s[i] = c + '0';
}
}
constexpr void num_to_s_base10000(char *s, UINT_32 num)
{
s[3] = '0' + num % 10;
s[2] = '0' + num / 10 % 10;
s[1] = '0' + num / 100 % 10;
s[0] = '0' + num / 1000 % 10;
}
constexpr void num_to_s_base100000(char *s, UINT_32 num)
{
s[4] = '0' + num % 10;
s[3] = '0' + num / 10 % 10;
s[2] = '0' + num / 100 % 10;
s[1] = '0' + num / 1000 % 10;
s[0] = '0' + num / 10000 % 10;
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
void tostr(char *str, uint32_t num) const
{
*reinterpret_cast<uint32_t *>(str) = table[num];
}
uint32_t tostr(uint32_t num) const
{
return table[num];
}
};
}
using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;
int main()
{
constexpr size_t STR_LEN = 2000008;
constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
static constexpr ItoStrBase10000 transfer;
static AlignAry<char, STR_LEN> out;
static AlignAry<HintFloat, 1 << 20> fft_ary_avx;
Complex *fft_ary = fft_ary_avx.template cast_ptr<Complex>();
uint32_t *ary = out.template cast_ptr<uint32_t>();
size_t len_a = 0, len_b = 0;
fread(out.data(), 1, STR_LEN, stdin);
char *p = out.data();
/*
struct stat sb;
int fd = fileno(stdin);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
*/
while (p[len_a] >= '0')
{
len_a++;
}
if (len_a == 1 && p[0] == '0')
{
puts("0");
return 0;
}
char *b = p + len_a;
while (!isdigit(*b))
{
b++;
}
while (b[len_b] >= '0')
{
len_b++;
}
if (len_b == 1 && b[0] == '0')
{
puts("0");
return 0;
} // 0.46ms
size_t len1 = char_to_real(p, fft_ary, len_a);
size_t len2 = char_to_imag(b, fft_ary, len_b); // 1.67ms
size_t fft_len = int_ceil2(len1 + len2 - 1);
small_fft.dif_avx(fft_ary, fft_len);
HintFloat inv = -0.5 / fft_len;
const Complex2 invx4(inv);
for (size_t i = 0; i < fft_len; i += 2)
{
Complex2 tmp = fft_ary + i;
tmp = tmp.square().linear_mul(invx4);
(tmp.conj()).store(fft_ary + i);
}
small_fft.dit_avx(fft_ary, fft_len); // 6ms
UINT_64 carry = 0;
size_t pos = STR_LEN / 4 - 1;
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
carry += UINT_64(fft_ary[i].imag() + 0.5);
UINT_64 num = 0;
std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
ary[pos] = transfer.tostr(num);
pos--;
}
ary[pos] = transfer.tostr(carry);
pos *= 4;
while (out[pos] == '0')
{
pos++;
} // 0.8ms
fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 23.423 ms | 15 MB + 888 KB | Accepted | Score: 100 | 显示更多 |