// TSKY 2025/6/20
#include <vector>
#include <array>
#include <complex>
#include <iostream>
#include <chrono>
#include <string>
#include <bitset>
#include <type_traits>
#include <cstdint>
#include <cfloat>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>
#include <climits>
#include <iostream>
#include <complex>
#include <type_traits>
#include <cstdint>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("avx")
#pragma GCC target("fma")
#pragma GCC target("avx2")
namespace hint_simd
{
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(4096) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
T *data()
{
return reinterpret_cast<T *>(ary);
}
const T *data() const
{
return reinterpret_cast<const T *>(ary);
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *cast_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
template <typename YMM>
inline void transpose64_2X4(YMM &row0, YMM &row1)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
row0 = YMM(_mm256_permute2f128_pd(t0, t1, 0x20)); // 0,4,2,6 1,5,3,7 -> 0,4,1,5
row1 = YMM(_mm256_permute2f128_pd(t0, t1, 0x31)); // 0,4,2,6 1,5,3,7 -> 2,6,3,7
}
template <typename YMM>
inline void transpose64_4X2(YMM &row0, YMM &row1)
{
auto t0 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x20); // 0,1,2,3 4,5,6,7 -> 0,1,4,5
auto t1 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x31); // 0,1,2,3 4,5,6,7 -> 2,3,6,7
row0 = YMM(_mm256_unpacklo_pd(t0, t1)); // 0,1,4,5 2,3,6,7 -> 0,4,2,6
row1 = YMM(_mm256_unpackhi_pd(t0, t1)); // 0,1,4,5 2,3,6,7 -> 1,5,3,7
}
template <typename YMM>
inline void transpose64_8X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
YMM &row4, YMM &row5, YMM &row6, YMM &row7)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15
auto t4 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 16,20,18,22
auto t5 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 17,21,19,23
auto t6 = _mm256_unpacklo_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 24,28,26,30
auto t7 = _mm256_unpackhi_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 25,29,27,31
row0 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x20));
row1 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x20));
row2 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x20));
row3 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x20));
row4 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x31));
row5 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x31));
row6 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x31));
row7 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x31));
}
template <typename YMM>
inline void transpose64_4X8(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
YMM &row4, YMM &row5, YMM &row6, YMM &row7)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 0,8,2,10
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 1,9,3,11
auto t2 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 16,24,18,26
auto t3 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 17,25,19,27
auto t4 = _mm256_unpacklo_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 4,12,6,14
auto t5 = _mm256_unpackhi_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 5,13,7,15
auto t6 = _mm256_unpacklo_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 20,28,22,30
auto t7 = _mm256_unpackhi_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 21,29,23,31
row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
row4 = YMM(_mm256_permute2f128_pd(t4, t6, 0x20));
row5 = YMM(_mm256_permute2f128_pd(t5, t7, 0x20));
row6 = YMM(_mm256_permute2f128_pd(t4, t6, 0x31));
row7 = YMM(_mm256_permute2f128_pd(t5, t7, 0x31));
}
template <typename YMM>
inline void transpose64_4X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15
row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
}
class Float64X4
{
public:
using F64 = double;
using F64X4 = Float64X4;
Float64X4() : data(_mm256_setzero_pd()) {}
Float64X4(__m256d in_data) : data(in_data) {}
Float64X4(F64 in_data) : data(_mm256_set1_pd(in_data)) {}
Float64X4(const F64 *in_data) : data(_mm256_load_pd(in_data)) {}
F64X4 operator+(const F64X4 &other) const
{
return _mm256_add_pd(data, other.data);
}
F64X4 operator-(const F64X4 &other) const
{
return _mm256_sub_pd(data, other.data);
}
F64X4 operator*(const F64X4 &other) const
{
return _mm256_mul_pd(data, other.data);
}
F64X4 operator/(const F64X4 &other) const
{
return _mm256_div_pd(data, other.data);
}
F64X4 &operator+=(const F64X4 &other)
{
return *this = *this + other;
}
F64X4 &operator-=(const F64X4 &other)
{
return *this = *this - other;
}
F64X4 &operator*=(const F64X4 &other)
{
return *this = *this * other;
}
F64X4 &operator/=(const F64X4 &other)
{
return *this = *this / other;
}
F64X4 floor() const
{
return _mm256_floor_pd(data);
}
// a * b + c
static F64X4 fmadd(const F64X4 &a, const F64X4 &b, const F64X4 &c)
{
return _mm256_fmadd_pd(a.data, b.data, c.data);
}
// a * b - c
static F64X4 fmsub(const F64X4 &a, const F64X4 &b, const F64X4 &c)
{
return _mm256_fmsub_pd(a.data, b.data, c.data);
}
template <int N>
F64X4 permute4x64() const
{
return _mm256_permute4x64_pd(data, N);
}
static F64X4 extractEven64X4(const F64X4 &in0, const F64X4 &in1)
{
F64X4 result = _mm256_unpacklo_pd(in0.data, in1.data); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
return result.permute4x64<0b11011000>(); // 0,4,2,6 -> 0,2,4,6
}
template <int N>
F64X4 permute() const
{
return _mm256_permute_pd(data, N);
}
F64X4 reverse() const
{
return permute4x64<0b00011011>();
}
void load(const F64 *p)
{
data = _mm256_load_pd(p);
}
void loadu(const F64 *p)
{
data = _mm256_loadu_pd(p);
}
void load1(const F64 *p)
{
data = _mm256_broadcast_sd(p);
}
static F64X4 fromMem(const F64 *p)
{
return _mm256_load_pd(p);
}
static F64X4 fromUMem(const F64 *p)
{
return _mm256_loadu_pd(p);
}
void store(F64 *p) const
{
_mm256_store_pd(p, data);
}
void storeu(F64 *p) const
{
_mm256_storeu_pd(p, data);
}
operator __m256d() const
{
return data;
}
// Convert positive double to int64
__m256i toI64X4() const
{
constexpr uint64_t mask = (uint64_t(1) << 52) - 1;
constexpr uint64_t offset = (uint64_t(1) << 10) - 1;
const __m256i f64bits = _mm256_castpd_si256(data);
__m256i tail = _mm256_and_si256(f64bits, _mm256_set1_epi64x(mask));
tail = _mm256_or_si256(tail, _mm256_set1_epi64x(mask + 1));
__m256i exp = _mm256_srli_epi64(f64bits, 52);
exp = _mm256_sub_epi64(_mm256_set1_epi64x(offset + 52), exp);
return _mm256_srlv_epi64(tail, exp);
}
template <int N>
F64 nthEle() const
{
union F64I64
{
int64_t i64;
F64 f64;
} temp;
temp.i64 = _mm256_extract_epi64(__m256i(data), N);
return temp.f64;
}
void print() const
{
std::cout << "[" << nthEle<0>() << "," << nthEle<1>()
<< "," << nthEle<2>() << "," << nthEle<3>() << "]" << std::endl;
}
private:
__m256d data;
};
struct Complex64X4
{
using C64X4 = Complex64X4;
using F64X4 = Float64X4;
using F64 = double;
Complex64X4() {}
Complex64X4(F64X4 real, F64X4 imag) : real(real), imag(imag) {}
Complex64X4(const F64 *p) : real(p), imag(p + 4) {}
Complex64X4(const F64 *p_real, const F64 *p_imag) : real(p_real), imag(p_imag) {}
C64X4 operator+(const C64X4 &other) const
{
return C64X4(real + other.real, imag + other.imag);
}
C64X4 operator-(const C64X4 &other) const
{
return C64X4(real - other.real, imag - other.imag);
}
C64X4 mul(const C64X4 &other) const
{
const F64X4 ii = imag * other.imag;
const F64X4 ri = real * other.imag;
const F64X4 r = F64X4::fmsub(real, other.real, ii);
const F64X4 i = F64X4::fmadd(imag, other.real, ri);
return C64X4(r, i);
}
C64X4 mulConj(const C64X4 &other) const
{
const F64X4 ii = imag * other.imag;
const F64X4 ri = real * other.imag;
const F64X4 r = F64X4::fmadd(real, other.real, ii);
const F64X4 i = F64X4::fmsub(imag, other.real, ri);
return C64X4(r, i);
}
// exp{i*theta*k},k in {0,1,2,3}
static C64X4 omegaSeq0To3(F64 theta, F64 begin = 0)
{
F64 real_arr[4] = {cos(begin), cos(theta + begin), cos(2 * theta + begin), cos(3 * theta + begin)};
F64 imag_arr[4] = {sin(begin), sin(theta + begin), sin(2 * theta + begin), sin(3 * theta + begin)};
return C64X4(F64X4(real_arr), F64X4(imag_arr));
}
template <typename T>
void load(const T *p, std::false_type)
{
this->load(p);
}
// From RIRI permutation
template <typename T>
void load(const T *p, std::true_type)
{
this->load(p);
*this = this->toRRIIPermu();
}
template <typename T>
void load(const T *p)
{
real.load(reinterpret_cast<const F64 *>(p));
imag.load(reinterpret_cast<const F64 *>(p) + 4);
}
template <typename T>
void loadu(const T *p)
{
real.loadu(reinterpret_cast<const F64 *>(p));
imag.loadu(reinterpret_cast<const F64 *>(p) + 4);
}
void load1(const F64 *real_p, const F64 *imag_p)
{
real.load1(real_p);
imag.load1(imag_p);
}
template <typename T>
void store(T *p, std::false_type) const
{
this->store(p);
}
// To RIRI permutation
template <typename T>
void store(T *p, std::true_type) const
{
this->toRIRIPermu().store(p);
}
template <typename T>
void store(T *p) const
{
real.store(reinterpret_cast<F64 *>(p));
imag.store(reinterpret_cast<F64 *>(p) + 4);
}
template <typename T>
void storeu(T *p) const
{
real.storeu(reinterpret_cast<F64 *>(p));
imag.storeu(reinterpret_cast<F64 *>(p) + 4);
}
C64X4 square() const
{
const F64X4 ii = imag * imag;
const F64X4 ri = real * imag;
const F64X4 r = F64X4::fmsub(real, real, ii);
const F64X4 i = ri + ri;
return C64X4(r, i);
}
C64X4 cube() const
{
const F64X4 rr = real * real;
const F64X4 ii = imag * imag;
const F64X4 rr3 = rr + rr + rr;
const F64X4 ii3 = ii + ii + ii;
const F64X4 r = real * (rr - ii3);
const F64X4 i = imag * (rr3 - ii);
return C64X4(r, i);
}
C64X4 toRIRIPermu() const
{
C64X4 res = *this;
transpose64_2X4(res.real, res.imag);
return res;
}
C64X4 toRRIIPermu() const
{
C64X4 res = *this;
transpose64_4X2(res.real, res.imag);
return res;
}
void print() const
{
alignas(32) F64 real_arr[4]{}, imag_arr[4]{};
real.storeu(real_arr);
imag.storeu(imag_arr);
std::cout << "[(" << real_arr[0] << ", " << imag_arr[0] << "), ("
<< real_arr[1] << ", " << imag_arr[1] << "), ("
<< real_arr[2] << ", " << imag_arr[2] << "), ("
<< real_arr[3] << ", " << imag_arr[3] << ")]" << std::endl;
}
C64X4 transToI64(std::false_type) const
{
return *this;
}
C64X4 transToI64(std::true_type) const
{
constexpr int64_t F1_2 = 4602678819172646912; // magic::bit_cast<int64_t>(0.5);
auto F1_2X4 = F64X4(__m256d(_mm256_set1_epi64x(F1_2)));
auto real_i64 = (real + F1_2X4).toI64X4();
auto imag_i64 = (imag + F1_2X4).toI64X4();
return C64X4(__m256d(real_i64), __m256d(imag_i64));
}
F64X4 real, imag;
};
}
#endif
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
constexpr size_t FFT_MAX_LEN = size_t(1) << 23;
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename IntTy>
constexpr bool is_2pow(IntTy n)
{
return (n & (n - 1)) == 0;
}
// 求整数的对数
template <typename T>
constexpr int hint_log2(T n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r)
{
int mid = (l + r) / 2;
if ((T(1) << mid) > n)
{
r = mid;
}
else
{
l = mid;
}
}
return l;
}
constexpr int hint_ctz(uint32_t x)
{
int r0 = 31;
x &= (-x);
if (x & 0x55555555)
{
r0 &= ~1;
}
if (x & 0x33333333)
{
r0 &= ~2;
}
if (x & 0x0F0F0F0F)
{
r0 &= ~4;
}
if (x & 0x00FF00FF)
{
r0 &= ~8;
}
if (x & 0x0000FFFF)
{
r0 &= ~16;
}
r0 += (x == 0);
return r0;
}
constexpr int hint_ctz(uint64_t x)
{
int r0 = 63;
x &= (-x);
if (x & 0x5555555555555555)
{
r0 &= ~1; // -1
}
if (x & 0x3333333333333333)
{
r0 &= ~2; // -2
}
if (x & 0x0F0F0F0F0F0F0F0F)
{
r0 &= ~4; // -4
}
if (x & 0x00FF00FF00FF00FF)
{
r0 &= ~8; // -8
}
if (x & 0x0000FFFF0000FFFF)
{
r0 &= ~16; // -16
}
if (x & 0x00000000FFFFFFFF)
{
r0 &= ~32; // -32
}
r0 += (x == 0);
return r0;
}
template <typename T, T N>
struct StaticObject
{
using Type = T;
static constexpr Type value = N;
};
template <size_t N>
using StaticSize = StaticObject<size_t, N>;
template <int N>
using StaticInt = StaticObject<int, N>;
// FFT与类FFT变换的命名空间
namespace transform
{
using namespace hint_simd;
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
template <typename T>
inline void transform2(const T a, const T b, T &sum, T &diff)
{
sum = a + b;
diff = a - b;
}
// 返回单位圆上辐角为theta的点
template <typename FloatTy>
inline auto unit_root(FloatTy theta)
{
return std::polar<FloatTy>(1.0, theta);
}
// 二进制逆序
template <typename It>
void binary_reverse_swap(It begin, It end)
{
const size_t len = end - begin;
// 左下标小于右下标时交换,防止重复交换
auto smaller_swap = [=](It it_left, It it_right)
{
if (it_left < it_right)
{
std::swap(it_left[0], it_right[0]);
}
};
// 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
auto get_next_bitrev = [=](It last)
{
size_t k = len / 2, indx = last - begin;
indx ^= k;
while (k > indx)
{
k >>= 1;
indx ^= k;
};
return begin + indx;
};
// 长度较短的普通逆序
if (len <= 16)
{
for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
{
smaller_swap(i, j);
j = get_next_bitrev(j);
}
return;
}
const size_t len_8 = len / 8;
const auto last = begin + len_8;
auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
{
smaller_swap(i0, j);
smaller_swap(i1, j + 1);
smaller_swap(i2, j + 2);
smaller_swap(i3, j + 3);
smaller_swap(i0 + len_8, j + 4);
smaller_swap(i1 + len_8, j + 5);
smaller_swap(i2 + len_8, j + 6);
smaller_swap(i3 + len_8, j + 7);
j = get_next_bitrev(j);
}
}
// 二进制逆序
template <typename T>
void binary_reverse_swap(T ary, const size_t len)
{
binary_reverse_swap(ary, ary + len);
}
namespace fft
{
using F64 = Float64;
using C64 = std::complex<F64>;
using F64X4 = Float64X4;
using C64X4 = Complex64X4;
template <typename Float, size_t OMEGA_LEN>
class TableFix
{
alignas(64) std::array<Float, OMEGA_LEN * 2> table;
public:
TableFix(size_t theta_divider, size_t factor, size_t stride)
{
const Float theta = -HINT_2PI * factor / theta_divider;
assert(OMEGA_LEN % stride == 0);
for (size_t begin = 0, index = 0; begin < OMEGA_LEN * 2; begin += stride * 2)
{
for (size_t j = 0; j < stride; j++, index++)
{
table[begin + j] = std::cos(theta * index);
table[begin + j + stride] = std::sin(theta * index);
}
}
}
constexpr const Float &operator[](size_t index) const
{
return table[index];
}
constexpr const Float *getOmegaIt(size_t index) const
{
return &table[index];
}
};
template <typename Float, int LOG_BEGIN, int LOG_END, int DIV>
class TableFixMulti
{
static_assert(LOG_END >= LOG_BEGIN);
static_assert(is_2pow(DIV));
static constexpr size_t TABLE_CPX_LEN = (size_t(1) << (LOG_END + 1)) / DIV;
alignas(64) std::array<Float, TABLE_CPX_LEN * 2> table;
public:
TableFixMulti(size_t factor, size_t stride = 4)
{
assert(((size_t(1) << LOG_BEGIN) / DIV) % stride == 0);
initAVXF64(factor, stride);
}
void initAVXF64(size_t factor, size_t stride)
{
assert((std::is_same<Float, Float64>::value));
assert(stride == 4);
size_t len = size_t(1) << LOG_BEGIN, cpx_len = len / DIV;
Float theta = -HINT_2PI * factor / len;
auto it = getBegin(LOG_BEGIN);
for (size_t i = 0; i < cpx_len; i++)
{
it[0] = std::cos(theta * i), it[stride] = std::sin(theta * i);
it += (i % stride == stride - 1 ? stride + 1 : 1);
}
it = getBegin(LOG_BEGIN);
for (int log_len = LOG_BEGIN + 1; log_len <= LOG_END; log_len++)
{
len = size_t(1) << log_len, cpx_len = len / DIV;
theta = -HINT_2PI * factor / len;
auto it = getBegin(log_len), it_last = getBegin(log_len - 1);
C64X4 unit(std::cos(theta), std::sin(theta));
for (auto end = it + cpx_len * 2; it < end; it += 16, it_last += 8)
{
Complex64X4 omega0, omega1;
omega0.load(it_last);
omega1 = omega0.mul(unit);
transpose64_2X4(omega0.real, omega1.real);
transpose64_2X4(omega0.imag, omega1.imag);
omega0.store(it), omega1.store(it + 8);
}
}
}
// void initEndAVXF64(size_t factor, size_t stride)
// {
// const size_t end_len = (size_t(1) << LOG_END), cpx_len = end_len / DIV;
// const Float theta = -HINT_2PI * factor / end_len;
// auto begin = getBegin(LOG_END);
// for (size_t i = 0; i < stride; i++)
// {
// begin[i] = std::cos(theta * i), begin[i + stride] = std::sin(theta * i);
// begin[i + stride * 2] = std::cos(theta * (i + stride)), begin[i + stride * 3] = std::sin(theta * (i + stride));
// }
// begin[0] = 1, begin[stride] = 0;
// auto last = getBegin(LOG_END - 1);
// C64X4 last0, last1;
// last0.load(begin), last1.load(begin + 8);
// last0.real = F64X4::extractEven64X4(last0.real, last1.real);
// last0.imag = F64X4::extractEven64X4(last0.imag, last1.imag);
// last0.store(last);
// last += 8;
// for (size_t len = stride * 2; len < cpx_len; len *= 2)
// {
// const Float angle = theta * len;
// C64X4 unit(std::cos(angle), std::sin(angle));
// auto it = begin + len * 2;
// for (size_t i = 0; i < len * 2; i += 16, last += 8)
// {
// last0.load(&begin[i]), last1.load(&begin[i + 8]);
// last0 = last0.mul(unit);
// last1 = last1.mul(unit);
// last0.store(&it[i]), last1.store(&it[i + 8]);
// last0.real = F64X4::extractEven64X4(last0.real, last1.real);
// last0.imag = F64X4::extractEven64X4(last0.imag, last1.imag);
// last0.store(last);
// }
// }
// }
// void initAVXF64(size_t factor, size_t stride)
// {
// assert((std::is_same<Float, Float64>::value));
// assert(stride == 4);
// initEndAVXF64(factor, stride);
// for (int log_len = LOG_END - 2; log_len >= LOG_BEGIN; log_len--)
// {
// auto it_src = getBegin(log_len + 1), it = getBegin(log_len);
// size_t cpx_len = (size_t(1) << log_len) / DIV;
// for (auto end = it + cpx_len * 2; it < end; it += 8, it_src += 16)
// {
// Complex64X4 omega0, omega1, omega2, omega3;
// omega0.load(it_src), omega1.load(it_src + 8);
// omega0.real = F64X4::extractEven64X4(omega0.real, omega1.real);
// omega0.imag = F64X4::extractEven64X4(omega0.imag, omega1.imag);
// omega0.store(it);
// }
// }
// }
constexpr const Float *getBegin(int log_len) const
{
size_t shift = (size_t(1) << log_len) / DIV;
return &table[shift * 2];
}
constexpr Float *getBegin(int log_len)
{
size_t shift = (size_t(1) << log_len) / DIV;
return &table[shift * 2];
}
};
struct FFTFixed
{
static constexpr size_t LOG_MAX = 18;
static constexpr size_t LOG_SHORT = 10;
static constexpr size_t SHORT_LEN = size_t(1) << LOG_SHORT;
static const TableFix<Float64, 4> table_8;
static const TableFix<Float64, 4> table_16_1;
static const TableFix<Float64, 4> table_16_3;
static const TableFix<Float64, 8> table_32_1;
static const TableFix<Float64, 8> table_32_3;
static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_2;
static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_3;
static const TableFixMulti<Float64, 6, LOG_MAX, 4> multi_table_1;
static constexpr const Float64 *it8 = &table_8[0];
static constexpr const Float64 *it16_1 = &table_16_1[0];
static constexpr const Float64 *it16_3 = &table_16_3[0];
static constexpr const Float64 *it32_1 = &table_32_1[0];
static constexpr const Float64 *it32_3 = &table_32_3[0];
template <typename Float>
static void dif4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, r3);
transform2(i1, i3);
transform2(r0, r1);
transform2(i0, i1);
transform2(r2, i3);
transform2(i2, r3, r3, i2);
std::swap(i3, r3);
}
template <typename Float>
static void idit4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r0, r1);
transform2(i0, i1);
transform2(r2, r3);
transform2(i2, i3);
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, i3, i3, r1);
transform2(i1, r3);
std::swap(i3, r3);
}
template <typename Float>
static void difSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, r3);
transform2(i1, i3);
transform2(r2, i3);
transform2(i2, r3, r3, i2);
std::swap(i3, r3);
}
template <typename Float>
static void iditSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r2, r3);
transform2(i2, i3);
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, i3, i3, r1);
transform2(i1, r3);
std::swap(i3, r3);
}
static void dif4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
{
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
dif4(r0, i0, r1, i1, r2, i2, r3, i3);
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
}
static void idit4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
{
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
idit4(r0, i0, r1, i1, r2, i2, r3, i3);
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
}
static void dif8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
{
transform2(c0, c1);
transform2(c2, c3);
c1 = c1.mul(omega);
c3 = c3.mul(omega);
dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
}
static void idit8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
{
idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c1 = c1.mulConj(omega);
c3 = c3.mulConj(omega);
transform2(c0, c1);
transform2(c2, c3);
}
static void dif8x2(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
dif8x2(c0, c1, c2, c3, omega);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void idit8x2(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
idit8x2(c0, c1, c2, c3, omega);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void dif16(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it8), c1 = c1.mul(omega);
omega.load(it16_1), c2 = c2.mul(omega);
omega.load(it16_3), c3 = c3.mul(omega);
dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void idit16(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it8), c1 = c1.mulConj(omega);
omega.load(it16_1), c2 = c2.mulConj(omega);
omega.load(it16_3), c3 = c3.mulConj(omega);
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void dif32(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 16), c2.load(in_out + 32), c3.load(in_out + 48);
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it32_1), c2 = c2.mul(omega);
omega.load(it32_3), c3 = c3.mul(omega);
c0.store(in_out), c1.store(in_out + 16), c2.store(in_out + 32), c3.store(in_out + 48);
c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56); // 1,3,5,7
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it32_1 + 8), c2 = c2.mul(omega);
omega.load(it32_3 + 8), c3 = c3.mul(omega);
c0.store(in_out + 8), c1.store(in_out + 24);
c0.load(in_out + 32), c1.load(in_out + 48), omega.load(it8); // 4,6
dif8x2(c0, c2, c1, c3, omega);
c0.store(in_out + 32), c2.store(in_out + 40), c1.store(in_out + 48), c3.store(in_out + 56);
dif16(in_out);
}
static void idit32(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
idit16(in_out);
c0.load(in_out + 32), c1.load(in_out + 40), c2.load(in_out + 48), c3.load(in_out + 56), omega.load(it8); // 4,5,6,7
idit8x2(c0, c1, c2, c3, omega);
c1.store(in_out + 40), c3.store(in_out + 56);
c1.load(in_out), c3.load(in_out + 16);
omega.load(it32_1), c0 = c0.mulConj(omega);
omega.load(it32_3), c2 = c2.mulConj(omega);
iditSplit(c1.real, c1.imag, c3.real, c3.imag, c0.real, c0.imag, c2.real, c2.imag);
c1.store(in_out), c3.store(in_out + 16), c0.store(in_out + 32), c2.store(in_out + 48);
c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
omega.load(it32_1 + 8), c2 = c2.mulConj(omega);
omega.load(it32_3 + 8), c3 = c3.mulConj(omega);
iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out + 8), c1.store(in_out + 24), c2.store(in_out + 40), c3.store(in_out + 56);
}
static void difRec(Float64 in_out[], StaticInt<LOG_SHORT>)
{
difIter<LOG_SHORT>(in_out);
}
static void iditRec(Float64 in_out[], StaticInt<LOG_SHORT>)
{
iditIter<LOG_SHORT>(in_out);
}
static void difRec(Float64 in_out[], StaticInt<LOG_SHORT - 1>)
{
difIter<LOG_SHORT - 1>(in_out);
}
static void iditRec(Float64 in_out[], StaticInt<LOG_SHORT - 1>)
{
iditIter<LOG_SHORT - 1>(in_out);
}
static void difRec(Float64 in_out[], StaticInt<4>)
{
dif16(in_out);
}
static void iditRec(Float64 in_out[], StaticInt<4>)
{
idit16(in_out);
}
static void difRec(Float64 in_out[], StaticInt<5>)
{
dif32(in_out);
}
static void iditRec(Float64 in_out[], StaticInt<5>)
{
idit32(in_out);
}
template <bool FROM_RIRI_PERM = false, int LOG_N>
static void difRec(Float64 in_out[], StaticInt<LOG_N>)
{
using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
constexpr size_t LEN = size_t(1) << LOG_N;
constexpr size_t STRIDE1 = LEN / 2, STRIDE2 = STRIDE1 * 2, STRIDE3 = STRIDE1 * 3;
auto table1 = multi_table_1.getBegin(LOG_N), table3 = multi_table_3.getBegin(LOG_N);
for (auto end = in_out + STRIDE1, it = in_out; it < end; it += 8, table1 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it, FromRIRI{}), c1.load(it + STRIDE1, FromRIRI{}), c2.load(it + STRIDE2, FromRIRI{}), c3.load(it + STRIDE3, FromRIRI{});
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table1), c2 = c2.mul(omega);
// omega.load(table3), c3 = c3.mul(omega);
c3 = c3.mul(omega.cube());
c0.store(it), c1.store(it + STRIDE1), c2.store(it + STRIDE2), c3.store(it + STRIDE3);
}
difRec(in_out, StaticInt<LOG_N - 1>{});
difRec(in_out + STRIDE2, StaticInt<LOG_N - 2>{});
difRec(in_out + STRIDE3, StaticInt<LOG_N - 2>{});
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false, int LOG_N>
static void iditRec(Float64 in_out[], StaticInt<LOG_N>)
{
using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
using ToI64 = std::integral_constant<bool, TO_INT64>;
constexpr size_t LEN = size_t(1) << LOG_N;
constexpr size_t STRIDE1 = LEN / 2, STRIDE2 = STRIDE1 * 2, STRIDE3 = STRIDE1 * 3;
iditRec(in_out, StaticInt<LOG_N - 1>{});
iditRec(in_out + STRIDE2, StaticInt<LOG_N - 2>{});
iditRec(in_out + STRIDE3, StaticInt<LOG_N - 2>{});
auto table1 = multi_table_1.getBegin(LOG_N), table3 = multi_table_3.getBegin(LOG_N);
for (auto end = in_out + STRIDE1, it = in_out; it < end; it += 8, table1 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it), c1.load(it + STRIDE1), c2.load(it + STRIDE2), c3.load(it + STRIDE3);
omega.load(table1), c2 = c2.mulConj(omega);
// omega.load(table3), c3 = c3.mulConj(omega);
c3 = c3.mulConj(omega.cube());
iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
c0.store(it, ToRIRI{}), c1.store(it + STRIDE1, ToRIRI{}), c2.store(it + STRIDE2, ToRIRI{}), c3.store(it + STRIDE3, ToRIRI{});
}
}
template <int LOG_N>
static void difIter(Float64 in_out[])
{
constexpr size_t LEN = size_t(1) << LOG_N, FLOAT_LEN = LEN * 2;
assert(LEN >= 64);
int log_rank = LOG_N;
for (; log_rank >= 6; log_rank -= 2)
{
size_t rank2 = size_t(1) << (log_rank + 1), stride1 = rank2 / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + FLOAT_LEN; begin < end; begin += rank2)
{
auto table1 = multi_table_1.getBegin(log_rank), table2 = multi_table_2.getBegin(log_rank), table3 = multi_table_3.getBegin(log_rank);
for (auto it = begin; it < begin + stride1; it += 8, table1 += 8, table2 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table2), c1 = c1.mul(omega);
omega.load(table1), c2 = c2.mul(omega);
omega.load(table3), c3 = c3.mul(omega);
c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
}
}
}
constexpr int SMALL_FFT_LOG = LOG_N % 2 == 0 ? 4 : 5;
constexpr size_t SMALL_FFT_LEN = size_t(1) << SMALL_FFT_LOG, STRIDE = SMALL_FFT_LEN * 2;
for (auto it = in_out; it < in_out + FLOAT_LEN; it += STRIDE * 4)
{
using SmallSize = StaticInt<SMALL_FFT_LOG>;
difRec(it, SmallSize{});
difRec(it + STRIDE, SmallSize{});
difRec(it + STRIDE * 2, SmallSize{});
difRec(it + STRIDE * 3, SmallSize{});
}
}
template <int LOG_N>
static void iditIter(Float64 in_out[])
{
const size_t LEN = size_t(1) << LOG_N, FLOAT_LEN = LEN * 2;
assert(LEN >= 64);
constexpr int SMALL_FFT_LOG = LOG_N % 2 == 0 ? 4 : 5;
constexpr size_t SMALL_FFT_LEN = size_t(1) << SMALL_FFT_LOG, STRIDE = SMALL_FFT_LEN * 2;
for (auto it = in_out; it < in_out + FLOAT_LEN; it += STRIDE * 4)
{
using SmallSize = StaticInt<SMALL_FFT_LOG>;
iditRec(it, SmallSize{});
iditRec(it + STRIDE, SmallSize{});
iditRec(it + STRIDE * 2, SmallSize{});
iditRec(it + STRIDE * 3, SmallSize{});
}
int log_rank = SMALL_FFT_LOG + 2;
for (; log_rank <= LOG_N; log_rank += 2)
{
size_t rank2 = size_t(1) << (log_rank + 1), stride1 = rank2 / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + FLOAT_LEN; begin < end; begin += rank2)
{
auto table1 = multi_table_1.getBegin(log_rank), table2 = multi_table_2.getBegin(log_rank), table3 = multi_table_3.getBegin(log_rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
omega.load(table2), c1 = c1.mulConj(omega);
omega.load(table1), c2 = c2.mulConj(omega);
omega.load(table3), c3 = c3.mulConj(omega);
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
}
}
}
}
};
constexpr size_t FFTFixed::LOG_SHORT;
constexpr size_t FFTFixed::SHORT_LEN;
const TableFix<Float64, 4> FFTFixed::table_8(8, 1, 4);
const TableFix<Float64, 4> FFTFixed::table_16_1(16, 1, 4);
const TableFix<Float64, 4> FFTFixed::table_16_3(16, 3, 4);
const TableFix<Float64, 8> FFTFixed::table_32_1(32, 1, 4);
const TableFix<Float64, 8> FFTFixed::table_32_3(32, 3, 4);
const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_2(2);
const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_3(3);
const TableFixMulti<Float64, 6, FFTFixed::LOG_MAX, 4> FFTFixed::multi_table_1(1);
class BinRevTableC64X4
{
public:
using F64 = double;
using C64 = std::complex<F64>;
using C64X4 = hint_simd::Complex64X4;
static constexpr int MAX_LOG_LEN = CHAR_BIT * sizeof(size_t) - 2;
static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;
// 由最大的FFT长度,最大的迭代次数得到
BinRevTableC64X4(int log_max_iter_in, int log_fft_len_in)
: log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
{
assert(log_max_iter <= log_fft_len);
assert(log_fft_len <= MAX_LOG_LEN);
F64 factor = F64(1) / (size_t(1) << (log_fft_len - log_max_iter));
table[0] = getOmega(2, 1, factor);
table[1] = getOmega(4, 1, -factor);
table[2] = getOmega(8, 1, factor);
for (int i = 3; i < MAX_LOG_LEN; i++)
{
const size_t rev_indx = 1;
const size_t last_indx = ((size_t(1) << i) - 1) << 1;
const size_t shift = size_t(6) << (i - 2);
table[i] = getOmega(size_t(1) << i + 1, last_indx - rev_indx - shift, -factor);
}
reset();
}
inline void reset(size_t i = 0)
{
auto brev = [this, i](size_t j) -> size_t
{
static const int shift = log_max_iter - 2;
static size_t rev4[4]{0, size_t(2) << shift, size_t(1) << shift, size_t(3) << shift};
if (i == 0)
{
return rev4[j];
}
int log_i = hint_log2(i);
return rev4[j] | ((size_t(1) << (log_max_iter - 1 - log_i)));
};
assert(i % 4 == 0);
F64 omegaX4[8];
for (int j = 0; j < 4; j++)
{
auto omega = getOmega(size_t(1) << log_fft_len, brev(j));
omegaX4[j] = omega.real();
omegaX4[j + 4] = omega.imag();
}
cur = C64X4(omegaX4);
index = i;
}
C64X4 iterate()
{
C64X4 diff, res = cur;
index += 4;
auto p = reinterpret_cast<F64 *>(&table[hint_ctz(index)]);
diff.load1(p, p + 1);
cur = cur.mul(diff);
return res;
}
static C64 getOmega(size_t n, size_t index, F64 factor = 1)
{
const F64 theta = -HINT_2PI * factor * index / n;
return std::polar<F64>(1.0, theta);
}
private:
C64X4 cur;
size_t index;
C64 table[MAX_LOG_LEN];
int log_max_iter, log_fft_len;
};
template <size_t RI_DIFF = 1, typename FloatTy>
inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1, const std::complex<FloatTy> &omega0)
{
using Complex = std::complex<FloatTy>;
auto combine2 = [&omega0](auto r0, auto i0, auto r1, auto i1, Complex &out0, Complex &out1)
{
auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
auto tr1 = r0 - r1, ti1 = i0 + i1; // diff
r0 = ti1 * omega0.real() + tr1 * omega0.imag();
i0 = ti1 * omega0.imag() - tr1 * omega0.real();
out0.real(tr0 + r0);
out0.imag(ti0 + i0);
out1.real(tr0 - r0);
out1.imag(i0 - ti0);
};
Complex x0, x1, x2, x3;
auto r0 = inout0[0], i0 = inout0[RI_DIFF], r1 = inout1[0], i1 = inout1[RI_DIFF];
combine2(r0, i0, r1, i1, x0, x1);
r0 = in0[0], i0 = in0[RI_DIFF], r1 = in1[0], i1 = in1[RI_DIFF];
combine2(r0, i0, r1, i1, x2, x3);
x0 *= x2;
x1 *= x3;
{ // separate2
r0 = x0.real(), i0 = x0.imag(), r1 = x1.real(), i1 = x1.imag();
auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
auto tr1 = r0 - r1, ti1 = i0 + i1; // diff
auto r = tr1 * omega0.imag() - ti1 * omega0.real();
auto i = tr1 * omega0.real() + ti1 * omega0.imag();
r0 = tr0 + r;
i0 = ti0 + i;
r1 = tr0 - r;
i1 = i - ti0;
}
inout0[0] = r0, inout0[RI_DIFF] = i0, inout1[0] = r1, inout1[RI_DIFF] = i1;
}
inline void dot_rfftX4(F64 *inout0, F64 *inout1, const F64 *in0, const F64 *in1, const C64X4 &omega0, const F64X4 &inv)
{
auto combine2 = [&omega0](C64X4 c0, C64X4 c1, C64X4 &out0, C64X4 &out1)
{
auto tr0 = c0.real + c1.real, ti0 = c0.imag - c1.imag; // sum
auto tr1 = c0.real - c1.real, ti1 = c0.imag + c1.imag; // diff
c0.real = F64X4::fmadd(ti1, omega0.real, tr1 * omega0.imag);
c0.imag = F64X4::fmsub(ti1, omega0.imag, tr1 * omega0.real);
out0.real = tr0 + c0.real;
out0.imag = ti0 + c0.imag;
out1.real = tr0 - c0.real;
out1.imag = c0.imag - ti0;
};
C64X4 x0, x1;
{
C64X4 x2, x3, x4, x5;
x0.load(inout0), x1.load(inout1);
x1.real = x1.real.reverse();
x1.imag = x1.imag.reverse();
combine2(x0, x1, x2, x3);
x0.load(in0), x1.load(in1);
x1.real = x1.real.reverse();
x1.imag = x1.imag.reverse();
combine2(x0, x1, x4, x5);
x0 = x2.mul(x4);
x1 = x3.mul(x5);
}
{ // separate2
auto tr0 = x0.real + x1.real, ti0 = x0.imag - x1.imag; // sum
auto tr1 = x0.real - x1.real, ti1 = x0.imag + x1.imag; // diff
auto r = F64X4::fmsub(tr1, omega0.imag, ti1 * omega0.real);
auto i = F64X4::fmadd(tr1, omega0.real, ti1 * omega0.imag);
x0.real = (tr0 + r) * inv;
x0.imag = (ti0 + i) * inv;
x1.real = (tr0 - r) * inv;
x1.imag = (i - ti0) * inv;
}
x1.real = x1.real.reverse();
x1.imag = x1.imag.reverse();
x0.store(inout0), x1.store(inout1);
}
inline void real_conv_binrev4(Float64 in_out[], Float64 in[], size_t len_complex)
{
const F64X4 inv4(0.125 / len_complex);
static BinRevTableC64X4 table(22, 23);
auto t0 = C64(in_out[0], in_out[4]), t1 = C64(in[0], in[4]);
auto t2 = (t0.real() + t0.imag()) * (t1.real() + t1.imag());
auto t3 = (t0.real() - t0.imag()) * (t1.real() - t1.imag());
in_out[0] = (t2 + t3) * 4, in_out[4] = (t2 - t3) * 4;
t0 = C64(in_out[1], in_out[5]) * C64(in[1], in[5]) * 8.0;
in_out[1] = t0.real(), in_out[5] = t0.imag();
dot_rfft<4>(in_out + 2, in_out + 3, in + 2, in + 3, C64(0, -1));
if (len_complex <= 4)
{
return;
}
constexpr Float64 COS_16_1 = 0.92387953251128675612818318939;
constexpr Float64 SIN_16_1 = -0.38268343236508977172845998403;
dot_rfft<4>(in_out + 8, in_out + 11, in + 8, in + 11, C64(COS_16_1, SIN_16_1));
dot_rfft<4>(in_out + 9, in_out + 10, in + 9, in + 10, C64(SIN_16_1, -COS_16_1));
for (size_t i = 0; i < 16; i++)
{
in_out[i] *= (0.125 / len_complex);
}
for (size_t len = 8; len < len_complex; len *= 2)
{
size_t begin = len * 2;
table.reset(len);
auto it0 = in_out + begin, it1 = it0 + begin - 8, it2 = in + begin, it3 = it2 + begin - 8;
for (; it0 < it1; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8)
{
auto omega = table.iterate();
dot_rfftX4(it0, it1, it2, it3, omega, inv4);
}
}
}
template <int LOG_CONV = 17, bool TO_INT = true>
inline void real_conv(F64 *in_out1, F64 *in2)
{
static_assert((LOG_CONV - 1) <= FFTFixed::LOG_MAX);
FFTFixed::difRec<true>((double *)(in_out1), StaticInt<LOG_CONV - 1>{});
FFTFixed::difRec<true>((double *)(in2), StaticInt<LOG_CONV - 1>{});
real_conv_binrev4(in_out1, in2, size_t(1) << (LOG_CONV - 1));
FFTFixed::iditRec<true, TO_INT>((double *)(in_out1), StaticInt<LOG_CONV - 1>{});
}
template <bool TO_INT = true>
inline void real_conv(F64 *in_out1, F64 *in2, size_t conv_len)
{
#define CASE_CONV(log_conv) \
case log_conv: \
real_conv<log_conv, TO_INT>(in_out1, in2); \
break;
const int log_conv = hint_log2(conv_len);
switch (log_conv)
{
CASE_CONV(11)
CASE_CONV(12)
CASE_CONV(13)
CASE_CONV(14)
CASE_CONV(15)
CASE_CONV(16)
CASE_CONV(17)
CASE_CONV(18)
CASE_CONV(19)
default:
assert(false && "Unsupported convolution length");
}
#undef CASE_CONV
}
}
}
constexpr uint64_t stoui64(const char *s, size_t dig = 4)
{
uint64_t result = 0;
for (size_t i = 0; i < dig; i++)
{
result *= 10;
result += (s[i] - '0');
}
return result;
}
constexpr uint32_t stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
constexpr uint32_t stobase100000(const char *s)
{
return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
}
static constexpr int DIGIT = 4;
constexpr uint64_t BASE = 10000;
inline size_t char_to_float64(const char *buffer, double *float_ary, size_t str_len)
{
int64_t len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
uint32_t tmp = stobase10000(buffer + pos - DIGIT);
float_ary[i] = tmp;
i++;
pos -= DIGIT;
}
if (pos > 0)
{
uint32_t tmp = stoui64(buffer, pos);
float_ary[i] = tmp;
}
return len;
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
void tostr(char *str, uint32_t num) const
{
*reinterpret_cast<uint32_t *>(str) = table[num];
}
uint32_t tostr(uint32_t num) const
{
return table[num];
}
};
// 读取两个数字字符串
void read_2num_str(const char *s, const char *&a, size_t &len_a, const char *&b, size_t &len_b)
{
while (!isdigit(*s))
{
s++;
}
a = s;
while (*s >= '0')
{
s++;
}
len_a = s - a;
while (!isdigit(*s))
{
s++;
}
b = s;
len_b = strlen(b);
while (!isdigit(b[len_b - 1]))
{
len_b--;
}
}
}
using namespace hint;
using namespace transform;
using namespace fft;
void test_big_mul()
{
constexpr size_t STR_LEN = 2000008;
constexpr int LOG_LEN = 18;
constexpr size_t MAX_FFT_LEN = size_t(1) << LOG_LEN;
constexpr size_t FLOAT_MAX_LEN = MAX_FFT_LEN * 2;
static constexpr ItoStrBase10000 transfer;
static AlignAry<char, STR_LEN> out;
static AlignAry<Float64, FLOAT_MAX_LEN> ary1;
static AlignAry<Float64, FLOAT_MAX_LEN> ary2;
uint32_t *ary = out.template cast_ptr<uint32_t>();
size_t len_a = 0, len_b = 0;
fread(out.data(), 1, STR_LEN, stdin);
// str_fill(out.data(), (STR_LEN - 8) / 2);
const char *a, *b;
read_2num_str(out.data(), a, len_a, b, len_b);
/*
struct stat sb;
int fd = fileno(stdin);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
*/
if (len_a == 1 && a[0] == '0')
{
puts("0");
return;
}
if (len_b == 1 && b[0] == '0')
{
puts("0");
return;
}
size_t len2 = char_to_float64(b, ary2.data(), len_b);
size_t len1 = char_to_float64(a, ary1.data(), len_a);
size_t conv_len = len1 + len2 - 1, len = int_ceil2(conv_len);
real_conv(ary1.data(), ary2.data(), len);
auto i64_ary1 = reinterpret_cast<uint64_t *>(ary1.data());
uint64_t carry = 0;
size_t pos = STR_LEN / 4 - 1;
for (size_t i = 0; i < conv_len; i++)
{
carry += i64_ary1[i];
uint64_t num = carry % BASE;
carry /= BASE;
ary[pos] = transfer.tostr(num);
pos--;
}
ary[pos] = transfer.tostr(carry);
pos *= 4;
while (out[pos] == '0')
{
pos++;
} // 0.8ms
fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}
int main()
{
test_big_mul();
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 8.818 ms | 13 MB + 908 KB | Wrong Answer | Score: 0 | 显示更多 |