// 给个STAR吧
// 再不济CV代码把这些链接留下吧,球球了
// https://github.com/With-Sky/HintFFT
// https://github.com/With-Sky/FFT-Benchmark
// https://github.com/With-Sky/HyperInt-mini
// https://space.bilibili.com/511540153
#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("fma")
// Use AVX
// 256bit simd
using HintFloat = double;
using Complex = std::complex<HintFloat>;
// 2个复数并行
struct Complex2
{
__m256d data;
Complex2()
{
data = _mm256_setzero_pd();
}
Complex2(double input)
{
data = _mm256_set1_pd(input);
}
Complex2(__m256d input)
{
data = input;
}
Complex2(const Complex2 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex2(double const *ptr)
{
data = _mm256_loadu_pd(ptr);
}
Complex2(Complex a)
{
data = _mm256_broadcast_pd((__m128d *)&a);
}
Complex2(Complex a, Complex b)
{
data = _mm256_set_m128d(*(__m128d *)&b, *(__m128d *)&a);
}
Complex2(const Complex *ptr)
{
data = _mm256_loadu_pd((const double *)ptr);
}
void clr()
{
data = _mm256_setzero_pd();
}
void store(Complex *a) const
{
_mm256_storeu_pd((double *)a, data);
}
void print() const
{
double ary[4];
_mm256_storeu_pd(ary, data);
printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
}
template <int M>
Complex2 element_mask_neg() const
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
return _mm256_xor_pd(data, neg_mask);
}
template <int M>
Complex2 element_permute() const
{
return _mm256_permute_pd(data, M);
}
template <int M>
Complex2 element_permute64() const
{
return _mm256_permute4x64_pd(data, M);
}
Complex2 all_real() const
{
return _mm256_unpacklo_pd(data, data);
// return _mm256_shuffle_pd(data, data, 0);
// return _mm256_movedup_pd(data);
}
Complex2 all_imag() const
{
return _mm256_unpackhi_pd(data, data);
// return _mm256_shuffle_pd(data, data, 15);
// return element_permute<0XF>();
}
Complex2 swap() const
{
return _mm256_shuffle_pd(data, data, 5);
// return element_permute<0X5>();
}
Complex2 mul_neg_i() const
{
static const __m256d subber{};
return Complex2(_mm256_addsub_pd(subber, data)).swap();
// return swap().conj();
}
Complex2 conj() const
{
return element_mask_neg<10>();
}
Complex2 linear_mul(Complex2 input) const
{
return _mm256_mul_pd(data, input.data);
}
Complex2 operator+(Complex2 input) const
{
return _mm256_add_pd(data, input.data);
}
Complex2 operator-(Complex2 input) const
{
return _mm256_sub_pd(data, input.data);
}
Complex2 operator*(Complex2 input) const
{
const __m256d a_rr = all_real().data;
const __m256d a_ii = all_imag().data;
const __m256d b_ir = input.swap().data;
return _mm256_addsub_pd(_mm256_mul_pd(a_rr, input.data), _mm256_mul_pd(a_ii, b_ir));
// auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
// return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
}
Complex2 operator/(Complex2 input) const
{
return _mm256_div_pd(data, input.data);
}
};
#endif
#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"
#define TABLE_ENABLE 0
#define MULTITHREAD 0 // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
// using HintFloat = double;
// using Complex = std::complex<HintFloat>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr HintFloat HINT_PI = 3.1415926535897932384626433832795;
constexpr HintFloat HINT_2PI = HINT_PI * 2;
constexpr HintFloat HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
/// @brief 生成不大于n的最大的2的幂次的数
/// @param n
/// @return 不大于n的最大的2的幂次的数
template <typename T>
constexpr T max_2pow(T n)
{
T res = 1;
res <<= (sizeof(T) * 8 - 1);
while (res > n)
{
res /= 2;
}
return res;
}
/// @brief 生成不小于n的最小的2的幂次的数
/// @param n
/// @return 不小于n的最小的2的幂次的数
template <typename T>
constexpr T min_2pow(T n)
{
T res = 1;
while (res < n)
{
res *= 2;
}
return res;
}
template <typename T>
constexpr size_t hint_log2(T n)
{
T res = 0;
while (n > 1)
{
n /= 2;
res++;
}
return res;
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
#if MULTITHREAD == 1
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
// 模板数组拷贝
template <typename T>
void ary_copy(T *target, const T *source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::memcpy(target, source, len * sizeof(T));
}
// 从其他类型数组拷贝到复数组
template <typename T>
inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
{
size_t min_len = std::min(len1, len2);
size_t i = 0;
while (i < min_len)
{
target[i] = Complex(source1[i], source2[i]);
i++;
}
while (i < len1)
{
target[i].real(source1[i]);
i++;
}
while (i < len2)
{
target[i].imag(source2[i]);
i++;
}
}
// FFT与类FFT变换的命名空间
namespace hint_transform
{
// 返回单位圆上辐角为theta的点
static Complex unit_root(HintFloat theta)
{
return std::polar<HintFloat>(1.0, theta);
}
// 返回单位圆上平分m份的第n个
static Complex unit_root(size_t m, size_t n)
{
return unit_root((HINT_2PI * n) / m);
}
class ComplexTableY
{
private:
std::vector<std::vector<Complex>> table1;
std::vector<std::vector<Complex>> table3;
INT_32 max_log_size = 2;
INT_32 cur_log_size = 2;
static constexpr size_t FAC = 1;
ComplexTableY(const ComplexTableY &) = delete;
ComplexTableY &operator=(const ComplexTableY &) = delete;
public:
~ComplexTableY() {}
// 初始化可以生成平分圆1<<shift份产生的单位根的表
ComplexTableY(UINT_32 max_shift)
{
max_shift = std::max<size_t>(max_shift, 1);
max_log_size = max_shift;
table1.resize(max_shift + 1);
table3.resize(max_shift + 1);
table1[0] = table1[1] = table3[0] = table3[1] = std::vector<Complex>{1};
table1[2] = table3[2] = std::vector<Complex>{1};
#if TABLE_PRELOAD == 1
expand(max_shift);
#endif
}
void expand(INT_32 shift)
{
shift = std::max<INT_32>(shift, 2);
if (shift > max_log_size)
{
throw("FFT length too long for lut\n");
}
for (INT_32 i = cur_log_size + 1; i <= shift; i++)
{
size_t len = 1ull << i, vec_size = len * FAC / 4;
table1[i].resize(vec_size);
table3[i].resize(vec_size);
table1[i][0] = table3[i][0] = Complex(1, 0);
for (size_t pos = 0; pos < vec_size / 2; pos++)
{
table1[i][pos * 2] = table1[i - 1][pos];
if (pos % 2 == 1)
{
Complex tmp = unit_root(-HINT_2PI * pos / len);
table1[i][pos] = tmp;
table1[i][vec_size - pos] = -Complex(tmp.imag(), tmp.real());
}
}
table1[i][vec_size / 2] = std::conj(unit_root(8, 1));
for (size_t pos = 0; pos < vec_size / 2; pos++)
{
table3[i][pos * 2] = table3[i - 1][pos];
if (pos % 2 == 1)
{
Complex tmp = get_omega(i, pos * 3);
table3[i][pos] = tmp;
table3[i][vec_size - pos] = Complex(tmp.imag(), tmp.real());
}
}
table3[i][vec_size / 2] = std::conj(unit_root(8, 3));
}
cur_log_size = std::max(cur_log_size, shift);
}
// 返回单位圆上辐角为theta的点
static Complex unit_root(double theta)
{
return std::polar<double>(1.0, theta);
}
// 返回单位圆上平分m份的第n个
static Complex unit_root(size_t m, size_t n)
{
return unit_root((HINT_2PI * n) / m);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex get_omega(UINT_32 shift, size_t n) const
{
size_t vec_size = (size_t(1) << shift) / 4;
if (n < vec_size)
{
return table1[shift][n];
}
else if (n > vec_size)
{
Complex tmp = table1[shift][vec_size * 2 - n];
return Complex(-tmp.real(), tmp.imag());
}
else
{
return Complex(0, -1);
}
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex get_omega3(UINT_32 shift, size_t n) const
{
return table3[shift][n];
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex2 get_omegaX2(UINT_32 shift, size_t n) const
{
return Complex2(table1[shift].data() + n);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex2 get_omega3X2(UINT_32 shift, size_t n) const
{
return Complex2(table3[shift].data() + n);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
const Complex *get_omega_ptr(UINT_32 shift, size_t n) const
{
return table1[shift].data() + n;
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
const Complex *get_omega3_ptr(UINT_32 shift, size_t n) const
{
return table3[shift].data() + n;
}
};
constexpr size_t lut_max_rank = 21;
static ComplexTableY TABLE(lut_max_rank);
// 二进制逆序
template <typename T>
void binary_reverse_swap(T &ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
// 2点fft
template <typename T>
inline void fft_2point(T &sum, T &diff)
{
T tmp0 = sum;
T tmp1 = diff;
sum = tmp0 + tmp1;
diff = tmp0 - tmp1;
}
// 4点fft
inline void fft_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp2 + tmp3;
input[rank * 2] = tmp0 - tmp1;
input[rank * 3] = tmp2 - tmp3;
}
inline void fft_dit_4point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
__m256d tmp0 = _mm256_loadu_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
__m256d tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3
tmp0 = _mm256_add_pd(tmp2, tmp3); // c0+c1,c2+c3
tmp1 = _mm256_sub_pd(tmp2, tmp3); // c0-c1,c2-c3
tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3
tmp3 = _mm256_permute_pd(tmp3, 0b0110);
tmp3 = _mm256_xor_pd(tmp3, neg_mask); // (C,D)
tmp0 = _mm256_add_pd(tmp2, tmp3); // A+C,B+D
tmp1 = _mm256_sub_pd(tmp2, tmp3); // A-C,B-D
_mm256_storeu_pd(reinterpret_cast<double *>(input), tmp0);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 2), tmp1);
}
inline void fft_dit_8point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
__m256d tmp0 = _mm256_loadu_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 4)); // c0,c1
__m256d tmp3 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 6)); // c2,c3
__m256d tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
__m256d tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3
__m256d tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0,c2
__m256d tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c1,c3
tmp0 = _mm256_add_pd(tmp4, tmp5); // c0+c1,c2+c3
tmp1 = _mm256_sub_pd(tmp4, tmp5); // c0-c1,c2-c3
tmp2 = _mm256_add_pd(tmp6, tmp7); // c0+c1,c2+c3
tmp3 = _mm256_sub_pd(tmp6, tmp7); // c0-c1,c2-c3
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0+c1,c0-c1;(A,B)
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c2+c3,c2-c3
tmp5 = _mm256_permute_pd(tmp5, 0b0110);
tmp5 = _mm256_xor_pd(tmp5, neg_mask); // (C,D)
tmp7 = _mm256_permute_pd(tmp7, 0b0110);
tmp7 = _mm256_xor_pd(tmp7, neg_mask); // (C,D)
tmp0 = _mm256_add_pd(tmp4, tmp5); // A+C,B+D
tmp1 = _mm256_sub_pd(tmp4, tmp5); // A-C,B-D
tmp2 = _mm256_add_pd(tmp6, tmp7); // A+C,B+D
tmp3 = _mm256_sub_pd(tmp6, tmp7); // A-C,B-D
// 2X4point-done
tmp6 = _mm256_permute_pd(tmp2, 0b0110);
tmp6 = _mm256_addsub_pd(tmp6, tmp2);
tmp6 = _mm256_permute_pd(tmp6, 0b0110);
tmp6 = _mm256_mul_pd(tmp6, mul1);
tmp2 = _mm256_blend_pd(tmp2, tmp6, 0b1100);
tmp7 = _mm256_permute_pd(tmp3, 0b0101);
tmp3 = _mm256_addsub_pd(tmp3, tmp7);
tmp3 = _mm256_blend_pd(tmp7, tmp3, 0b1100);
tmp3 = _mm256_mul_pd(tmp3, mul2);
tmp4 = _mm256_add_pd(tmp0, tmp2);
tmp5 = _mm256_add_pd(tmp1, tmp3);
tmp6 = _mm256_sub_pd(tmp0, tmp2);
tmp7 = _mm256_sub_pd(tmp1, tmp3);
_mm256_storeu_pd(reinterpret_cast<double *>(input), tmp4);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 2), tmp5);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 4), tmp6);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 6), tmp7);
}
inline void fft_dif_4point_avx(Complex *input)
{
__m256d tmp0 = _mm256_loadu_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_add_pd(tmp0, tmp1); // c0+c2,c1+c3;
__m256d tmp3 = _mm256_sub_pd(tmp0, tmp1); // c0-c2,c1-c3;
tmp3 = _mm256_permute_pd(tmp3, 0b0110); // c0-c2,r(c1-c3);
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
tmp3 = _mm256_xor_pd(tmp3, neg_mask);
tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // A,C
tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // B,D
tmp2 = _mm256_add_pd(tmp0, tmp1); // A+B,C+D
tmp3 = _mm256_sub_pd(tmp0, tmp1); // A-B,C-D
tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
_mm256_storeu_pd(reinterpret_cast<double *>(input), tmp0);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 2), tmp1);
}
inline void fft_dif_8point_avx(Complex *input)
{
static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
__m256d tmp0 = _mm256_loadu_pd(reinterpret_cast<double *>(input)); // c0,c1
__m256d tmp1 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
__m256d tmp2 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 4)); // c4,c5
__m256d tmp3 = _mm256_loadu_pd(reinterpret_cast<double *>(input + 6)); // c6,c7
__m256d tmp4 = _mm256_add_pd(tmp0, tmp2);
__m256d tmp5 = _mm256_add_pd(tmp1, tmp3);
__m256d tmp6 = _mm256_sub_pd(tmp0, tmp2);
__m256d tmp7 = _mm256_sub_pd(tmp1, tmp3);
tmp2 = _mm256_permute_pd(tmp6, 0b0110);
tmp2 = _mm256_addsub_pd(tmp2, tmp6);
tmp2 = _mm256_permute_pd(tmp2, 0b0110);
tmp2 = _mm256_mul_pd(tmp2, mul1);
tmp6 = _mm256_blend_pd(tmp6, tmp2, 0b1100);
tmp3 = _mm256_permute_pd(tmp7, 0b0101);
tmp7 = _mm256_addsub_pd(tmp7, tmp3);
tmp7 = _mm256_blend_pd(tmp3, tmp7, 0b1100);
tmp7 = _mm256_mul_pd(tmp7, mul2);
// 2X4point
tmp0 = _mm256_add_pd(tmp4, tmp5);
tmp1 = _mm256_sub_pd(tmp4, tmp5);
tmp1 = _mm256_permute_pd(tmp1, 0b0110);
tmp1 = _mm256_xor_pd(tmp1, neg_mask);
tmp2 = _mm256_add_pd(tmp6, tmp7);
tmp3 = _mm256_sub_pd(tmp6, tmp7);
tmp3 = _mm256_permute_pd(tmp3, 0b0110);
tmp3 = _mm256_xor_pd(tmp3, neg_mask);
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
tmp0 = _mm256_add_pd(tmp4, tmp5);
tmp1 = _mm256_sub_pd(tmp4, tmp5);
tmp2 = _mm256_add_pd(tmp6, tmp7);
tmp3 = _mm256_sub_pd(tmp6, tmp7);
tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);
_mm256_storeu_pd(reinterpret_cast<double *>(input), tmp4);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 2), tmp5);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 4), tmp6);
_mm256_storeu_pd(reinterpret_cast<double *>(input + 6), tmp7);
}
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(Complex omega, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2] * omega;
Complex tmp3 = input[rank * 3] * omega_cube;
fft_2point(tmp2, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp2;
input[rank] = tmp1 + tmp3;
input[rank * 2] = tmp0 - tmp2;
input[rank * 3] = tmp1 - tmp3;
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(Complex omega, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0;
input[rank] = tmp1;
input[rank * 2] = (tmp2 + tmp3) * omega;
input[rank * 3] = (tmp2 - tmp3) * omega_cube;
}
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp1 = input + rank;
Complex2 tmp2 = Complex2(input + rank * 2) * omega;
Complex2 tmp3 = Complex2(input + rank * 3) * omega_cube;
fft_2point(tmp2, tmp3);
tmp3 = tmp3.mul_neg_i();
(tmp0 + tmp2).store(input);
(tmp1 + tmp3).store(input + rank);
(tmp0 - tmp2).store(input + rank * 2);
(tmp1 - tmp3).store(input + rank * 3);
}
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(const Complex *omega, const Complex *omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp4 = input + 2;
Complex2 tmp1 = input + rank;
Complex2 tmp5 = input + rank + 2;
Complex2 tmp2 = Complex2(input + rank * 2) * Complex2(omega);
Complex2 tmp6 = Complex2(input + rank * 2 + 2) * Complex2(omega + 2);
Complex2 tmp3 = Complex2(input + rank * 3) * Complex2(omega_cube);
Complex2 tmp7 = Complex2(input + rank * 3 + 2) * Complex2(omega_cube + 2);
fft_2point(tmp2, tmp3);
fft_2point(tmp6, tmp7);
tmp3 = tmp3.mul_neg_i();
tmp7 = tmp7.mul_neg_i();
(tmp0 + tmp2).store(input);
(tmp4 + tmp6).store(input + 2);
(tmp1 + tmp3).store(input + rank);
(tmp5 + tmp7).store(input + rank + 2);
(tmp0 - tmp2).store(input + rank * 2);
(tmp4 - tmp6).store(input + rank * 2 + 2);
(tmp1 - tmp3).store(input + rank * 3);
(tmp5 - tmp7).store(input + rank * 3 + 2);
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(const Complex *omega, const Complex *omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp4 = input + 2;
Complex2 tmp1 = input + rank;
Complex2 tmp5 = input + rank + 2;
Complex2 tmp2 = input + rank * 2;
Complex2 tmp6 = input + rank * 2 + 2;
Complex2 tmp3 = input + rank * 3;
Complex2 tmp7 = input + rank * 3 + 2;
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
fft_2point(tmp4, tmp6);
fft_2point(tmp5, tmp7);
tmp3 = tmp3.mul_neg_i();
tmp7 = tmp7.mul_neg_i();
tmp0.store(input);
tmp4.store(input + 2);
tmp1.store(input + rank);
tmp5.store(input + rank + 2);
((tmp2 + tmp3) * Complex2(omega)).store(input + rank * 2);
((tmp6 + tmp7) * Complex2(omega + 2)).store(input + rank * 2 + 2);
((tmp2 - tmp3) * Complex2(omega_cube)).store(input + rank * 3);
((tmp6 - tmp7) * Complex2(omega_cube + 2)).store(input + rank * 3 + 2);
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = (input);
Complex2 tmp1 = (input + rank);
Complex2 tmp2 = (input + rank * 2);
Complex2 tmp3 = (input + rank * 3);
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = tmp3.mul_neg_i();
tmp0.store(input);
tmp1.store(input + rank);
((tmp2 + tmp3) * omega).store(input + rank * 2);
((tmp2 - tmp3) * omega_cube).store(input + rank * 3);
}
// 求共轭复数及归一化,逆变换用
inline void fft_conj(Complex *input, size_t fft_len, HintFloat div = 1)
{
for (size_t i = 0; i < fft_len; i++)
{
input[i] = std::conj(input[i]) / div;
}
}
// 归一化,逆变换用
inline void fft_normalize(Complex *input, size_t fft_len)
{
HintFloat len = static_cast<HintFloat>(fft_len);
for (size_t i = 0; i < fft_len; i++)
{
input[i] /= len;
}
}
// 模板化时间抽取分裂基fft
static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
static constexpr HintFloat cos_1_16 = 0.92387953251128675612818318939679;
static constexpr HintFloat sin_1_16 = 0.3826834323650897717284599840304;
static constexpr Complex w1(cos_1_16, -sin_1_16), w3(sin_1_16, -cos_1_16), w9(-cos_1_16, sin_1_16);
static constexpr Complex omega1_table[4] = {Complex(1), w1, Complex(cos_1_8, -cos_1_8), w3};
static constexpr Complex omega3_table[4] = {Complex(1), w3, Complex(-cos_1_8, -cos_1_8), w9};
static const Complex2 omega0(omega1_table), omega1(omega1_table + 2);
static const Complex2 omega_cu0(omega3_table), omega_cu1(omega3_table + 2);
template <size_t LEN>
void fft_split_radix_dit_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
fft_split_radix_dit_template<half_len>(input);
fft_split_radix_dit_template<quarter_len>(input + half_len);
fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
#if TABLE_ENABLE == 1
for (size_t i = 0; i < quarter_len; i += 4)
{
auto omega = TABLE.get_omega_ptr(log_len, i);
auto omega_cube = TABLE.get_omega3_ptr(log_len, i);
fft_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
}
#else
static const Complex unit1 = std::conj(unit_root(LEN, 1));
static const Complex unit3 = std::conj(unit_root(LEN, 3));
static const Complex unit2 = std::conj(unit_root(LEN, 2));
static const Complex unit6 = std::conj(unit_root(LEN, 6));
static const Complex2 unit(unit2, unit2);
static const Complex2 unit_cube(unit6, unit6);
Complex2 omega(Complex(1, 0), unit1);
Complex2 omega_cube(Complex(1, 0), unit3);
for (size_t i = 0; i < quarter_len; i += 2)
{
fft_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
omega = omega * unit;
omega_cube = omega_cube * unit_cube;
}
#endif
}
template <>
void fft_split_radix_dit_template<0>(Complex *input) {}
template <>
void fft_split_radix_dit_template<1>(Complex *input) {}
template <>
void fft_split_radix_dit_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
void fft_split_radix_dit_template<4>(Complex *input)
{
fft_dit_4point_avx(input);
}
template <>
void fft_split_radix_dit_template<8>(Complex *input)
{
fft_dit_8point_avx(input);
}
template <>
void fft_split_radix_dit_template<16>(Complex *input)
{
constexpr size_t log_len = hint_log2(16);
fft_dit_8point_avx(input);
fft_dit_4point_avx(input + 8);
fft_dit_4point_avx(input + 12);
fft_split_radix_dit_butterfly(omega0, omega_cu0, input, 4);
fft_split_radix_dit_butterfly(omega1, omega_cu1, input + 2, 4);
}
// 模板化频率抽取分裂基fft
template <size_t LEN>
void fft_split_radix_dif_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
#if TABLE_ENABLE == 1
for (size_t i = 0; i < quarter_len; i += 4)
{
auto omega = TABLE.get_omega_ptr(log_len, i);
auto omega_cube = TABLE.get_omega3_ptr(log_len, i);
fft_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
}
#else
static const Complex unit1 = std::conj(unit_root(LEN, 1));
static const Complex unit3 = std::conj(unit_root(LEN, 3));
static const Complex unit2 = std::conj(unit_root(LEN, 2));
static const Complex unit6 = std::conj(unit_root(LEN, 6));
static const Complex2 unit(unit2, unit2);
static const Complex2 unit_cube(unit6, unit6);
Complex2 omega(Complex(1, 0), unit1);
Complex2 omega_cube(Complex(1, 0), unit3);
for (size_t i = 0; i < quarter_len; i += 2)
{
fft_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
omega = omega * unit;
omega_cube = omega_cube * unit_cube;
}
#endif
fft_split_radix_dif_template<half_len>(input);
fft_split_radix_dif_template<quarter_len>(input + half_len);
fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
}
template <>
void fft_split_radix_dif_template<0>(Complex *input) {}
template <>
void fft_split_radix_dif_template<1>(Complex *input) {}
template <>
void fft_split_radix_dif_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
void fft_split_radix_dif_template<4>(Complex *input)
{
fft_dif_4point_avx(input);
}
template <>
void fft_split_radix_dif_template<8>(Complex *input)
{
fft_dif_8point_avx(input);
}
template <>
void fft_split_radix_dif_template<16>(Complex *input)
{
constexpr size_t log_len = hint_log2(16);
fft_split_radix_dif_butterfly(omega0, omega_cu0, input, 4);
fft_split_radix_dif_butterfly(omega1, omega_cu1, input + 2, 4);
fft_dif_8point_avx(input);
fft_dif_4point_avx(input + 8);
fft_dif_4point_avx(input + 12);
}
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dit_template_alt(Complex *input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dit_template_alt<LEN / 2>(input, fft_len);
return;
}
fft_split_radix_dit_template<LEN>(input);
}
template <>
void fft_split_radix_dit_template_alt<0>(Complex *input, size_t fft_len) {}
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dif_template_alt(Complex *input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dif_template_alt<LEN / 2>(input, fft_len);
return;
}
fft_split_radix_dif_template<LEN>(input);
}
template <>
void fft_split_radix_dif_template_alt<0>(Complex *input, size_t fft_len) {}
auto fft_split_radix_dit = fft_split_radix_dit_template_alt<size_t(1) << lut_max_rank>;
auto fft_split_radix_dif = fft_split_radix_dif_template_alt<size_t(1) << lut_max_rank>;
/// @brief 时间抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dit(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
#if TABLE_ENABLE == 1
TABLE.expand(hint_log2(fft_len));
#endif
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
fft_split_radix_dit(input, fft_len);
}
/// @brief 频率抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dif(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
#if TABLE_ENABLE == 1
TABLE.expand(hint_log2(fft_len));
#endif
fft_split_radix_dif(input, fft_len);
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
}
inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
constexpr UINT_64 stoui64(char *s, size_t dig = 4)
{
UINT_64 result = 0;
for (size_t i = 0; i < dig; i++)
{
result *= 10;
result += (s[i] - '0');
}
return result;
}
constexpr UINT_64 BASE = 10;
template <size_t DIG>
constexpr UINT_64 stoui64(char *s)
{
UINT_64 factor = qpow(BASE, DIG - 1);
return (*s - '0') * factor + stoui64<DIG - 1>(s + 1);
}
template <>
constexpr UINT_64 stoui64<0>(char *s)
{
return 0;
}
constexpr UINT_64 stobase10000(char *s)
{
return (s[0] - '0') * 1000 + (s[1] - '0') * 100 + (s[2] - '0') * 10 + s[3] - '0';
}
constexpr UINT_64 stobase100000(char *s)
{
return (s[0] - '0') * 10000 + (s[1] - '0') * 1000 + (s[2] - '0') * 100 + (s[3] - '0') * 10 + s[4] - '0';
}
static constexpr INT_64 DIGIT = 4;
inline size_t char_to_real(char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
// hint::UINT_64 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].real(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_64 tmp = stoui64(buffer, pos);
comary[i].real(tmp);
}
return len;
}
inline size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
{
hint::INT_64 len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
// hint::UINT_64 tmp = stobase10000(buffer + pos - DIGIT);
comary[i].imag(tmp);
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_64 tmp = stoui64(buffer, pos);
comary[i].imag(tmp);
}
return len;
}
inline void num_to_s(char *s, UINT_64 num)
{
char c = '0';
int i = DIGIT;
while (i > 0)
{
i--;
std::tie(num, c) = div_mod<UINT_64>(num, 10);
s[i] = c + '0';
}
}
constexpr void num_to_s_base10000(char *s, UINT_64 num)
{
s[3] = '0' + num % 10;
s[2] = '0' + num / 10 % 10;
s[1] = '0' + num / 100 % 10;
s[0] = '0' + num / 1000 % 10;
}
constexpr void num_to_s_base100000(char *s, UINT_64 num)
{
s[4] = '0' + num % 10;
s[3] = '0' + num / 10 % 10;
s[2] = '0' + num / 100 % 10;
s[1] = '0' + num / 1000 % 10;
s[0] = '0' + num / 10000 % 10;
}
}
}
using namespace std;
using namespace hint;
using namespace hint_transform;
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
size_t len1 = n + 1, len2 = m + 1, out_len = len1 + len2 - 1;
size_t fft_len = min_2pow(out_len);
// static Complex2 avx_ary[1 << 20];
// Complex2 *avx_ary = new Complex2[fft_len / 2];
static Complex fft_ary[1 << 21];
com_ary_combine_copy(fft_ary, a, len1, b, len2);
fft_dif(fft_ary, fft_len, false); // 优化FFT
HintFloat inv = -0.5 / fft_len;
Complex2 invx4(inv);
for (size_t i = 0; i < fft_len; i += 2)
{
// Complex tmp = fft_ary[i];
// fft_ary[i] = std::conj(tmp * tmp) * inv;
Complex2 tmp = fft_ary + i;
tmp = tmp * tmp.linear_mul(invx4);
(tmp.conj()).store(fft_ary + i);
}
fft_dit(fft_ary, fft_len, false); // 优化FFT
for (size_t i = 0; i < out_len; i++)
{
c[i] = unsigned(fft_ary[i].imag() + 0.5);
}
}
class QPrint
{
private:
char *data = nullptr;
size_t pos = 0;
public:
QPrint(size_t max_len)
{
data = new char[max_len];
}
~QPrint()
{
if (data != nullptr)
{
delete[] data;
}
}
void operator<<(uint64_t n)
{
if (pos != 0)
{
data[pos] = ' ';
pos++;
}
if (n == 0)
{
data[pos] = '0';
pos++;
return;
}
size_t digs = pos;
uint64_t tmp = n;
while (n > 0)
{
n /= 10;
digs++;
}
pos = digs;
while (tmp > 0)
{
digs--;
data[digs] = tmp % 10 + '0';
tmp /= 10;
}
data[pos] = '\0';
}
void operator<<(const std::string &s)
{
if (pos != 0)
{
data[pos] = ' ';
pos++;
}
memcpy(data + pos, s.data(), s.size());
pos += s.size();
data[pos] = '\0';
}
void put() const
{
puts(data);
}
};
inline int ReadNum()
{
int res = 0;
int tmp = getchar();
while (tmp < '0' || '9' < tmp)
{
tmp = getchar();
}
while ('0' <= tmp && tmp <= '9')
{
res *= 10;
res += (tmp - '0');
tmp = getchar();
}
return res;
}
int main()
{
size_t m = 4, n = 4;
m = ReadNum();
n = ReadNum();
size_t len1 = m + 1, len2 = n + 1;
QPrint qout(20000000);
size_t fft_len = min_2pow(len1 + len2 - 1);
Complex *fft_ary = new Complex[fft_len];
for (size_t i = 0; i < len1; i++)
{
int c = ReadNum();
fft_ary[i].real(c);
}
for (size_t i = 0; i < len2; i++)
{
int c = ReadNum();
fft_ary[i].imag(c);
}
fft_dif(fft_ary, fft_len, false);
HintFloat inv = -0.5 / fft_len;
Complex2 invx4(inv);
for (size_t i = 0; i < fft_len; i += 2)
{
Complex2 tmp = fft_ary + i;
tmp = tmp * tmp.linear_mul(invx4);
(tmp.conj()).store(fft_ary + i);
}
fft_dit(fft_ary, fft_len, false);
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
qout << int(fft_ary[i].imag() + 0.5);
}
// delete[] fft_ary;
qout.put();
}
// int main()
// {
// constexpr size_t STR_LEN = 4005;
// constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
// static char out[STR_LEN];
// // static Complex2 avx_ary[1 << 18];
// static Complex fft_ary[1 << 10];
// size_t len_a = 0, len_b = 0;
// scanf("%s", out);
// while (isdigit(out[len_a]))
// {
// len_a++;
// }
// if (len_a == 1 && out[0] == '0')
// {
// printf("0");
// return 0;
// }
// size_t len1 = char_to_real(out, fft_ary, len_a);
// scanf("%s", out);
// while (isdigit(out[len_b]))
// {
// len_b++;
// }
// if (len_b == 1 && out[0] == '0')
// {
// printf("0");
// return 0;
// }
// size_t len2 = char_to_imag(out, fft_ary, len_b);
// size_t fft_len = min_2pow(len1 + len2 - 1);
// fft_dif(fft_ary, fft_len, false); // 优化FFT
// HintFloat inv = -0.5 / fft_len;
// Complex2 invx4(inv);
// for (size_t i = 0; i < fft_len; i++)
// {
// Complex tmp = fft_ary[i];
// fft_ary[i] = std::conj(tmp * tmp) * inv;
// // Complex2 tmp = fft_ary + i;
// // tmp = tmp * tmp;
// // (tmp.conj()).store(fft_ary + i);
// }
// fft_dit(fft_ary, fft_len, false); // 优化FFT
// UINT_64 carry = 0;
// size_t pos = STR_LEN - 1;
// for (size_t i = 0; i < len1 + len2 - 1; i++)
// {
// carry += UINT_64(fft_ary[i].imag() + 0.5);
// UINT_64 num = 0;
// std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
// // num_to_s_base10000(out + pos - DIGIT, num);
// num_to_s(out + pos - DIGIT, num);
// pos -= DIGIT;
// }
// num_to_s(out + pos - DIGIT, carry);
// pos -= DIGIT;
// while (out[pos] == '0')
// {
// pos++;
// }
// puts(out + pos);
// }
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Subtask #1 Testcase #1 | 38.54 us | 48 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #2 | 8.528 ms | 6 MB + 908 KB | Accepted | Score: 100 | 显示更多 |
Subtask #1 Testcase #3 | 3.32 ms | 2 MB + 612 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #4 | 3.411 ms | 2 MB + 588 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #5 | 38.63 us | 48 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #6 | 38.12 us | 48 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #7 | 36.87 us | 48 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #8 | 7.792 ms | 6 MB + 372 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #9 | 7.777 ms | 6 MB + 372 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #10 | 7.081 ms | 5 MB + 860 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #11 | 8.624 ms | 7 MB + 48 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #12 | 5.84 ms | 4 MB + 828 KB | Accepted | Score: 0 | 显示更多 |
Subtask #1 Testcase #13 | 38.49 us | 48 KB | Accepted | Score: 0 | 显示更多 |