// 给个STAR吧
// 再不济CV代码把这些链接留下吧,球球了
// https://github.com/With-Sky/HintFFT
// https://github.com/With-Sky/FFT-Benchmark
// https://github.com/With-Sky/HyperInt-mini
// https://space.bilibili.com/511540153
#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("fma")
#pragma GCC target("avx2")
namespace hint_simd
{
// Use AVX
// 256bit simd
using HintFloat = double;
using Complex = std::complex<HintFloat>;
// 2个复数并行
struct Complex2
{
__m256d data;
Complex2()
{
data = _mm256_setzero_pd();
}
Complex2(double input)
{
data = _mm256_set1_pd(input);
}
Complex2(__m256d input)
{
data = input;
}
Complex2(const Complex2 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex2(double const *ptr)
{
data = _mm256_load_pd(ptr);
}
Complex2(const Complex *ptr)
{
data = _mm256_load_pd((const double *)ptr);
}
void clr()
{
data = _mm256_setzero_pd();
}
void store(Complex *a) const
{
_mm256_store_pd((double *)a, data);
}
void print() const
{
double ary[4];
_mm256_storeu_pd(ary, data);
printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
}
template <int M>
Complex2 element_mask_neg() const
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
return _mm256_xor_pd(data, neg_mask);
}
template <int M>
Complex2 element_permute() const
{
return _mm256_permute_pd(data, M);
}
template <int M>
Complex2 element_permute64() const
{
return _mm256_permute4x64_pd(data, M);
}
Complex2 all_real() const
{
return _mm256_unpacklo_pd(data, data);
// return _mm256_shuffle_pd(data, data, 0);
// return _mm256_movedup_pd(data);
}
Complex2 all_imag() const
{
return _mm256_unpackhi_pd(data, data);
// return _mm256_shuffle_pd(data, data, 15);
// return element_permute<0XF>();
}
Complex2 swap() const
{
return _mm256_shuffle_pd(data, data, 5);
// return element_permute<0X5>();
}
Complex2 mul_neg_i() const
{
static const __m256d subber{};
return Complex2(_mm256_addsub_pd(subber, data)).swap();
// return swap().conj();
}
Complex2 conj() const
{
return element_mask_neg<10>();
}
Complex2 linear_mul(Complex2 input) const
{
return _mm256_mul_pd(data, input.data);
}
Complex2 square() const
{
const __m256d rr = all_real().data;
const __m256d ir = swap().data;
const __m256d add = _mm256_add_pd(rr, ir);
const __m256d sub = _mm256_sub_pd(rr, ir);
return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
}
Complex2 operator+(const Complex2 &input) const
{
return _mm256_add_pd(data, input.data);
}
Complex2 operator-(const Complex2 &input) const
{
return _mm256_sub_pd(data, input.data);
}
Complex2 operator*(const Complex2 &input) const
{
const __m256d a_rr = all_real().data;
const __m256d a_ii = all_imag().data;
const __m256d b_ir = input.swap().data;
return _mm256_addsub_pd(_mm256_mul_pd(a_rr, input.data), _mm256_mul_pd(a_ii, b_ir));
// auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
// return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
}
};
}
#endif
#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"
#define TABLE_ENABLE 1
#define MULTITHREAD 0 // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
using namespace hint_simd;
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using HintFloat = double;
using Complex = std::complex<HintFloat>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr HintFloat HINT_PI = 3.1415926535897932384626433832795;
constexpr HintFloat HINT_2PI = HINT_PI * 2;
constexpr HintFloat HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(128) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
T *get_ptr()
{
return reinterpret_cast<T *>(ary);
}
const T *get_ptr() const
{
return reinterpret_cast<const T *>(ary);
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *get_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
/// @brief 生成不大于n的最大的2的幂次的数
/// @param n
/// @return 不大于n的最大的2的幂次的数
template <typename T>
constexpr T max_2pow(T n)
{
T res = 1;
res <<= (sizeof(T) * 8 - 1);
while (res > n)
{
res /= 2;
}
return res;
}
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename T>
constexpr T hint_log2(T n)
{
T res = 0;
while (n > 1)
{
n /= 2;
res++;
}
return res;
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
#if MULTITHREAD == 1
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
// FFT与类FFT变换的命名空间
namespace hint_transform
{
// 二进制逆序
template <typename T>
void binary_reverse_swap(T &ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
namespace hint_fft
{
// 返回单位圆上辐角为theta的点
static Complex unit_root(HintFloat theta)
{
return std::polar<HintFloat>(1.0, theta);
}
// 返回单位圆上平分m份的第n个
static Complex unit_root(size_t m, size_t n)
{
return unit_root((HINT_2PI * n) / m);
}
template <UINT_32 MAX_SHIFT>
class ComplexTableC
{
private:
enum
{
TABLE_LEN = (size_t(1) << MAX_SHIFT) / 2
};
AlignAry<HintFloat, TABLE_LEN * 2> table1_avx;
AlignAry<HintFloat, TABLE_LEN * 2> table3_avx;
Complex *table1 = nullptr, *table3 = nullptr;
INT_32 max_log_size = 2;
INT_32 cur_log_size = 2;
static constexpr size_t FAC = 1;
ComplexTableC(const ComplexTableC &) = delete;
ComplexTableC &operator=(const ComplexTableC &) = delete;
public:
// 初始化可以生成平分圆1<<shift份产生的单位根的表
constexpr ComplexTableC()
{
table1 = table1_avx.template cast_ptr<Complex>();
table3 = table3_avx.template cast_ptr<Complex>();
max_log_size = std::max<size_t>(MAX_SHIFT, 1);
table1[0] = table1[1] = table3[0] = table3[1] = Complex(1);
// expand(max_log_size);
}
constexpr Complex &table1_access(int shift, size_t n)
{
return table1[(1 << shift) / 4 + n];
}
constexpr Complex &table3_access(int shift, size_t n)
{
return table3[(1 << shift) / 4 + n];
}
constexpr void expand(INT_32 shift)
{
expand_topdown(shift);
// expand_bottomup(shift);
}
constexpr void expand_bottomup(INT_32 shift)
{
shift = std::max<INT_32>(shift, 2);
if (shift > max_log_size)
{
throw("FFT length too long for lut\n");
}
for (INT_32 i = cur_log_size + 1; i <= shift; i++)
{
size_t len = 1ull << i, vec_size = len * FAC / 4;
table1_access(i, 0) = table3_access(i, 0) = Complex(1, 0);
for (size_t pos = 0; pos < vec_size / 2; pos++)
{
table1_access(i, pos * 2) = table1_access(i - 1, pos);
if (pos % 2 == 1)
{
Complex tmp = unit_root(-HINT_2PI * pos / len);
table1_access(i, pos) = tmp;
table1_access(i, vec_size - pos) = -Complex(tmp.imag(), tmp.real());
}
}
table1_access(i, vec_size / 2) = std::conj(unit_root(8, 1));
for (size_t pos = 0; pos < vec_size / 2; pos++)
{
table3_access(i, pos * 2) = table3_access(i - 1, pos);
if (pos % 2 == 1)
{
Complex tmp = get_omega(i, pos * 3);
table3_access(i, pos) = tmp;
table3_access(i, vec_size - pos) = Complex(tmp.imag(), tmp.real());
}
}
table3_access(i, vec_size / 2) = std::conj(unit_root(8, 3));
}
cur_log_size = std::max(cur_log_size, shift);
}
constexpr void expand_topdown(INT_32 shift)
{
shift = std::min(shift, max_log_size);
if (shift <= cur_log_size)
{
return;
}
size_t len = 1ull << shift, vec_size = len * FAC / 4;
table1_access(shift, 0) = table3_access(shift, 0) = Complex(1, 0);
const HintFloat inv = -HINT_2PI / len;
for (size_t pos = 1; pos < vec_size / 2; pos *= 2)
{
table1_access(shift, pos) = unit_root(inv * pos);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
size_t sub_pos = pos & (pos - 1);
table1_access(shift, pos) = table1_access(shift, sub_pos) * table1_access(shift, pos - sub_pos);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
Complex tmp = table1_access(shift, pos);
table1_access(shift, vec_size - pos) = -Complex(tmp.imag(), tmp.real());
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
Complex tmp = get_omega(shift, pos * 3);
table3_access(shift, pos) = tmp;
table3_access(shift, vec_size - pos) = Complex(tmp.imag(), tmp.real());
}
table1_access(shift, vec_size / 2) = std::conj(unit_root(8, 1));
table3_access(shift, vec_size / 2) = std::conj(unit_root(8, 3));
for (INT_32 log = shift - 1; log > cur_log_size; log--)
{
len = 1ull << log, vec_size = len / 4;
for (size_t pos = 0; pos < vec_size; pos++)
{
table1_access(log, pos) = table1_access(log + 1, pos * 2);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
Complex tmp = get_omega(log, pos * 3);
table3_access(log, pos) = tmp;
table3_access(log, vec_size - pos) = Complex(tmp.imag(), tmp.real());
}
table3_access(log, 0) = Complex(1, 0);
table3_access(log, vec_size / 2) = std::conj(unit_root(8, 3));
}
cur_log_size = std::max(cur_log_size, shift);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
constexpr Complex get_omega(UINT_32 shift, size_t n) const
{
size_t vec_size = (size_t(1) << shift) / 4;
if (n < vec_size)
{
return table1[vec_size + n];
}
else if (n > vec_size)
{
Complex tmp = table1[vec_size + vec_size * 2 - n];
return Complex(-tmp.real(), tmp.imag());
}
else
{
return Complex(0, -1);
}
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex get_omega3(UINT_32 shift, size_t n) const
{
return table3_access(shift, n);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex2 get_omegaX2(UINT_32 shift, size_t n) const
{
return Complex2(table1 + (1 << (shift - 2)) + n);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex2 get_omega3X2(UINT_32 shift, size_t n) const
{
return Complex2(table3 + (1 << (shift - 2)) + n);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
const Complex *get_omega_ptr(UINT_32 shift, size_t n) const
{
return table1 + (1 << (shift - 2)) + n;
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
const Complex *get_omega3_ptr(UINT_32 shift, size_t n) const
{
return table3 + (1 << (shift - 2)) + n;
}
};
constexpr size_t lut_max_rank = 13;
// static ComplexTableY TABLE(lut_max_rank);
static ComplexTableC<lut_max_rank> TABLE;
// 2点fft
template <typename T>
inline void fft_2point(T &sum, T &diff)
{
T tmp0 = sum;
T tmp1 = diff;
sum = tmp0 + tmp1;
diff = tmp0 - tmp1;
}
inline void fft_dit_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp1);
fft_2point(tmp2, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp2;
input[rank] = tmp1 + tmp3;
input[rank * 2] = tmp0 - tmp2;
input[rank * 3] = tmp1 - tmp3;
}
inline void fft_dif_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp0 - tmp1;
input[rank * 2] = tmp2 + tmp3;
input[rank * 3] = tmp2 - tmp3;
}
inline void fft_dit_4point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
__m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
__m256d tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3
tmp0 = _mm256_add_pd(tmp2, tmp3); // c0+c1,c2+c3
tmp1 = _mm256_sub_pd(tmp2, tmp3); // c0-c1,c2-c3
tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3
tmp3 = _mm256_permute_pd(tmp3, 0b0110);
tmp3 = _mm256_xor_pd(tmp3, neg_mask); // (C,D)
tmp0 = _mm256_add_pd(tmp2, tmp3); // A+C,B+D
tmp1 = _mm256_sub_pd(tmp2, tmp3); // A-C,B-D
_mm256_store_pd(reinterpret_cast<double *>(input), tmp0);
_mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp1);
}
inline void fft_dit_8point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
__m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_load_pd(reinterpret_cast<double *>(input + 4)); // c0,c1
__m256d tmp3 = _mm256_load_pd(reinterpret_cast<double *>(input + 6)); // c2,c3
__m256d tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
__m256d tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3
__m256d tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0,c2
__m256d tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c1,c3
tmp0 = _mm256_add_pd(tmp4, tmp5); // c0+c1,c2+c3
tmp1 = _mm256_sub_pd(tmp4, tmp5); // c0-c1,c2-c3
tmp2 = _mm256_add_pd(tmp6, tmp7); // c0+c1,c2+c3
tmp3 = _mm256_sub_pd(tmp6, tmp7); // c0-c1,c2-c3
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0+c1,c0-c1;(A,B)
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c2+c3,c2-c3
tmp5 = _mm256_permute_pd(tmp5, 0b0110);
tmp5 = _mm256_xor_pd(tmp5, neg_mask); // (C,D)
tmp7 = _mm256_permute_pd(tmp7, 0b0110);
tmp7 = _mm256_xor_pd(tmp7, neg_mask); // (C,D)
tmp0 = _mm256_add_pd(tmp4, tmp5); // A+C,B+D
tmp1 = _mm256_sub_pd(tmp4, tmp5); // A-C,B-D
tmp2 = _mm256_add_pd(tmp6, tmp7); // A+C,B+D
tmp3 = _mm256_sub_pd(tmp6, tmp7); // A-C,B-D
// 2X4point-done
tmp6 = _mm256_permute_pd(tmp2, 0b0110);
tmp6 = _mm256_addsub_pd(tmp6, tmp2);
tmp6 = _mm256_permute_pd(tmp6, 0b0110);
tmp6 = _mm256_mul_pd(tmp6, mul1);
tmp2 = _mm256_blend_pd(tmp2, tmp6, 0b1100);
tmp7 = _mm256_permute_pd(tmp3, 0b0101);
tmp3 = _mm256_addsub_pd(tmp3, tmp7);
tmp3 = _mm256_blend_pd(tmp7, tmp3, 0b1100);
tmp3 = _mm256_mul_pd(tmp3, mul2);
tmp4 = _mm256_add_pd(tmp0, tmp2);
tmp5 = _mm256_add_pd(tmp1, tmp3);
tmp6 = _mm256_sub_pd(tmp0, tmp2);
tmp7 = _mm256_sub_pd(tmp1, tmp3);
_mm256_store_pd(reinterpret_cast<double *>(input), tmp4);
_mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp5);
_mm256_store_pd(reinterpret_cast<double *>(input + 4), tmp6);
_mm256_store_pd(reinterpret_cast<double *>(input + 6), tmp7);
}
inline void fft_dif_4point_avx(Complex *input)
{
__m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_add_pd(tmp0, tmp1); // c0+c2,c1+c3;
__m256d tmp3 = _mm256_sub_pd(tmp0, tmp1); // c0-c2,c1-c3;
tmp3 = _mm256_permute_pd(tmp3, 0b0110); // c0-c2,r(c1-c3);
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
tmp3 = _mm256_xor_pd(tmp3, neg_mask);
tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // A,C
tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // B,D
tmp2 = _mm256_add_pd(tmp0, tmp1); // A+B,C+D
tmp3 = _mm256_sub_pd(tmp0, tmp1); // A-B,C-D
tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
_mm256_store_pd(reinterpret_cast<double *>(input), tmp0);
_mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp1);
}
inline void fft_dif_8point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
__m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_load_pd(reinterpret_cast<double *>(input + 4)); // c4,c5
__m256d tmp3 = _mm256_load_pd(reinterpret_cast<double *>(input + 6)); // c6,c7
__m256d tmp4 = _mm256_add_pd(tmp0, tmp2);
__m256d tmp5 = _mm256_add_pd(tmp1, tmp3);
__m256d tmp6 = _mm256_sub_pd(tmp0, tmp2);
__m256d tmp7 = _mm256_sub_pd(tmp1, tmp3);
tmp2 = _mm256_permute_pd(tmp6, 0b0110);
tmp2 = _mm256_addsub_pd(tmp2, tmp6);
tmp2 = _mm256_permute_pd(tmp2, 0b0110);
tmp2 = _mm256_mul_pd(tmp2, mul1);
tmp6 = _mm256_blend_pd(tmp6, tmp2, 0b1100);
tmp3 = _mm256_permute_pd(tmp7, 0b0101);
tmp7 = _mm256_addsub_pd(tmp7, tmp3);
tmp7 = _mm256_blend_pd(tmp3, tmp7, 0b1100);
tmp7 = _mm256_mul_pd(tmp7, mul2);
// 2X4point
tmp0 = _mm256_add_pd(tmp4, tmp5);
tmp1 = _mm256_sub_pd(tmp4, tmp5);
tmp1 = _mm256_permute_pd(tmp1, 0b0110);
tmp1 = _mm256_xor_pd(tmp1, neg_mask);
tmp2 = _mm256_add_pd(tmp6, tmp7);
tmp3 = _mm256_sub_pd(tmp6, tmp7);
tmp3 = _mm256_permute_pd(tmp3, 0b0110);
tmp3 = _mm256_xor_pd(tmp3, neg_mask);
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
tmp0 = _mm256_add_pd(tmp4, tmp5);
tmp1 = _mm256_sub_pd(tmp4, tmp5);
tmp2 = _mm256_add_pd(tmp6, tmp7);
tmp3 = _mm256_sub_pd(tmp6, tmp7);
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
_mm256_store_pd(reinterpret_cast<double *>(input), tmp4);
_mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp5);
_mm256_store_pd(reinterpret_cast<double *>(input + 4), tmp6);
_mm256_store_pd(reinterpret_cast<double *>(input + 6), tmp7);
}
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp1 = input + rank;
Complex2 tmp2 = Complex2(input + rank * 2) * omega;
Complex2 tmp3 = Complex2(input + rank * 3) * omega_cube;
fft_2point(tmp2, tmp3);
tmp3 = tmp3.mul_neg_i();
(tmp0 + tmp2).store(input);
(tmp1 + tmp3).store(input + rank);
(tmp0 - tmp2).store(input + rank * 2);
(tmp1 - tmp3).store(input + rank * 3);
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = (input);
Complex2 tmp1 = (input + rank);
Complex2 tmp2 = (input + rank * 2);
Complex2 tmp3 = (input + rank * 3);
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = tmp3.mul_neg_i();
tmp0.store(input);
tmp1.store(input + rank);
((tmp2 + tmp3) * omega).store(input + rank * 2);
((tmp2 - tmp3) * omega_cube).store(input + rank * 3);
}
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(const Complex *omega, const Complex *omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp4 = input + 2;
Complex2 tmp1 = input + rank;
Complex2 tmp5 = input + rank + 2;
Complex2 tmp2 = Complex2(input + rank * 2) * Complex2(omega);
Complex2 tmp6 = Complex2(input + rank * 2 + 2) * Complex2(omega + 2);
Complex2 tmp3 = Complex2(input + rank * 3) * Complex2(omega_cube);
Complex2 tmp7 = Complex2(input + rank * 3 + 2) * Complex2(omega_cube + 2);
fft_2point(tmp2, tmp3);
fft_2point(tmp6, tmp7);
tmp3 = tmp3.mul_neg_i();
tmp7 = tmp7.mul_neg_i();
(tmp0 + tmp2).store(input);
(tmp4 + tmp6).store(input + 2);
(tmp1 + tmp3).store(input + rank);
(tmp5 + tmp7).store(input + rank + 2);
(tmp0 - tmp2).store(input + rank * 2);
(tmp4 - tmp6).store(input + rank * 2 + 2);
(tmp1 - tmp3).store(input + rank * 3);
(tmp5 - tmp7).store(input + rank * 3 + 2);
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(const Complex *omega, const Complex *omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp4 = input + 2;
Complex2 tmp1 = input + rank;
Complex2 tmp5 = input + rank + 2;
Complex2 tmp2 = input + rank * 2;
Complex2 tmp6 = input + rank * 2 + 2;
Complex2 tmp3 = input + rank * 3;
Complex2 tmp7 = input + rank * 3 + 2;
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
fft_2point(tmp4, tmp6);
fft_2point(tmp5, tmp7);
tmp3 = tmp3.mul_neg_i();
tmp7 = tmp7.mul_neg_i();
tmp0.store(input);
tmp4.store(input + 2);
tmp1.store(input + rank);
tmp5.store(input + rank + 2);
((tmp2 + tmp3) * Complex2(omega)).store(input + rank * 2);
((tmp6 + tmp7) * Complex2(omega + 2)).store(input + rank * 2 + 2);
((tmp2 - tmp3) * Complex2(omega_cube)).store(input + rank * 3);
((tmp6 - tmp7) * Complex2(omega_cube + 2)).store(input + rank * 3 + 2);
}
// 模板化时间抽取分裂基fft
static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
static constexpr HintFloat cos_1_16 = 0.92387953251128675612818318939679;
static constexpr HintFloat sin_1_16 = 0.3826834323650897717284599840304;
static constexpr Complex w1(cos_1_16, -sin_1_16), w3(sin_1_16, -cos_1_16), w9(-cos_1_16, sin_1_16);
static constexpr Complex omega1_table[4] = {Complex(1), w1, Complex(cos_1_8, -cos_1_8), w3};
static constexpr Complex omega3_table[4] = {Complex(1), w3, Complex(-cos_1_8, -cos_1_8), w9};
static const Complex2 omega0(omega1_table), omega1(omega1_table + 2);
static const Complex2 omega_cu0(omega3_table), omega_cu1(omega3_table + 2);
template <size_t LEN>
void fft_split_radix_dit_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
fft_split_radix_dit_template<quarter_len>(input + half_len);
fft_split_radix_dit_template<half_len>(input);
#if TABLE_ENABLE == 1
auto omega = TABLE.get_omega_ptr(log_len, 0);
auto omega_cube = TABLE.get_omega3_ptr(log_len, 0);
for (size_t i = 0; i < quarter_len; i += 4)
{
fft_split_radix_dit_butterfly(omega + i, omega_cube + i, input + i, quarter_len);
}
#else
static const Complex unit1 = std::conj(unit_root(LEN, 1));
static const Complex unit2 = std::conj(unit_root(LEN, 2));
static const Complex unit3 = std::conj(unit_root(LEN, 3));
static const Complex unit6 = std::conj(unit_root(LEN, 6));
static const Complex unit9 = std::conj(unit_root(LEN, 9));
static const Complex unit4 = std::conj(unit_root(LEN, 4));
static const Complex unit12 = std::conj(unit_root(LEN, 12));
static const Complex2 unit(unit4, unit4);
static const Complex2 unit_cube(unit12, unit12);
Complex2 omega1(Complex(1, 0), unit1);
Complex2 omega2(unit2, unit3);
Complex2 omega_cube1(Complex(1, 0), unit3);
Complex2 omega_cube2(unit6, unit9);
for (size_t i = 0; i < quarter_len; i += 4)
{
fft_split_radix_dit_butterfly(omega1, omega_cube1, input + i, quarter_len);
fft_split_radix_dit_butterfly(omega2, omega_cube2, input + i + 2, quarter_len);
omega1 = omega1 * unit;
omega2 = omega2 * unit;
omega_cube1 = omega_cube1 * unit_cube;
omega_cube2 = omega_cube2 * unit_cube;
}
#endif
}
template <>
inline void fft_split_radix_dit_template<0>(Complex *input) {}
template <>
inline void fft_split_radix_dit_template<1>(Complex *input) {}
template <>
inline void fft_split_radix_dit_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
inline void fft_split_radix_dit_template<4>(Complex *input)
{
fft_dit_4point_avx(input);
}
template <>
inline void fft_split_radix_dit_template<8>(Complex *input)
{
fft_dit_8point_avx(input);
}
template <>
inline void fft_split_radix_dit_template<16>(Complex *input)
{
fft_dit_4point_avx(input + 12);
fft_dit_4point_avx(input + 8);
fft_dit_8point_avx(input);
fft_split_radix_dit_butterfly(omega0, omega_cu0, input, 4);
fft_split_radix_dit_butterfly(omega1, omega_cu1, input + 2, 4);
}
// 模板化频率抽取分裂基fft
template <size_t LEN>
void fft_split_radix_dif_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
#if TABLE_ENABLE == 1
auto omega = TABLE.get_omega_ptr(log_len, 0);
auto omega_cube = TABLE.get_omega3_ptr(log_len, 0);
for (size_t i = 0; i < quarter_len; i += 4)
{
fft_split_radix_dif_butterfly(omega + i, omega_cube + i, input + i, quarter_len);
}
#else
static const Complex unit1 = std::conj(unit_root(LEN, 1));
static const Complex unit2 = std::conj(unit_root(LEN, 2));
static const Complex unit3 = std::conj(unit_root(LEN, 3));
static const Complex unit6 = std::conj(unit_root(LEN, 6));
static const Complex unit9 = std::conj(unit_root(LEN, 9));
static const Complex unit4 = std::conj(unit_root(LEN, 4));
static const Complex unit12 = std::conj(unit_root(LEN, 12));
static const Complex2 unit(unit4, unit4);
static const Complex2 unit_cube(unit12, unit12);
Complex2 omega1(Complex(1, 0), unit1);
Complex2 omega2(unit2, unit3);
Complex2 omega_cube1(Complex(1, 0), unit3);
Complex2 omega_cube2(unit6, unit9);
for (size_t i = 0; i < quarter_len; i += 4)
{
fft_split_radix_dif_butterfly(omega1, omega_cube1, input + i, quarter_len);
fft_split_radix_dif_butterfly(omega2, omega_cube2, input + i + 2, quarter_len);
omega1 = omega1 * unit;
omega2 = omega2 * unit;
omega_cube1 = omega_cube1 * unit_cube;
omega_cube2 = omega_cube2 * unit_cube;
}
#endif
fft_split_radix_dif_template<half_len>(input);
fft_split_radix_dif_template<quarter_len>(input + half_len);
fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
}
template <>
inline void fft_split_radix_dif_template<0>(Complex *input) {}
template <>
inline void fft_split_radix_dif_template<1>(Complex *input) {}
template <>
inline void fft_split_radix_dif_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
inline void fft_split_radix_dif_template<4>(Complex *input)
{
fft_dif_4point_avx(input);
}
template <>
inline void fft_split_radix_dif_template<8>(Complex *input)
{
fft_dif_8point_avx(input);
}
template <>
inline void fft_split_radix_dif_template<16>(Complex *input)
{
fft_split_radix_dif_butterfly(omega0, omega_cu0, input, 4);
fft_split_radix_dif_butterfly(omega1, omega_cu1, input + 2, 4);
fft_dif_8point_avx(input);
fft_dif_4point_avx(input + 8);
fft_dif_4point_avx(input + 12);
}
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dit_template_alt(Complex *input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dit_template_alt<LEN / 2>(input, fft_len);
return;
}
fft_split_radix_dit_template<LEN>(input);
}
template <>
void fft_split_radix_dit_template_alt<0>(Complex *input, size_t fft_len) {}
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dif_template_alt(Complex *input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dif_template_alt<LEN / 2>(input, fft_len);
return;
}
fft_split_radix_dif_template<LEN>(input);
}
template <>
void fft_split_radix_dif_template_alt<0>(Complex *input, size_t fft_len) {}
auto fft_split_radix_dit = fft_split_radix_dit_template_alt<size_t(1) << lut_max_rank>;
auto fft_split_radix_dif = fft_split_radix_dif_template_alt<size_t(1) << lut_max_rank>;
/// @brief 时间抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dit(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
#if TABLE_ENABLE == 1
TABLE.expand(hint_log2(fft_len));
#endif
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
fft_split_radix_dit(input, fft_len);
}
/// @brief 频率抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dif(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
#if TABLE_ENABLE == 1
TABLE.expand(hint_log2(fft_len));
#endif
fft_split_radix_dif(input, fft_len);
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
}
}
}
inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
{
UINT_64 result = 0;
for (size_t i = 0; i < dig; i++)
{
result *= 10;
result += (s[i] - '0');
}
return result;
}
constexpr UINT_32 stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
constexpr UINT_32 stobase100000(const char *s)
{
return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
}
constexpr UINT_32 stobase1000000(const char *s)
{
return s[0] * 100000 + s[1] * 10000 + s[2] * 1000 + s[3] * 100 + s[4] * 10 + s[5] - '0' * 111111;
}
static constexpr INT_64 DIGIT = 4;
inline size_t char_to_real(const char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
// hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].real(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_32 tmp = stoui64(buffer, pos);
comary[i].real(tmp);
}
return len;
}
inline size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
// hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].imag(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_32 tmp = stoui64(buffer, pos);
comary[i].imag(tmp);
}
return len;
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
void tostr(char *str, uint32_t num) const
{
*reinterpret_cast<uint32_t *>(str) = table[num];
}
uint32_t tostr(uint32_t num) const
{
return table[num];
}
};
}
using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;
int main()
{
constexpr size_t STR_LEN = 20008;
constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
static constexpr ItoStrBase10000 transfer;
static AlignAry<char, STR_LEN> out;
static AlignAry<HintFloat, 1 << 14> fft_arr;
Complex *fft_ary = fft_arr.cast_ptr<Complex>();
uint32_t *ary = out.cast_ptr<uint32_t>();
size_t len_a = 0, len_b = 0;
fread(out.get_ptr(), 1, STR_LEN, stdin);
char *p = out.get_ptr();
/*
struct stat sb;
int fd = fileno(stdin);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
*/
while (p[len_a] >= '0')
{
len_a++;
}
if (len_a == 1 && p[0] == '0')
{
puts("0");
return 0;
}
char *b = p + len_a;
while (!isdigit(*b))
{
b++;
}
while (b[len_b] >= '0')
{
len_b++;
}
if (len_b == 1 && b[0] == '0')
{
puts("0");
return 0;
} // 0.46ms
size_t len1 = char_to_real(p, fft_ary, len_a);
size_t len2 = char_to_imag(b, fft_ary, len_b); // 1.67ms
size_t fft_len = int_ceil2(len1 + len2 - 1);
TABLE.expand(hint_log2(fft_len)); // 3.5ms
fft_split_radix_dif_template<1 << 13>(fft_ary);
HintFloat inv = -0.5 / fft_len;
Complex2 invx4(inv);
for (size_t i = 0; i < fft_len; i += 2)
{
Complex2 tmp = fft_ary + i;
tmp = tmp.square().linear_mul(invx4);
(tmp.conj()).store(fft_ary + i);
}
fft_split_radix_dit_template<1 << 13>(fft_ary); // 6ms
UINT_64 carry = 0;
size_t pos = STR_LEN / 4 - 1;
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
carry += UINT_64(fft_ary[i].imag() + 0.5);
UINT_64 num = 0;
std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
ary[pos] = transfer.tostr(num);
pos--;
}
ary[pos] = transfer.tostr(carry);
pos *= 4;
while (out[pos] == '0')
{
pos++;
} // 0.8ms
fwrite(out.get_ptr() + pos, 1, STR_LEN - pos, stdout);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 168.86 us | 328 KB | Accepted | Score: 100 | 显示更多 |