// TSKY 2025/9/16
#include <array>
#include <complex>
#include <iostream>
#include <type_traits>
#include <cstdint>
#include <climits>
#include <cstring>
#include <cassert>
#include <immintrin.h>
#define __FMA__
#define __AVX2__
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("avx2")
#pragma GCC target("fma")
namespace hint_simd
{
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(4096) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
T *data()
{
return reinterpret_cast<T *>(ary);
}
const T *data() const
{
return reinterpret_cast<const T *>(ary);
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *cast_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
template <typename YMM>
inline void transpose64_2X4(YMM &row0, YMM &row1)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
row0 = YMM(_mm256_permute2f128_pd(t0, t1, 0x20)); // 0,4,2,6 1,5,3,7 -> 0,4,1,5
row1 = YMM(_mm256_permute2f128_pd(t0, t1, 0x31)); // 0,4,2,6 1,5,3,7 -> 2,6,3,7
}
template <typename YMM>
inline void transpose64_4X2(YMM &row0, YMM &row1)
{
auto t0 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x20); // 0,1,2,3 4,5,6,7 -> 0,1,4,5
auto t1 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x31); // 0,1,2,3 4,5,6,7 -> 2,3,6,7
row0 = YMM(_mm256_unpacklo_pd(t0, t1)); // 0,1,4,5 2,3,6,7 -> 0,4,2,6
row1 = YMM(_mm256_unpackhi_pd(t0, t1)); // 0,1,4,5 2,3,6,7 -> 1,5,3,7
}
template <typename YMM>
inline void transpose64_4X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3)
{
auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15
row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
}
class Float64X4
{
public:
using F64 = double;
using F64X4 = Float64X4;
Float64X4() : data(_mm256_setzero_pd()) {}
Float64X4(__m256d in_data) : data(in_data) {}
Float64X4(F64 in_data) : data(_mm256_set1_pd(in_data)) {}
Float64X4(const F64 *in_data) : data(_mm256_load_pd(in_data)) {}
F64X4 operator+(const F64X4 &other) const
{
return _mm256_add_pd(data, other.data);
}
F64X4 operator-(const F64X4 &other) const
{
return _mm256_sub_pd(data, other.data);
}
F64X4 operator*(const F64X4 &other) const
{
return _mm256_mul_pd(data, other.data);
}
F64X4 operator/(const F64X4 &other) const
{
return _mm256_div_pd(data, other.data);
}
F64X4 &operator+=(const F64X4 &other)
{
return *this = *this + other;
}
F64X4 &operator-=(const F64X4 &other)
{
return *this = *this - other;
}
F64X4 &operator*=(const F64X4 &other)
{
return *this = *this * other;
}
F64X4 &operator/=(const F64X4 &other)
{
return *this = *this / other;
}
F64X4 floor() const
{
return _mm256_floor_pd(data);
}
// a * b + c
static F64X4 fmadd(const F64X4 &a, const F64X4 &b, const F64X4 &c)
{
#ifdef __FMA__
return _mm256_fmadd_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
return a * b + c;
#endif
}
// a * b - c
static F64X4 fmsub(const F64X4 &a, const F64X4 &b, const F64X4 &c)
{
#ifdef __FMA__
return _mm256_fmsub_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
return a * b - c;
#endif
}
#ifdef __AVX2__
template <int N>
F64X4 permute4x64() const
{
return _mm256_permute4x64_pd(data, N);
}
#else
template <int N>
F64X4 permute4x64() const
{
alignas(32) uint64_t arr[4];
alignas(32) uint64_t dst[4];
this->store(reinterpret_cast<F64 *>(arr));
dst[0] = arr[(N >> 0) & 3];
dst[1] = arr[(N >> 2) & 3];
dst[2] = arr[(N >> 4) & 3];
dst[3] = arr[(N >> 6) & 3];
return fromMem(reinterpret_cast<const F64 *>(dst));
}
#endif
static F64X4 extractEven64X4(const F64X4 &in0, const F64X4 &in1)
{
F64X4 result = _mm256_unpacklo_pd(in0.data, in1.data); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
return result.permute4x64<0b11011000>(); // 0,4,2,6 -> 0,2,4,6
}
template <int N>
F64X4 permute() const
{
return _mm256_permute_pd(data, N);
}
F64X4 reverse() const
{
return permute4x64<0b00011011>();
}
void load(const F64 *p)
{
data = _mm256_load_pd(p);
}
void loadu(const F64 *p)
{
data = _mm256_loadu_pd(p);
}
void load1(const F64 *p)
{
data = _mm256_broadcast_sd(p);
}
static F64X4 fromMem(const F64 *p)
{
return _mm256_load_pd(p);
}
static F64X4 fromUMem(const F64 *p)
{
return _mm256_loadu_pd(p);
}
void store(F64 *p) const
{
_mm256_store_pd(p, data);
}
void storeu(F64 *p) const
{
_mm256_storeu_pd(p, data);
}
operator __m256d() const
{
return data;
}
#ifdef __AVX2__
// Convert positive double to int64
__m256i toI64X4() const
{
constexpr uint64_t mask = (uint64_t(1) << 52) - 1;
constexpr uint64_t offset = (uint64_t(1) << 10) - 1;
const __m256i f64bits = _mm256_castpd_si256(data);
__m256i tail = _mm256_and_si256(f64bits, _mm256_set1_epi64x(mask));
tail = _mm256_or_si256(tail, _mm256_set1_epi64x(mask + 1));
__m256i exp = _mm256_srli_epi64(f64bits, 52);
exp = _mm256_sub_epi64(_mm256_set1_epi64x(offset + 52), exp);
return _mm256_srlv_epi64(tail, exp);
}
#else
#pragma message("No AVX2 support")
__m256i toI64X4() const
{
alignas(32) F64 arr[4];
alignas(32) int64_t i64_arr[4];
this->store(arr);
i64_arr[0] = arr[0];
i64_arr[1] = arr[1];
i64_arr[2] = arr[2];
i64_arr[3] = arr[3];
return _mm256_load_si256(reinterpret_cast<const __m256i *>(i64_arr));
}
#endif
template <int N>
F64 nthEle() const
{
union F64I64
{
int64_t i64;
F64 f64;
} temp;
temp.i64 = _mm256_extract_epi64(__m256i(data), N);
return temp.f64;
}
void print() const
{
std::cout << "[" << nthEle<0>() << "," << nthEle<1>()
<< "," << nthEle<2>() << "," << nthEle<3>() << "]" << std::endl;
}
private:
__m256d data;
};
struct Complex64X4
{
using C64X4 = Complex64X4;
using F64X4 = Float64X4;
using F64 = double;
Complex64X4() {}
Complex64X4(F64X4 real, F64X4 imag) : real(real), imag(imag) {}
Complex64X4(const F64 *p) : real(p), imag(p + 4) {}
Complex64X4(const F64 *p_real, const F64 *p_imag) : real(p_real), imag(p_imag) {}
C64X4 operator+(const C64X4 &other) const
{
return C64X4(real + other.real, imag + other.imag);
}
C64X4 operator-(const C64X4 &other) const
{
return C64X4(real - other.real, imag - other.imag);
}
C64X4 operator*(const F64X4 &other) const
{
return C64X4(real * other, imag * other);
}
C64X4 mul(const C64X4 &other) const
{
const F64X4 ii = imag * other.imag;
const F64X4 ri = real * other.imag;
const F64X4 r = F64X4::fmsub(real, other.real, ii);
const F64X4 i = F64X4::fmadd(imag, other.real, ri);
return C64X4(r, i);
}
C64X4 mulConj(const C64X4 &other) const
{
const F64X4 ii = imag * other.imag;
const F64X4 ri = real * other.imag;
const F64X4 r = F64X4::fmadd(real, other.real, ii);
const F64X4 i = F64X4::fmsub(imag, other.real, ri);
return C64X4(r, i);
}
C64X4 reverse() const
{
return C64X4(real.reverse(), imag.reverse());
}
// exp{i*theta*k},k in {0,1,2,3}
static C64X4 omegaSeq0To3(F64 theta, F64 begin = 0)
{
F64 real_arr[4] = {cos(begin), cos(theta + begin), cos(2 * theta + begin), cos(3 * theta + begin)};
F64 imag_arr[4] = {sin(begin), sin(theta + begin), sin(2 * theta + begin), sin(3 * theta + begin)};
return C64X4(F64X4(real_arr), F64X4(imag_arr));
}
template <typename T>
void load(const T *p, std::false_type)
{
this->load(p);
}
// From RIRI permutation
template <typename T>
void load(const T *p, std::true_type)
{
this->load(p);
*this = this->toRRIIPermu();
}
template <typename T>
void load(const T *p)
{
real.load(reinterpret_cast<const F64 *>(p));
imag.load(reinterpret_cast<const F64 *>(p) + 4);
}
template <typename T>
void loadu(const T *p)
{
real.loadu(reinterpret_cast<const F64 *>(p));
imag.loadu(reinterpret_cast<const F64 *>(p) + 4);
}
void load1(const F64 *real_p, const F64 *imag_p)
{
real.load1(real_p);
imag.load1(imag_p);
}
template <typename T>
void store(T *p, std::false_type) const
{
this->store(p);
}
// To RIRI permutation
template <typename T>
void store(T *p, std::true_type) const
{
this->toRIRIPermu().store(p);
}
template <typename T>
void store(T *p) const
{
real.store(reinterpret_cast<F64 *>(p));
imag.store(reinterpret_cast<F64 *>(p) + 4);
}
template <typename T>
void storeu(T *p) const
{
real.storeu(reinterpret_cast<F64 *>(p));
imag.storeu(reinterpret_cast<F64 *>(p) + 4);
}
C64X4 square() const
{
const F64X4 ii = imag * imag;
const F64X4 ri = real * imag;
const F64X4 r = F64X4::fmsub(real, real, ii);
const F64X4 i = ri + ri;
return C64X4(r, i);
}
C64X4 cube() const
{
const F64X4 rr = real * real;
const F64X4 ii = imag * imag;
const F64X4 rr3 = rr + rr + rr;
const F64X4 ii3 = ii + ii + ii;
const F64X4 r = real * (rr - ii3);
const F64X4 i = imag * (rr3 - ii);
return C64X4(r, i);
}
C64X4 toRIRIPermu() const
{
C64X4 res = *this;
transpose64_2X4(res.real, res.imag);
return res;
}
C64X4 toRRIIPermu() const
{
C64X4 res = *this;
transpose64_4X2(res.real, res.imag);
return res;
}
void print() const
{
alignas(32) F64 real_arr[4]{}, imag_arr[4]{};
real.storeu(real_arr);
imag.storeu(imag_arr);
std::cout << "[(" << real_arr[0] << ", " << imag_arr[0] << "), ("
<< real_arr[1] << ", " << imag_arr[1] << "), ("
<< real_arr[2] << ", " << imag_arr[2] << "), ("
<< real_arr[3] << ", " << imag_arr[3] << ")]" << std::endl;
}
C64X4 transToI64(std::false_type) const
{
return *this;
}
C64X4 transToI64(std::true_type) const
{
constexpr int64_t F1_2 = 4602678819172646912; // magic::bit_cast<int64_t>(0.5);
auto F1_2X4 = F64X4(__m256d(_mm256_set1_epi64x(F1_2)));
auto real_i64 = (real + F1_2X4).toI64X4();
auto imag_i64 = (imag + F1_2X4).toI64X4();
return C64X4(__m256d(real_i64), __m256d(imag_i64));
}
F64X4 real, imag;
};
}
#endif
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
constexpr Float64 COS_PI_8 = 0.707106781186547524400844;
constexpr size_t FFT_MAX_LEN = size_t(1) << 23;
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename IntTy>
constexpr bool is_2pow(IntTy n)
{
return n != 0 && (n & (n - 1)) == 0;
}
// 求整数的对数
template <typename T>
constexpr int hint_log2(T n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r)
{
int mid = (l + r) / 2;
if ((T(1) << mid) > n)
{
r = mid;
}
else
{
l = mid;
}
}
return l;
}
constexpr int hint_ctz(uint32_t x)
{
int r0 = 31;
x &= (-x);
if (x & 0x55555555)
{
r0 &= ~1;
}
if (x & 0x33333333)
{
r0 &= ~2;
}
if (x & 0x0F0F0F0F)
{
r0 &= ~4;
}
if (x & 0x00FF00FF)
{
r0 &= ~8;
}
if (x & 0x0000FFFF)
{
r0 &= ~16;
}
r0 += (x == 0);
return r0;
}
constexpr int hint_ctz(uint64_t x)
{
int r0 = 63;
x &= (-x);
if (x & 0x5555555555555555)
{
r0 &= ~1; // -1
}
if (x & 0x3333333333333333)
{
r0 &= ~2; // -2
}
if (x & 0x0F0F0F0F0F0F0F0F)
{
r0 &= ~4; // -4
}
if (x & 0x00FF00FF00FF00FF)
{
r0 &= ~8; // -8
}
if (x & 0x0000FFFF0000FFFF)
{
r0 &= ~16; // -16
}
if (x & 0x00000000FFFFFFFF)
{
r0 &= ~32; // -32
}
r0 += (x == 0);
return r0;
}
constexpr int hint_popcnt(uint32_t n)
{
constexpr uint32_t mask55 = 0x55555555;
constexpr uint32_t mask33 = 0x33333333;
constexpr uint32_t mask0f = 0x0f0f0f0f;
constexpr uint32_t maskff = 0x00ff00ff;
n = (n & mask55) + ((n >> 1) & mask55);
n = (n & mask33) + ((n >> 2) & mask33);
n = (n & mask0f) + ((n >> 4) & mask0f);
n = (n & maskff) + ((n >> 8) & maskff);
return uint16_t(n) + (n >> 16);
}
constexpr int hint_popcnt(uint64_t n)
{
constexpr uint64_t mask5555 = 0x5555555555555555;
constexpr uint64_t mask3333 = 0x3333333333333333;
constexpr uint64_t mask0f0f = 0x0f0f0f0f0f0f0f0f;
constexpr uint64_t mask00ff = 0x00ff00ff00ff00ff;
constexpr uint64_t maskffff = 0x0000ffff0000ffff;
n = (n & mask5555) + ((n >> 1) & mask5555);
n = (n & mask3333) + ((n >> 2) & mask3333);
n = (n & mask0f0f) + ((n >> 4) & mask0f0f);
n = (n & mask00ff) + ((n >> 8) & mask00ff);
n = (n & maskffff) + ((n >> 16) & maskffff);
return uint32_t(n) + (n >> 32);
}
// FFT与类FFT变换的命名空间
namespace transform
{
using namespace hint_simd;
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
template <typename T>
inline void transform2(const T a, const T b, T &sum, T &diff)
{
sum = a + b;
diff = a - b;
}
namespace fft
{
using F64 = Float64;
using C64 = std::complex<F64>;
using F64X4 = Float64X4;
using C64X4 = Complex64X4;
template <typename Float, size_t OMEGA_LEN>
class TableFix
{
alignas(64) std::array<Float, OMEGA_LEN * 2> table;
public:
TableFix(size_t theta_divider, size_t factor, size_t stride)
{
const Float theta = -HINT_2PI * factor / theta_divider;
assert(OMEGA_LEN % stride == 0);
for (size_t begin = 0, index = 0; begin < OMEGA_LEN * 2; begin += stride * 2)
{
for (size_t j = 0; j < stride; j++, index++)
{
table[begin + j] = std::cos(theta * index);
table[begin + j + stride] = std::sin(theta * index);
}
}
}
constexpr const Float &operator[](size_t index) const
{
return table[index];
}
constexpr const Float *getOmegaIt(size_t index) const
{
return &table[index];
}
};
template <typename Float, int LOG_BEGIN, int LOG_END, int DIV>
class TableFixMulti
{
static_assert(LOG_END >= LOG_BEGIN);
static_assert(is_2pow(DIV));
static constexpr size_t TABLE_CPX_LEN = (size_t(1) << (LOG_END + 1)) / DIV;
alignas(64) std::array<Float, TABLE_CPX_LEN * 2> table;
public:
TableFixMulti(size_t factor, size_t stride = 4)
{
assert(((size_t(1) << LOG_BEGIN) / DIV) % stride == 0);
initBottomUp(factor, stride);
}
void initTopDown(size_t factor, size_t stride)
{
static_assert(std::is_same<Float, Float64>::value);
assert(stride == 4);
}
void initBottomUp(size_t factor, size_t stride)
{
static_assert(std::is_same<Float, Float64>::value);
assert(stride == 4);
size_t len = size_t(1) << LOG_BEGIN, cpx_len = len / DIV;
Float theta = -HINT_2PI * factor / len;
auto it = getBeginLog(LOG_BEGIN);
for (size_t i = 0; i < cpx_len; i++)
{
it[0] = std::cos(theta * i), it[stride] = std::sin(theta * i);
it += (i % stride == stride - 1 ? stride + 1 : 1);
}
it = getBeginLog(LOG_BEGIN);
for (int log_len = LOG_BEGIN + 1; log_len <= LOG_END; log_len++)
{
len = size_t(1) << log_len, cpx_len = len / DIV;
theta = -HINT_2PI * factor / len;
auto it = getBeginLog(log_len), it_last = getBeginLog(log_len - 1);
C64X4 unit(std::cos(theta), std::sin(theta));
for (auto end = it + cpx_len * 2; it < end; it += 16, it_last += 8)
{
Complex64X4 omega0, omega1;
omega0.load(it_last);
omega1 = omega0.mul(unit);
transpose64_2X4(omega0.real, omega1.real);
transpose64_2X4(omega0.imag, omega1.imag);
omega0.store(it), omega1.store(it + 8);
}
}
}
constexpr const Float *getBeginLog(int log_rank) const
{
return getBegin(size_t(1) << log_rank);
}
constexpr Float *getBeginLog(int log_rank)
{
return getBegin(size_t(1) << log_rank);
}
constexpr const Float *getBegin(size_t rank) const
{
return &table[rank * 2 / DIV];
}
constexpr Float *getBegin(size_t rank)
{
return &table[rank * 2 / DIV];
}
};
struct FFT
{
template <typename Float>
static void dif4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
difSplit(r0, i0, r1, i1, r2, i2, r3, i3);
transform2(r0, r1);
transform2(i0, i1);
}
template <typename Float>
static void idit4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r0, r1);
transform2(i0, i1);
iditSplit(r0, i0, r1, i1, r2, i2, r3, i3);
}
template <typename Float>
static void difSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, r3);
transform2(i1, i3);
transform2(r2, i3);
transform2(i2, r3, r3, i2);
std::swap(i3, r3);
}
template <typename Float>
static void iditSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
{
transform2(r2, r3);
transform2(i2, i3);
transform2(r0, r2);
transform2(i0, i2);
transform2(r1, i3, i3, r1);
transform2(i1, r3);
std::swap(i3, r3);
}
};
struct FFTAVX : public FFT
{
static constexpr size_t LOG_SHORT = 10;
static constexpr size_t LOG_MID = 14;
static constexpr size_t LOG_MAX = 18;
static constexpr size_t SHORT_LEN = size_t(1) << LOG_SHORT;
static constexpr size_t MID_LEN = size_t(1) << LOG_MID;
static constexpr size_t MAX_LEN = size_t(1) << LOG_MAX;
static const TableFix<Float64, 4> table_8;
static const TableFix<Float64, 4> table_16_1;
static const TableFix<Float64, 4> table_16_3;
static const TableFix<Float64, 8> table_32_1;
static const TableFix<Float64, 8> table_32_3;
static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_3;
static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_2;
static const TableFixMulti<Float64, 6, LOG_MAX, 4> multi_table_1;
static constexpr const Float64 *it8 = &table_8[0];
static constexpr const Float64 *it16_1 = &table_16_1[0];
static constexpr const Float64 *it16_3 = &table_16_3[0];
static constexpr const Float64 *it32_1 = &table_32_1[0];
static constexpr const Float64 *it32_3 = &table_32_3[0];
static void dif4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
{
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
dif4(r0, i0, r1, i1, r2, i2, r3, i3);
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
}
static void idit4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
{
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
idit4(r0, i0, r1, i1, r2, i2, r3, i3);
transpose64_4X4(r0, r1, r2, r3);
transpose64_4X4(i0, i1, i2, i3);
}
static void dif8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
{
transform2(c0, c1);
transform2(c2, c3);
c1 = c1.mul(omega);
c3 = c3.mul(omega);
FFTAVX::dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
}
static void idit8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
{
idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c1 = c1.mulConj(omega);
c3 = c3.mulConj(omega);
transform2(c0, c1);
transform2(c2, c3);
}
static void dif8x2(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
dif8x2(c0, c1, c2, c3, omega);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void idit8x2(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
idit8x2(c0, c1, c2, c3, omega);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void dif16(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it8), c1 = c1.mul(omega);
omega.load(it16_1), c2 = c2.mul(omega);
omega.load(it16_3), c3 = c3.mul(omega);
dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void idit16(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it8), c1 = c1.mulConj(omega);
omega.load(it16_1), c2 = c2.mulConj(omega);
omega.load(it16_3), c3 = c3.mulConj(omega);
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
}
static void dif32(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(in_out), c1.load(in_out + 16), c2.load(in_out + 32), c3.load(in_out + 48);
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it32_1), c2 = c2.mul(omega);
omega.load(it32_3), c3 = c3.mul(omega);
c0.store(in_out), c1.store(in_out + 16), c2.store(in_out + 32), c3.store(in_out + 48);
c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56); // 1,3,5,7
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(it32_1 + 8), c2 = c2.mul(omega);
omega.load(it32_3 + 8), c3 = c3.mul(omega);
c0.store(in_out + 8), c1.store(in_out + 24);
c0.load(in_out + 32), c1.load(in_out + 48), omega.load(it8); // 4,6
dif8x2(c0, c2, c1, c3, omega);
c0.store(in_out + 32), c2.store(in_out + 40), c1.store(in_out + 48), c3.store(in_out + 56);
dif16(in_out);
}
static void idit32(Float64 in_out[])
{
Complex64X4 c0, c1, c2, c3, omega;
idit16(in_out);
c0.load(in_out + 32), c1.load(in_out + 40), c2.load(in_out + 48), c3.load(in_out + 56), omega.load(it8); // 4,5,6,7
idit8x2(c0, c1, c2, c3, omega);
c1.store(in_out + 40), c3.store(in_out + 56);
c1.load(in_out), c3.load(in_out + 16);
omega.load(it32_1), c0 = c0.mulConj(omega);
omega.load(it32_3), c2 = c2.mulConj(omega);
iditSplit(c1.real, c1.imag, c3.real, c3.imag, c0.real, c0.imag, c2.real, c2.imag);
c1.store(in_out), c3.store(in_out + 16), c0.store(in_out + 32), c2.store(in_out + 48);
c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
omega.load(it32_1 + 8), c2 = c2.mulConj(omega);
omega.load(it32_3 + 8), c3 = c3.mulConj(omega);
iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(in_out + 8), c1.store(in_out + 24), c2.store(in_out + 40), c3.store(in_out + 56);
}
static void dif16(Float64 in_out[], size_t float_len)
{
assert(float_len >= 32);
for (auto end = in_out + float_len; in_out < end; in_out += 32)
{
dif16(in_out);
}
}
static void idit16(Float64 in_out[], size_t float_len)
{
assert(float_len >= 32);
for (auto end = in_out + float_len; in_out < end; in_out += 32)
{
idit16(in_out);
}
}
static void dif32(Float64 in_out[], size_t float_len)
{
assert(float_len >= 64);
for (auto end = in_out + float_len; in_out < end; in_out += 64)
{
dif32(in_out);
}
}
static void idit32(Float64 in_out[], size_t float_len)
{
assert(float_len >= 64);
for (auto end = in_out + float_len; in_out < end; in_out += 64)
{
idit32(in_out);
}
}
static void difIter(Float64 in_out[], size_t float_len)
{
size_t fft_len = float_len / 2;
assert(fft_len <= SHORT_LEN);
for (size_t rank = fft_len; rank >= 64; rank /= 4)
{
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table2), c1 = c1.mul(omega);
omega.load(table1), c2 = c2.mul(omega);
omega.load(table3), c3 = c3.mul(omega);
c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
}
}
}
if (hint_log2(fft_len) % 2 == 0)
{
dif16(in_out, float_len);
}
else
{
dif32(in_out, float_len);
}
}
static void iditIter(Float64 in_out[], size_t float_len)
{
size_t fft_len = float_len / 2;
assert(fft_len <= SHORT_LEN);
size_t rank = 0;
if (hint_log2(fft_len) % 2 == 0)
{
idit16(in_out, float_len);
rank = 64;
}
else
{
idit32(in_out, float_len);
rank = 128;
}
for (; rank <= fft_len; rank *= 4)
{
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
omega.load(table2), c1 = c1.mulConj(omega);
omega.load(table1), c2 = c2.mulConj(omega);
omega.load(table3), c3 = c3.mulConj(omega);
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
}
}
}
}
template <bool FROM_RIRI_PERM = false>
static void dif2LayerMid(Float64 in_out[], size_t float_len, size_t rank)
{
using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega, omega2;
c0.load(it0, FromRIRI{}), c1.load(it1, FromRIRI{}), c2.load(it2, FromRIRI{}), c3.load(it3, FromRIRI{});
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table1), c2 = c2.mul(omega);
omega2 = omega.square(), c1 = c1.mul(omega2);
c3 = c3.mul(omega2.mul(omega));
c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
}
}
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void idit2LayerMid(Float64 in_out[], size_t float_len, size_t rank)
{
using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
using ToI64 = std::integral_constant<bool, TO_INT64>;
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega, omega2;
c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
omega.load(table1), c2 = c2.mulConj(omega);
omega2 = omega.square(), c1 = c1.mulConj(omega2);
c3 = c3.mulConj(omega2.mul(omega));
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
c0.store(it0, ToRIRI{}), c1.store(it1, ToRIRI{}), c2.store(it2, ToRIRI{}), c3.store(it3, ToRIRI{});
}
}
}
template <bool FROM_RIRI_PERM = false>
static void dif2LayerLarge(Float64 in_out[], size_t float_len, size_t rank)
{
using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega, omega2;
c0.load(it0, FromRIRI{}), c1.load(it1, FromRIRI{}), c2.load(it2, FromRIRI{}), c3.load(it3, FromRIRI{});
dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table1), c2 = c2.mul(omega);
omega2 = omega.square(), c1 = c1.mul(omega2);
c3 = c3.mul(omega2.mul(omega));
c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
}
}
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void idit2LayerLarge(Float64 in_out[], size_t float_len, size_t rank)
{
using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
using ToI64 = std::integral_constant<bool, TO_INT64>;
const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
{
auto table1 = multi_table_1.getBegin(rank);
auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega, omega2;
c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
omega.load(table1), c2 = c2.mulConj(omega);
omega2 = omega.square(), c1 = c1.mulConj(omega2);
c3 = c3.mulConj(omega2.mul(omega));
idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
c0.store(it0, ToRIRI{}), c1.store(it1, ToRIRI{}), c2.store(it2, ToRIRI{}), c3.store(it3, ToRIRI{});
}
}
}
template <bool FROM_RIRI_PERM = false>
static void difMid(Float64 in_out[], size_t float_len)
{
size_t rank = float_len / 2;
dif2LayerMid<FROM_RIRI_PERM>(in_out, float_len, rank);
rank /= 4;
for (; rank > SHORT_LEN; rank /= 4)
{
dif2LayerMid(in_out, float_len, rank);
}
for (auto end = in_out + float_len; in_out < end; in_out += rank * 2)
{
difIter(in_out, rank * 2);
}
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void iditMid(Float64 in_out[], size_t float_len)
{
constexpr size_t SHORT_LEN_RADIX4 = size_t(1) << ((LOG_SHORT) / 2 * 2);
constexpr size_t SHORT_LEN_RADIX2 = SHORT_LEN_RADIX4 == SHORT_LEN ? SHORT_LEN_RADIX4 / 2 : SHORT_LEN;
size_t fft_len = float_len / 2;
size_t rank = hint_log2(fft_len) % 2 == 0 ? SHORT_LEN_RADIX4 : SHORT_LEN_RADIX2;
rank = std::min(rank, fft_len);
for (auto it = in_out, end = in_out + float_len; it < end; it += rank * 2)
{
iditIter(it, rank * 2);
}
rank *= 4;
for (; rank < fft_len; rank *= 4)
{
idit2LayerMid(in_out, float_len, rank);
}
idit2LayerMid<TO_RIRI_PERM, TO_INT64>(in_out, float_len, rank);
}
template <bool FROM_RIRI_PERM = false>
static void difLarge(Float64 in_out[], size_t float_len)
{
size_t rank = float_len / 2;
dif2LayerLarge<FROM_RIRI_PERM>(in_out, float_len, rank);
rank /= 4;
for (; rank > MID_LEN; rank /= 4)
{
dif2LayerLarge(in_out, float_len, rank);
}
for (auto end = in_out + float_len; in_out < end; in_out += rank * 2)
{
difMid(in_out, rank * 2);
}
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void iditLarge(Float64 in_out[], size_t float_len)
{
constexpr size_t MID_LEN_RADIX4 = size_t(1) << ((LOG_MID) / 2 * 2);
constexpr size_t MID_LEN_RADIX2 = MID_LEN_RADIX4 == MID_LEN ? MID_LEN_RADIX4 / 2 : MID_LEN;
size_t fft_len = float_len / 2;
size_t rank = hint_log2(fft_len) % 2 == 0 ? MID_LEN_RADIX4 : MID_LEN_RADIX2;
rank = std::min(rank, fft_len);
for (auto it = in_out, end = in_out + float_len; it < end; it += rank * 2)
{
iditMid(it, rank * 2);
}
rank *= 4;
for (; rank < fft_len; rank *= 4)
{
idit2LayerLarge(in_out, float_len, rank);
}
idit2LayerLarge<TO_RIRI_PERM, TO_INT64>(in_out, float_len, rank);
}
template <bool FROM_RIRI_PERM = false>
static void difRec(Float64 in_out[], size_t float_len)
{
using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
const size_t fft_len = float_len / 2;
assert(fft_len <= MAX_LEN);
if (fft_len <= SHORT_LEN)
{
difIter(in_out, float_len);
assert(!FROM_RIRI_PERM);
return;
}
const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
auto table1 = multi_table_1.getBegin(fft_len);
for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it, FromRIRI{}), c1.load(it + stride1, FromRIRI{}), c2.load(it + stride2, FromRIRI{}), c3.load(it + stride3, FromRIRI{});
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table1), c2 = c2.mul(omega);
c3 = c3.mul(omega.cube());
c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
}
difRec(in_out, stride2);
difRec(in_out + stride2, stride1);
difRec(in_out + stride3, stride1);
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void iditRec(Float64 in_out[], size_t float_len)
{
const size_t fft_len = float_len / 2;
assert(fft_len <= MAX_LEN);
if (fft_len <= SHORT_LEN)
{
iditIter(in_out, float_len);
assert(!TO_RIRI_PERM);
return;
}
using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
using ToI64 = std::integral_constant<bool, TO_INT64>;
const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
iditRec(in_out, stride2);
iditRec(in_out + stride2, stride1);
iditRec(in_out + stride3, stride1);
auto table1 = multi_table_1.getBegin(fft_len);
for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
omega.load(table1), c2 = c2.mulConj(omega);
c3 = c3.mulConj(omega.cube());
iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
c0.store(it, ToRIRI{}), c1.store(it + stride1, ToRIRI{}), c2.store(it + stride2, ToRIRI{}), c3.store(it + stride3, ToRIRI{});
}
}
template <bool FROM_RIRI_PERM = false>
static void difRecS(Float64 in_out[], size_t float_len)
{
using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
const size_t fft_len = float_len / 2;
assert(fft_len <= MAX_LEN);
if (fft_len <= MID_LEN)
{
difMid<FROM_RIRI_PERM>(in_out, float_len);
return;
}
const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
auto table1 = multi_table_1.getBegin(fft_len);
for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it, FromRIRI{}), c1.load(it + stride1, FromRIRI{}), c2.load(it + stride2, FromRIRI{}), c3.load(it + stride3, FromRIRI{});
difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
omega.load(table1), c2 = c2.mul(omega);
c3 = c3.mul(omega.cube());
c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
}
difRecS(in_out, stride2);
difRecS(in_out + stride2, stride1);
difRecS(in_out + stride3, stride1);
}
template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
static void iditRecS(Float64 in_out[], size_t float_len)
{
const size_t fft_len = float_len / 2;
assert(fft_len <= MAX_LEN);
if (fft_len <= MID_LEN)
{
iditMid<TO_RIRI_PERM, TO_INT64>(in_out, float_len);
return;
}
using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
using ToI64 = std::integral_constant<bool, TO_INT64>;
const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
iditRecS(in_out, stride2);
iditRecS(in_out + stride2, stride1);
iditRecS(in_out + stride3, stride1);
auto table1 = multi_table_1.getBegin(fft_len);
for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
{
Complex64X4 c0, c1, c2, c3, omega;
c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
omega.load(table1), c2 = c2.mulConj(omega);
c3 = c3.mulConj(omega.cube());
iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
c0.store(it, ToRIRI{}), c1.store(it + stride1, ToRIRI{}), c2.store(it + stride2, ToRIRI{}), c3.store(it + stride3, ToRIRI{});
}
}
};
constexpr size_t FFTAVX::LOG_SHORT;
constexpr size_t FFTAVX::LOG_MID;
constexpr size_t FFTAVX::LOG_MAX;
constexpr size_t FFTAVX::SHORT_LEN;
constexpr size_t FFTAVX::MID_LEN;
constexpr size_t FFTAVX::MAX_LEN;
const TableFix<Float64, 4> FFTAVX::table_8(8, 1, 4);
const TableFix<Float64, 4> FFTAVX::table_16_1(16, 1, 4);
const TableFix<Float64, 4> FFTAVX::table_16_3(16, 3, 4);
const TableFix<Float64, 8> FFTAVX::table_32_1(32, 1, 4);
const TableFix<Float64, 8> FFTAVX::table_32_3(32, 3, 4);
const TableFixMulti<Float64, 6, FFTAVX::LOG_SHORT, 4> FFTAVX::multi_table_3(3);
const TableFixMulti<Float64, 6, FFTAVX::LOG_SHORT, 4> FFTAVX::multi_table_2(2);
const TableFixMulti<Float64, 6, FFTAVX::LOG_MAX, 4> FFTAVX::multi_table_1(1);
constexpr uint32_t bitrev32(uint32_t n)
{
constexpr uint32_t mask55 = 0x55555555;
constexpr uint32_t mask33 = 0x33333333;
constexpr uint32_t mask0f = 0x0f0f0f0f;
constexpr uint32_t maskff = 0x00ff00ff;
n = ((n & mask55) << 1) | ((n >> 1) & mask55);
n = ((n & mask33) << 2) | ((n >> 2) & mask33);
n = ((n & mask0f) << 4) | ((n >> 4) & mask0f);
n = ((n & maskff) << 8) | ((n >> 8) & maskff);
return (n << 16) | (n >> 16);
}
constexpr uint32_t bitrev(uint32_t n, int len)
{
assert(len <= 32);
return bitrev32(n) >> (32 - len);
}
class BinRevTableC64X4HP
{
public:
using F64 = double;
using C64 = std::complex<F64>;
using C64X4 = hint_simd::Complex64X4;
static constexpr int MAX_LOG_LEN = 32, LOG_BLOCK = 2, BLOCK = 1 << LOG_BLOCK;
static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;
BinRevTableC64X4HP(int log_max_iter_in, int log_fft_len_in)
: index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
{
assert(log_max_iter <= log_fft_len);
assert(log_fft_len <= MAX_LOG_LEN);
const F64 factor = F64(1) / (size_t(1) << (log_fft_len - log_max_iter));
for (int i = 0; i < MAX_LOG_LEN; i++)
{
units[i] = getOmega(size_t(1) << (i + 1), 1, factor);
}
auto fp = reinterpret_cast<F64 *>(table);
fp[0] = 1, fp[BLOCK] = 0;
for (int i = 1; i < BLOCK; i++)
{
C64 omega = getOmega(BLOCK, bitrev(i, LOG_BLOCK), factor);
fp[i] = omega.real(), fp[i + BLOCK] = omega.imag();
}
}
// Only for power of 2
void reset(size_t i = 0)
{
if (i == 0)
{
pop = 0, index = i;
return;
}
assert((i & (i - 1)) == 0);
assert(i % BLOCK == 0);
pop = 1, index = i / BLOCK;
int zero = hint_ctz(index);
auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
table[1].load1(fp, fp + 1);
table[1] = table[1].mul(table[0]);
}
C64X4 iterate()
{
C64X4 res = table[pop], unit4;
index++;
int zero = hint_ctz(index);
auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
unit4.load1(fp, fp + 1);
pop -= zero;
table[pop + 1] = table[pop].mul(unit4);
pop++;
return res;
}
static C64 getOmega(size_t n, size_t index, F64 factor = 1)
{
F64 theta = -HINT_2PI * index / n;
return std::polar<F64>(1, theta * factor);
}
private:
C64 units[MAX_LOG_LEN]{};
C64X4 table[MAX_LOG_LEN]{};
size_t index;
int pop;
int log_max_iter, log_fft_len;
};
template <size_t RI_DIFF = 1, typename FloatTy>
inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1,
const std::complex<FloatTy> &omega0, const FloatTy factor = 1)
{
using Complex = std::complex<FloatTy>;
auto mul1 = [](Complex c0, Complex c1)
{
return Complex(c0.imag() * c1.real() + c0.real() * c1.imag(),
c0.imag() * c1.imag() - c0.real() * c1.real());
};
auto mul2 = [](Complex c0, Complex c1)
{
return Complex(c0.real() * c1.imag() - c0.imag() * c1.real(),
c0.real() * c1.real() + c0.imag() * c1.imag());
};
auto compute2 = [&omega0](Complex in0, Complex in1, Complex &out0, Complex &out1, auto Func)
{
in1 = std::conj(in1);
transform2(in0, in1);
in1 = Func(in1, omega0);
out0 = in0 + in1;
out1 = std::conj(in0 - in1);
};
Complex c0, c1;
{
Complex x0, x1, x2, x3;
c0.real(inout0[0]), c0.imag(inout0[RI_DIFF]), c1.real(inout1[0]), c1.imag(inout1[RI_DIFF]);
compute2(c0, c1, x0, x1, mul1);
c0.real(in0[0]), c0.imag(in0[RI_DIFF]), c1.real(in1[0]), c1.imag(in1[RI_DIFF]);
compute2(c0, c1, x2, x3, mul1);
x0 *= x2 * factor;
x1 *= x3 * factor;
compute2(x0, x1, c0, c1, mul2);
}
inout0[0] = c0.real(), inout0[RI_DIFF] = c0.imag();
inout1[0] = c1.real(), inout1[RI_DIFF] = c1.imag();
}
inline void dot_rfftX4(F64 *inout0, F64 *inout1, const F64 *in0, const F64 *in1, const C64X4 &omega0, const F64X4 &inv)
{
auto mul1 = [](C64X4 c0, C64X4 c1)
{
return C64X4(F64X4::fmadd(c0.imag, c1.real, c0.real * c1.imag),
F64X4::fmsub(c0.imag, c1.imag, c0.real * c1.real));
};
auto mul2 = [](C64X4 c0, C64X4 c1)
{
return C64X4(F64X4::fmsub(c0.real, c1.imag, c0.imag * c1.real),
F64X4::fmadd(c0.real, c1.real, c0.imag * c1.imag));
};
auto compute2 = [&omega0](C64X4 c0, C64X4 c1, C64X4 &out0, C64X4 &out1, auto Func)
{
C64X4 t0(c0.real + c1.real, c0.imag - c1.imag), t1(c0.real - c1.real, c0.imag + c1.imag);
t1 = Func(t1, omega0);
out0 = t0 + t1;
out1.real = t0.real - t1.real;
out1.imag = t1.imag - t0.imag;
};
C64X4 c0, c1;
{
C64X4 x0, x1, x2, x3;
c0.load(inout0), c1.load(inout1);
compute2(c0, c1.reverse(), x0, x1, mul1);
c0.load(in0), c1.load(in1);
compute2(c0, c1.reverse(), x2, x3, mul1);
c0 = x0.mul(x2) * inv;
c1 = x1.mul(x3) * inv;
compute2(c0, c1, c0, c1, mul2);
}
c0.store(inout0), c1.reverse().store(inout1);
}
inline void real_dot_binrev4(Float64 in_out[], Float64 in[], size_t float_len)
{
using Complex = std::complex<Float64>;
Float64 inv = 2.0 / float_len;
{
auto r0 = in_out[0], i0 = in_out[4], r1 = in[0], i1 = in[4];
transform2(r0, i0);
transform2(r1, i1);
r0 *= r1, i0 *= i1;
transform2(r0, i0);
in_out[0] = r0 * 0.5 * inv, in_out[4] = i0 * 0.5 * inv;
}
auto temp = Complex(in_out[1], in_out[5]) * Complex(in[1], in[5]) * inv;
in_out[1] = temp.real(), in_out[5] = temp.imag();
inv /= 8;
static BinRevTableC64X4HP table(31, 32);
dot_rfft<4>(&in_out[2], &in_out[3], &in[2], &in[3], Complex(COS_PI_8, -COS_PI_8), inv);
constexpr Float64 COS_16_1 = 0.92387953251128675612818318939;
constexpr Float64 SIN_16_1 = 0.38268343236508977172845998403;
dot_rfft<4>(&in_out[8], &in_out[11], &in[8], &in[11], C64(COS_16_1, -SIN_16_1), inv);
dot_rfft<4>(&in_out[9], &in_out[10], &in[9], &in[10], C64(-SIN_16_1, -COS_16_1), inv);
const Float64X4 inv4 = F64X4(inv);
for (size_t begin = 16; begin < float_len; begin *= 2)
{
table.reset(begin / 2);
auto it0 = in_out + begin, it1 = it0 + begin - 8, it2 = in + begin, it3 = it2 + begin - 8;
for (; it0 < it1; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8)
{
auto omega = table.iterate();
dot_rfftX4(it0, it1, it2, it3, omega, inv4);
}
}
}
template <bool TO_INT = false>
inline void real_conv_avx(F64 *in_out1, F64 *in2, size_t float_len)
{
assert(is_2pow(float_len));
FFTAVX::difRec<true>(in_out1, float_len);
FFTAVX::difRec<true>(in2, float_len);
real_dot_binrev4(in_out1, in2, float_len);
FFTAVX::iditRec<true, TO_INT>(in_out1, float_len);
}
template <bool TO_INT = false>
inline void real_conv_avxS(F64 *in_out1, F64 *in2, size_t float_len)
{
assert(is_2pow(float_len));
FFTAVX::difLarge<true>(in_out1, float_len);
FFTAVX::difLarge<true>(in2, float_len);
real_dot_binrev4(in_out1, in2, float_len);
FFTAVX::iditLarge<true, TO_INT>(in_out1, float_len);
}
}
}
}
namespace string_util{
using namespace hint;
using namespace transform;
using namespace fft;
class ItoStrBase10000{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num){
uint32_t res = (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res + '0' * 0x1010101;}
constexpr ItoStrBase10000(){
for (size_t i = 0; i < 10000; i++){
table[i] = itosbase10000(i);
}}
void tostr(char *str, uint32_t num) const{
std::memcpy(str, &table[num], sizeof(num));}
uint32_t tostr(uint32_t num) const{
return table[num];}
};
class StrtoIBase100{
private:
static constexpr size_t TABLE_SIZE = size_t(1) << 15;
uint16_t table[TABLE_SIZE]{};
public:
static constexpr uint16_t itosbase100(uint16_t num){
uint16_t res = (num / 10 % 10) | ((num % 10) << 8);
return res + '0' * 0x0101;}
constexpr StrtoIBase100(){
for (size_t i = 0; i < TABLE_SIZE; i++){
table[i] = UINT16_MAX;
}
for (size_t i = 0; i < 100; i++){
table[itosbase100(i)] = i;
}}
uint16_t toInt(const char *str) const{
uint16_t num;
std::memcpy(&num, str, sizeof(num));
return table[num];}
};
constexpr ItoStrBase10000 itosbase10000{};
constexpr StrtoIBase100 strtoibase100{};
inline uint32_t stobase10000(const char *s){
return strtoibase100.toInt(s) * 100 + strtoibase100.toInt(s + 2);
}
template <typename T, size_t ALIGN = 64>
class AlignMem{
public:
using Ptr = T *;
using ConstPtr = const T *;
AlignMem() : ptr(nullptr) {}
AlignMem(size_t n) : len(n), ptr(reinterpret_cast<Ptr>(_mm_malloc(n * sizeof(T), ALIGN))) {}
~AlignMem(){
if (ptr){
_mm_free(ptr);
}};
T &operator[](size_t i){
return ptr[i];}
const T &operator[](size_t i) const{
return ptr[i];}
Ptr begin(){
return ptr;}
Ptr end(){
return ptr + len;}
ConstPtr begin() const{
return ptr;}
ConstPtr end() const{
return ptr + len;}
private:
T *ptr;
size_t len;
};
template <typename T>
void fill_zero(T *begin, T *end){
std::memset(begin, 0, (end - begin) * sizeof(T));
}
template <typename T>
size_t str_num_to_array_base10000(const char *str, size_t len, T *ary){
constexpr size_t BLOCK = 4;
auto end = str + len, p = str;
size_t i = 0;
for (auto ed = end - len % BLOCK; p < ed; p += BLOCK, i++){
ary[i] = stobase10000(p);}
size_t shift = 0;
if (p < end){
size_t rem = end - p;
int n = 0;
for (; p < end; p++){
n = n * 10 + *p - '0';
}
shift = BLOCK - rem;
for (; rem < BLOCK; rem++){
n *= 10;
}
ary[i] = n;
i++;}
return shift;
}
template <typename T>
size_t conv_to_str_base10000(const T *ary, size_t conv_len, size_t shift, char *res, size_t &res_len){
constexpr size_t BLOCK = 4, BASE = 10000;
res_len = (conv_len + 1) * BLOCK;
auto end = res + res_len;
size_t i = conv_len;
uint64_t carry = 0;
while (i > 0){
i--;
end -= BLOCK;
carry += ary[i];
itosbase10000.tostr(end, carry % BASE);
carry /= BASE;}
assert(carry < BASE);
end -= 4;
itosbase10000.tostr(end, carry);
while (*end == '0'){
end++;}
size_t offset = end - res;
res_len -= (offset + shift);
return offset;
}
// return result begin
char *big_mul(const char *str1, size_t len1, const char *str2, size_t len2, char *res, size_t &res_len){
constexpr size_t BLOCK = 4, BASE = 10000;
size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
size_t conv_len = block_len1 + block_len2 - 1, fft_len = hint::int_ceil2(conv_len);
fft_len = std::max<size_t>(fft_len, 256);
AlignMem<Float64> ary1(fft_len), ary2(fft_len);
size_t shift = str_num_to_array_base10000(str1, len1, &ary1[0]);
shift += str_num_to_array_base10000(str2, len2, &ary2[0]);
fill_zero(ary1.begin() + block_len1, ary1.end());
fill_zero(ary2.begin() + block_len2, ary2.end());
real_conv_avxS<true>(ary1.begin(), ary2.begin(), fft_len);
return res + conv_to_str_base10000((uint64_t *)ary1.begin(), conv_len, shift, res, res_len);
}
size_t preserve_strlen(size_t len1, size_t len2){
constexpr size_t BLOCK = 4;
size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
return (block_len1 + block_len2) * BLOCK;
}
size_t digit_strlen(const char *str){
auto begin = str;
while (*str >= '0'){
str++;}
return str - begin;
}
void mul(){
constexpr size_t STR_LEN = 2000008;
static char str[STR_LEN] = "0 10";
fread(str, 1, STR_LEN, stdin);
char *s1 = str, *s2;
size_t len1 = digit_strlen(str);
s2 = s1 + len1;
while (*s2 < '0'){
s2++;}
size_t len2 = digit_strlen(s2);
size_t res_len = preserve_strlen(len1, len2);
auto begin = big_mul(s1, len1, s2, len2, str, res_len);
auto end = begin + res_len;
if(res_len == 0)
{
puts("0");
}
fwrite(begin, 1, res_len, stdout);
}
}
int main(){
string_util::mul();
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 8.704 ms | 13 MB + 888 KB | Accepted | Score: 100 | 显示更多 |