#include <vector>
#include <array>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
#pragma GCC target("fma")
#pragma GCC target("avx2")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#define TABLE_ENABLE 1 // 是否使用查找表
#define MULTITHREAD 0 // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
#if MULTITHREAD == 1
#define TABLE_ENABLE 1
#endif
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using HintFloat = double;
using Complex = std::complex<HintFloat>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = HINT_PI * 2;
constexpr double HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(128) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index)
{
return ary[index];
}
constexpr const T &operator[](size_t index) const
{
return ary[index];
}
T *data()
{
return reinterpret_cast<T *>(ary);
}
const T *data() const
{
return reinterpret_cast<const T *>(ary);
}
template <typename Ty>
Ty *cast_ptr()
{
return reinterpret_cast<Ty *>(ary);
}
template <typename Ty>
const Ty *cast_ptr() const
{
return reinterpret_cast<const Ty *>(ary);
}
};
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename T>
constexpr int hint_log2(T n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r)
{
int mid = (l + r) / 2;
if ((T(1) << mid) > n)
{
r = mid;
}
else
{
l = mid;
}
}
return l;
}
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
#if MULTITHREAD == 1
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
// 模板数组拷贝
template <typename T>
void ary_copy(T *target, const T *source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::copy(target, source, len * sizeof(T));
}
// FFT与类FFT变换的命名空间
namespace hint_transform
{
namespace hint_fft
{
// 返回单位圆上辐角为theta的点
static Complex unit_root(HintFloat theta)
{
return std::polar<HintFloat>(1.0, theta);
}
// 返回单位圆上平分m份的第n个
static Complex unit_root(size_t m, size_t n)
{
return unit_root((HINT_2PI * n) / m);
}
struct RIPtr
{
HintFloat *real = nullptr;
HintFloat *imag = nullptr;
constexpr RIPtr() {}
constexpr RIPtr(HintFloat *in_real, HintFloat *in_imag)
: real(in_real), imag(in_imag) {}
template <typename DataTy>
void load(DataTy &r, DataTy &i) const
{
r.load(real);
i.load(imag);
}
template <typename DataTy>
void save(const DataTy &r, const DataTy &i)
{
r.save(real);
i.save(imag);
}
void load(HintFloat &r, HintFloat &i) const
{
r = *real;
i = *imag;
}
void save(HintFloat r, HintFloat i)
{
*real = r;
*imag = i;
}
constexpr RIPtr operator+(size_t offset) const
{
return RIPtr(real + offset, imag + offset);
}
constexpr RIPtr operator-(size_t offset) const
{
return RIPtr(real - offset, imag - offset);
}
};
template <typename T>
constexpr T complex_mul_real(const T &ar, const T &ai, const T &br, const T &bi)
{
return ar * br - ai * bi;
}
template <typename T>
constexpr T complex_mul_imag(const T &ar, const T &ai, const T &br, const T &bi)
{
return ar * bi + ai * br;
}
template <UINT_32 MAX_SHIFT>
class ComplexTableS
{
public:
enum
{
TABLE_LEN = (size_t(1) << MAX_SHIFT),
RI_DIS = TABLE_LEN / 2
};
private:
AlignAry<HintFloat, TABLE_LEN> table1;
AlignAry<HintFloat, TABLE_LEN> table3;
INT_32 max_log_size = 2;
INT_32 cur_log_size = 2;
static constexpr size_t FAC = 1;
ComplexTableS(const ComplexTableS &) = delete;
ComplexTableS &operator=(const ComplexTableS &) = delete;
public:
// 初始化可以生成平分圆1<<shift份产生的单位根的表
constexpr ComplexTableS()
{
max_log_size = std::max<size_t>(MAX_SHIFT, 1);
table1[0] = table1[1] = 1;
table3[0] = table3[1] = 1;
#if TABLE_PRELOAD == 1
expand(max_log_size);
#endif
}
constexpr void expand(INT_32 shift)
{
expand_topdown(shift);
}
constexpr void expand_topdown(INT_32 shift)
{
shift = std::min(shift, max_log_size);
if (shift <= cur_log_size)
{
return;
}
size_t len = 1ull << shift, vec_size = len * FAC / 4;
RIPtr ptr1 = get_omega_begin(shift);
RIPtr ptr3 = get_omega3_begin(shift);
ptr1.save(1.0, 0.0);
ptr3.save(1.0, 0.0);
for (size_t pos = 1; pos < vec_size / 2; pos *= 2)
{
HintFloat theta = -HINT_2PI * pos / len;
HintFloat real = std::cos(theta), imag = std::sin(theta);
(ptr1 + pos).save(real, imag);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
size_t sub_pos = pos & (pos - 1);
HintFloat real1, imag1;
HintFloat real2, imag2;
HintFloat real3, imag3;
(ptr1 + sub_pos).load(real1, imag1), (ptr1 + pos - sub_pos).load(real2, imag2);
real3 = complex_mul_real(real1, imag1, real2, imag2);
imag3 = complex_mul_imag(real1, imag1, real2, imag2);
(ptr1 + pos).save(real3, imag3);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
size_t sub_pos = pos & (pos - 1);
HintFloat real, imag;
(ptr1 + pos).load(real, imag);
(ptr1 + vec_size - pos).save(-imag, -real);
}
for (size_t pos = 1; pos < vec_size / 2; pos++)
{
Complex tmp = get_omega(shift, pos * 3);
(ptr3 + pos).save(tmp.real(), tmp.imag());
(ptr3 + vec_size - pos).save(tmp.imag(), tmp.real());
}
Complex tmp = std::conj(unit_root(8, 1));
(ptr1 + vec_size / 2).save(tmp.real(), tmp.imag());
tmp = std::conj(unit_root(8, 3));
(ptr3 + vec_size / 2).save(tmp.real(), tmp.imag());
for (INT_32 log = shift - 1; log > cur_log_size; log--)
{
len = 1ull << log, vec_size = len / 4;
ptr1 = get_omega_begin(log);
ptr3 = get_omega3_begin(log);
RIPtr src1 = get_omega_begin(log + 1);
RIPtr src3 = get_omega3_begin(log + 1);
for (size_t pos = 0; pos < vec_size; pos++)
{
HintFloat r1, r3, i1, i3;
(src1 + pos * 2).load(r1, i1);
(src3 + pos * 2).load(r3, i3);
(ptr1 + pos).save(r1, i1);
(ptr3 + pos).save(r3, i3);
}
}
cur_log_size = std::max(cur_log_size, shift);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
constexpr Complex get_omega(UINT_32 shift, size_t n)
{
size_t vec_size = (size_t(1) << shift) / 4;
RIPtr omg_ptr = get_omega_begin(shift);
HintFloat real, imag;
if (n < vec_size)
{
(omg_ptr + n).load(real, imag);
return Complex(real, imag);
}
else if (n > vec_size)
{
(omg_ptr + vec_size * 2 - n).load(real, imag);
return Complex(-real, imag);
}
else
{
return Complex(0, -1);
}
}
constexpr RIPtr get_omega_begin(UINT_32 shift)
{
HintFloat *ptr = table1.data() + (1 << (shift - 2));
RIPtr ri_ptr(ptr, ptr + TABLE_LEN / 2);
return ri_ptr;
}
constexpr RIPtr get_omega3_begin(UINT_32 shift)
{
HintFloat *ptr = table3.data() + (1 << (shift - 2));
RIPtr ri_ptr(ptr, ptr + TABLE_LEN / 2);
return ri_ptr;
}
template <UINT_32 SHIFT>
constexpr const HintFloat *get_omega_iter() const
{
return table1.data() + (1 << (SHIFT - 2));
}
template <UINT_32 SHIFT>
constexpr const HintFloat *get_omega3_iter() const
{
return table3.data() + (1 << (SHIFT - 2));
}
};
constexpr size_t lut_max_rank = 19;
using FFTable = ComplexTableS<lut_max_rank>;
static FFTable TABLE;
template <typename T>
inline void fft_2point(T &sum, T &diff)
{
T tmp0 = sum;
T tmp1 = diff;
sum = tmp0 + tmp1;
diff = tmp0 - tmp1;
}
struct ComputeEnd4
{
HintFloat f0, f1, f2, f3;
ComputeEnd4() = default;
ComputeEnd4(const ComputeEnd4 &in) = default;
ComputeEnd4(HintFloat fin0, HintFloat fin1, HintFloat fin2, HintFloat fin3)
: f0(fin0), f1(fin1), f2(fin2), f3(fin3) {}
ComputeEnd4(const HintFloat *ptr)
{
load(ptr);
}
void load(const HintFloat *ptr)
{
f0 = ptr[0];
f1 = ptr[1];
f2 = ptr[2];
f3 = ptr[3];
}
void save(HintFloat *const ptr)
{
ptr[0] = f0;
ptr[1] = f1;
ptr[2] = f2;
ptr[3] = f3;
}
ComputeEnd4 operator+(const ComputeEnd4 &in) const
{
return ComputeEnd4(
f0 + in.f0, f1 + in.f1, f2 + in.f2, f3 + in.f3);
}
ComputeEnd4 operator-(const ComputeEnd4 &in) const
{
return ComputeEnd4(
f0 - in.f0, f1 - in.f1, f2 - in.f2, f3 - in.f3);
}
ComputeEnd4 operator*(const ComputeEnd4 &in) const
{
return ComputeEnd4(
f0 * in.f0, f1 * in.f1, f2 * in.f2, f3 * in.f3);
}
ComputeEnd4 operator-() const
{
return ComputeEnd4(-f0, -f1, -f2, -f3);
}
void print() const
{
std::cout << f0 << " " << f1 << " " << f2 << " " << f3 << "\n";
}
};
struct ComputeEnd2
{
HintFloat f0, f1;
ComputeEnd2() = default;
ComputeEnd2(const ComputeEnd2 &in) = default;
ComputeEnd2(HintFloat fin0, HintFloat fin1)
: f0(fin0), f1(fin1) {}
ComputeEnd2(const HintFloat *ptr)
{
load(ptr);
}
void load(const HintFloat *ptr)
{
f0 = ptr[0];
f1 = ptr[1];
}
void save(HintFloat *const ptr) const
{
ptr[0] = f0;
ptr[1] = f1;
}
ComputeEnd2 operator+(const ComputeEnd2 &in) const
{
return ComputeEnd2(
f0 + in.f0, f1 + in.f1);
}
ComputeEnd2 operator-(const ComputeEnd2 &in) const
{
return ComputeEnd2(
f0 - in.f0, f1 - in.f1);
}
ComputeEnd2 operator*(const ComputeEnd2 &in) const
{
return ComputeEnd2(
f0 * in.f0, f1 * in.f1);
}
ComputeEnd2 operator-() const
{
return ComputeEnd2(-f0, -f1);
}
void print() const
{
std::cout << f0 << " " << f1 << "\n";
}
};
// 定义元素为实数的接口后端,有实数组与复数组
template <typename COMPUTE_END = ComputeEnd4>
struct FFTInterface
{
using DataTy = COMPUTE_END;
enum
{
OFFSET = sizeof(COMPUTE_END) / sizeof(HintFloat)
};
using IterTy = RIPtr;
};
using Iter = HintFloat *;
using ConstIter = const HintFloat *;
template <size_t LEN, size_t RI_DIS>
struct FFT
{
enum
{
FT_DIS = RI_DIS,
LUT_DIS = FFTable::RI_DIS
};
static constexpr size_t log_len = hint_log2(LEN);
static constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
using DataTy = ComputeEnd2;
static constexpr size_t offset = sizeof(DataTy) / sizeof(HintFloat);
using half_fft = FFT<half_len, RI_DIS>;
using quarter_fft = FFT<quarter_len, RI_DIS>;
static void fft_split_radix_dit_butterfly(ConstIter omega, ConstIter omega_cube, Iter fft_input)
{
DataTy r0, r1, r2, r3, i0, i1, i2, i3;
DataTy tr0, tr1, tr2, tr3, ti0, ti1, ti2, ti3;
r0 = fft_input;
r1 = fft_input + quarter_len;
r2 = fft_input + quarter_len * 2;
r3 = fft_input + quarter_len * 3;
i0 = fft_input + FT_DIS;
i1 = fft_input + FT_DIS + quarter_len;
i2 = fft_input + FT_DIS + quarter_len * 2;
i3 = fft_input + FT_DIS + quarter_len * 3;
tr0 = omega;
tr1 = omega_cube;
ti0 = omega + LUT_DIS;
ti1 = omega_cube + LUT_DIS;
tr2 = complex_mul_real(r2, i2, tr0, ti0);
ti2 = complex_mul_imag(r2, i2, tr0, ti0);
tr3 = complex_mul_real(r3, i3, tr1, ti1);
ti3 = complex_mul_imag(r3, i3, tr1, ti1);
r2 = tr2 + tr3;
i2 = ti2 + ti3;
r3 = ti2 - ti3;
i3 = tr3 - tr2;
tr0 = r0 + r2;
ti0 = i0 + i2;
tr2 = r0 - r2;
ti2 = i0 - i2;
tr1 = r1 + r3;
ti1 = i1 + i3;
tr3 = r1 - r3;
ti3 = i1 - i3;
tr0.save(fft_input);
tr1.save(fft_input + quarter_len);
tr2.save(fft_input + quarter_len * 2);
tr3.save(fft_input + quarter_len * 3);
ti0.save(fft_input + FT_DIS);
ti1.save(fft_input + FT_DIS + quarter_len);
ti2.save(fft_input + FT_DIS + quarter_len * 2);
ti3.save(fft_input + FT_DIS + quarter_len * 3);
}
static void fft_split_radix_dif_butterfly(ConstIter omega, ConstIter omega_cube, Iter fft_input)
{
DataTy r0, r1, r2, r3, i0, i1, i2, i3;
DataTy tr0, tr1, tr2, tr3, ti0, ti1, ti2, ti3;
r0 = fft_input;
r1 = fft_input + quarter_len;
r2 = fft_input + quarter_len * 2;
r3 = fft_input + quarter_len * 3;
i0 = fft_input + FT_DIS;
i1 = fft_input + FT_DIS + quarter_len;
i2 = fft_input + FT_DIS + quarter_len * 2;
i3 = fft_input + FT_DIS + quarter_len * 3;
tr0 = r0 + r2;
ti0 = i0 + i2;
tr2 = r0 - r2;
ti2 = i0 - i2;
tr1 = r1 + r3;
ti1 = i1 + i3;
tr3 = i1 - i3;
ti3 = r3 - r1;
r2 = tr2 + tr3;
i2 = ti2 + ti3;
r3 = tr2 - tr3;
i3 = ti2 - ti3;
r0 = omega;
r1 = omega_cube;
i0 = omega + LUT_DIS;
i1 = omega_cube + LUT_DIS;
tr2 = complex_mul_real(r2, i2, r0, i0);
ti2 = complex_mul_imag(r2, i2, r0, i0);
tr3 = complex_mul_real(r3, i3, r1, i1);
ti3 = complex_mul_imag(r3, i3, r1, i1);
tr0.save(fft_input);
tr1.save(fft_input + quarter_len);
tr2.save(fft_input + quarter_len * 2);
tr3.save(fft_input + quarter_len * 3);
ti0.save(fft_input + FT_DIS);
ti1.save(fft_input + FT_DIS + quarter_len);
ti2.save(fft_input + FT_DIS + quarter_len * 2);
ti3.save(fft_input + FT_DIS + quarter_len * 3);
}
static void fft_split_radix_dit(Iter fft_input)
{
quarter_fft::fft_split_radix_dit(fft_input + half_len + quarter_len);
quarter_fft::fft_split_radix_dit(fft_input + half_len);
half_fft::fft_split_radix_dit(fft_input);
static ConstIter omg_ptr = TABLE.get_omega_iter<log_len>();
static ConstIter omg3_ptr = TABLE.get_omega3_iter<log_len>();
for (size_t i = 0; i < quarter_len; i += offset)
{
fft_split_radix_dit_butterfly(omg_ptr + i, omg3_ptr + i, fft_input + i);
}
}
static void fft_split_radix_dif(Iter fft_input)
{
static ConstIter omg_ptr = TABLE.get_omega_iter<log_len>();
static ConstIter omg3_ptr = TABLE.get_omega3_iter<log_len>();
for (size_t i = 0; i < quarter_len; i += offset)
{
fft_split_radix_dif_butterfly(omg_ptr + i, omg3_ptr + i, fft_input + i);
}
half_fft::fft_split_radix_dif(fft_input);
quarter_fft::fft_split_radix_dif(fft_input + half_len);
quarter_fft::fft_split_radix_dif(fft_input + half_len + quarter_len);
}
};
template <size_t RI_DIS>
struct FFT<0, RI_DIS>
{
static void fft_split_radix_dit(Iter fft_input) {}
static void fft_split_radix_dif(Iter fft_input) {}
};
template <size_t RI_DIS>
struct FFT<1, RI_DIS>
{
static void fft_split_radix_dit(Iter fft_input) {}
static void fft_split_radix_dif(Iter fft_input) {}
};
template <size_t RI_DIS>
struct FFT<2, RI_DIS>
{
static void fft_split_radix_dit(HintFloat *fft_input)
{
fft_2point(fft_input[0], fft_input[1]);
fft_2point(fft_input[RI_DIS], fft_input[RI_DIS + 1]);
}
static void fft_split_radix_dif(HintFloat *fft_input)
{
fft_2point(fft_input[0], fft_input[1]);
fft_2point(fft_input[RI_DIS], fft_input[RI_DIS + 1]);
}
};
template <size_t RI_DIS>
struct FFT<4, RI_DIS>
{
static void fft_dit_4point(Iter fft_input)
{
HintFloat r0 = fft_input[0];
HintFloat r1 = fft_input[1];
HintFloat r2 = fft_input[2];
HintFloat r3 = fft_input[3];
HintFloat i0 = fft_input[RI_DIS];
HintFloat i1 = fft_input[RI_DIS + 1];
HintFloat i2 = fft_input[RI_DIS + 2];
HintFloat i3 = fft_input[RI_DIS + 3];
HintFloat tr0 = r0 + r1;
HintFloat ti0 = i0 + i1;
HintFloat tr1 = r0 - r1;
HintFloat ti1 = i0 - i1;
HintFloat tr2 = r2 + r3;
HintFloat ti2 = i2 + i3;
HintFloat tr3 = i2 - i3;
HintFloat ti3 = r3 - r2;
fft_input[0] = tr0 + tr2;
fft_input[1] = tr1 + tr3;
fft_input[2] = tr0 - tr2;
fft_input[3] = tr1 - tr3;
fft_input[RI_DIS] = ti0 + ti2;
fft_input[RI_DIS + 1] = ti1 + ti3;
fft_input[RI_DIS + 2] = ti0 - ti2;
fft_input[RI_DIS + 3] = ti1 - ti3;
}
static void fft_dif_4point(Iter fft_input)
{
HintFloat r0 = fft_input[0];
HintFloat r1 = fft_input[1];
HintFloat r2 = fft_input[2];
HintFloat r3 = fft_input[3];
HintFloat i0 = fft_input[RI_DIS];
HintFloat i1 = fft_input[RI_DIS + 1];
HintFloat i2 = fft_input[RI_DIS + 2];
HintFloat i3 = fft_input[RI_DIS + 3];
HintFloat tr0 = r0 + r2;
HintFloat ti0 = i0 + i2;
HintFloat tr2 = r0 - r2;
HintFloat ti2 = i0 - i2;
HintFloat tr1 = r1 + r3;
HintFloat ti1 = i1 + i3;
HintFloat tr3 = i1 - i3;
HintFloat ti3 = r3 - r1;
fft_input[0] = tr0 + tr1;
fft_input[1] = tr0 - tr1;
fft_input[2] = tr2 + tr3;
fft_input[3] = tr2 - tr3;
fft_input[RI_DIS] = ti0 + ti1;
fft_input[RI_DIS + 1] = ti0 - ti1;
fft_input[RI_DIS + 2] = ti2 + ti3;
fft_input[RI_DIS + 3] = ti2 - ti3;
}
static void fft_split_radix_dit(HintFloat *fft_input)
{
fft_dit_4point(fft_input);
}
static void fft_split_radix_dif(HintFloat *fft_input)
{
fft_dif_4point(fft_input);
}
};
template <size_t RI_DIS>
struct FFT<8, RI_DIS>
{
static void fft_dit_8point(Iter fft_input)
{
HintFloat r0 = fft_input[0];
HintFloat r1 = fft_input[1];
HintFloat r2 = fft_input[2];
HintFloat r3 = fft_input[3];
HintFloat r4 = fft_input[4];
HintFloat r5 = fft_input[5];
HintFloat r6 = fft_input[6];
HintFloat r7 = fft_input[7];
HintFloat i0 = fft_input[RI_DIS];
HintFloat i1 = fft_input[RI_DIS + 1];
HintFloat i2 = fft_input[RI_DIS + 2];
HintFloat i3 = fft_input[RI_DIS + 3];
HintFloat i4 = fft_input[RI_DIS + 4];
HintFloat i5 = fft_input[RI_DIS + 5];
HintFloat i6 = fft_input[RI_DIS + 6];
HintFloat i7 = fft_input[RI_DIS + 7];
// 4Xdit2
HintFloat tr0 = r0 + r1, ti0 = i0 + i1; // 0-1
HintFloat tr1 = r0 - r1, ti1 = i0 - i1;
HintFloat tr2 = r2 + r3, ti2 = i2 + i3; // 2-3
HintFloat tr3 = i2 - i3, ti3 = r3 - r2;
HintFloat tr4 = r4 + r5, ti4 = i4 + i5; // 4-5
HintFloat tr5 = r4 - r5, ti5 = i4 - i5;
HintFloat tr6 = r6 + r7, ti6 = i6 + i7; // 6-7
HintFloat tr7 = i6 - i7, ti7 = r7 - r6;
// 2Xdit4
r0 = tr0 + tr2, i0 = ti0 + ti2; // 0-2
r2 = tr0 - tr2, i2 = ti0 - ti2;
r1 = tr1 + tr3, i1 = ti1 + ti3; // 1-3
r3 = tr1 - tr3, i3 = ti1 - ti3;
r4 = tr4 + tr6, i4 = ti4 + ti6; // 4-6
r6 = ti4 - ti6, i6 = tr6 - tr4;
r5 = tr5 + tr7, i5 = ti5 + ti7; // 5-7
r7 = tr5 - tr7, i7 = ti5 - ti7;
static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
static constexpr HintFloat cos_3_8 = -cos_1_8;
tr5 = cos_1_8 * (i5 + r5), ti5 = cos_1_8 * (i5 - r5);
tr7 = cos_3_8 * (r7 - i7), ti7 = cos_3_8 * (r7 + i7);
// dit8
fft_input[0] = r0 + r4;
fft_input[1] = r1 + tr5;
fft_input[2] = r2 + r6;
fft_input[3] = r3 + tr7;
fft_input[4] = r0 - r4;
fft_input[5] = r1 - tr5;
fft_input[6] = r2 - r6;
fft_input[7] = r3 - tr7;
fft_input[RI_DIS] = i0 + i4;
fft_input[RI_DIS + 1] = i1 + ti5;
fft_input[RI_DIS + 2] = i2 + i6;
fft_input[RI_DIS + 3] = i3 + ti7;
fft_input[RI_DIS + 4] = i0 - i4;
fft_input[RI_DIS + 5] = i1 - ti5;
fft_input[RI_DIS + 6] = i2 - i6;
fft_input[RI_DIS + 7] = i3 - ti7;
}
static void fft_dif_8point(Iter fft_input)
{
HintFloat r0 = fft_input[0];
HintFloat r1 = fft_input[1];
HintFloat r2 = fft_input[2];
HintFloat r3 = fft_input[3];
HintFloat r4 = fft_input[4];
HintFloat r5 = fft_input[5];
HintFloat r6 = fft_input[6];
HintFloat r7 = fft_input[7];
HintFloat i0 = fft_input[RI_DIS];
HintFloat i1 = fft_input[RI_DIS + 1];
HintFloat i2 = fft_input[RI_DIS + 2];
HintFloat i3 = fft_input[RI_DIS + 3];
HintFloat i4 = fft_input[RI_DIS + 4];
HintFloat i5 = fft_input[RI_DIS + 5];
HintFloat i6 = fft_input[RI_DIS + 6];
HintFloat i7 = fft_input[RI_DIS + 7];
// dif8
HintFloat tr0 = r0 + r4, ti0 = i0 + i4; // 0-4
HintFloat tr4 = r0 - r4, ti4 = i0 - i4;
HintFloat tr1 = r1 + r5, ti1 = i1 + i5; // 1-5
HintFloat tr5 = r1 - r5, ti5 = i1 - i5;
HintFloat tr2 = r2 + r6, ti2 = i2 + i6; // 2-6
HintFloat tr6 = i2 - i6, ti6 = r6 - r2;
HintFloat tr3 = r3 + r7, ti3 = i3 + i7; // 3-7
HintFloat tr7 = r3 - r7, ti7 = i3 - i7;
static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
static constexpr HintFloat cos_3_8 = -cos_1_8;
r5 = cos_1_8 * (ti5 + tr5), i5 = cos_1_8 * (ti5 - tr5);
r7 = cos_3_8 * (tr7 - ti7), i7 = cos_3_8 * (tr7 + ti7);
// 2Xdif4
r0 = tr0 + tr2, i0 = ti0 + ti2; // 0-2
r2 = tr0 - tr2, i2 = ti0 - ti2;
r1 = tr1 + tr3, i1 = ti1 + ti3; // 1-3
r3 = ti1 - ti3, i3 = tr3 - tr1;
r4 = tr4 + tr6, i4 = ti4 + ti6; // 4-6
r6 = tr4 - tr6, i6 = ti4 - ti6;
tr5 = r5 + r7, ti5 = i5 + i7; // 5-7
tr7 = i5 - i7, ti7 = r7 - r5;
// 4xdif2
fft_input[0] = r0 + r1;
fft_input[1] = r0 - r1;
fft_input[2] = r2 + r3;
fft_input[3] = r2 - r3;
fft_input[4] = r4 + tr5;
fft_input[5] = r4 - tr5;
fft_input[6] = r6 + tr7;
fft_input[7] = r6 - tr7;
fft_input[RI_DIS] = i0 + i1;
fft_input[RI_DIS + 1] = i0 - i1;
fft_input[RI_DIS + 2] = i2 + i3;
fft_input[RI_DIS + 3] = i2 - i3;
fft_input[RI_DIS + 4] = i4 + ti5;
fft_input[RI_DIS + 5] = i4 - ti5;
fft_input[RI_DIS + 6] = i6 + ti7;
fft_input[RI_DIS + 7] = i6 - ti7;
}
static void fft_split_radix_dit(Iter fft_input)
{
fft_dit_8point(fft_input);
}
static void fft_split_radix_dif(Iter fft_input)
{
fft_dif_8point(fft_input);
}
};
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dit_template_alt(Iter input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dit_template_alt<LEN / 2>(input, fft_len);
return;
}
TABLE.expand_topdown(hint_log2(LEN));
FFT<LEN, LEN>::fft_split_radix_dit(input);
}
template <>
void fft_split_radix_dit_template_alt<0>(Iter input, size_t fft_len) {}
// 辅助选择函数
template <size_t LEN = 1>
void fft_split_radix_dif_template_alt(Iter input, size_t fft_len)
{
if (fft_len < LEN)
{
fft_split_radix_dif_template_alt<LEN / 2>(input, fft_len);
return;
}
TABLE.expand_topdown(hint_log2(LEN));
FFT<LEN, LEN>::fft_split_radix_dif(input);
}
template <>
void fft_split_radix_dif_template_alt<0>(Iter input, size_t fft_len) {}
// auto fft_split_radix_dit = fft_split_radix_dit_template_alt<size_t(1) << lut_max_rank>;
// auto fft_split_radix_dif = fft_split_radix_dif_template_alt<size_t(1) << lut_max_rank>;
}
}
template <typename T>
inline void ary_clr(T *ptr, size_t len)
{
memset(ptr, 0, len * sizeof(T));
}
inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
{
UINT_64 result = 0;
for (size_t i = 0; i < dig; i++)
{
result *= 10;
result += (s[i] - '0');
}
return result;
}
constexpr UINT_32 stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
static constexpr INT_64 DIGIT = 4;
static constexpr INT_32 BASE = qpow(10, DIGIT);
inline size_t char_to_real(const char *buffer1, size_t len1, HintFloat *ary, size_t fft_len)
{
hint::INT_64 len = len1, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0)
{
hint::UINT_32 tmp = stobase10000(buffer1 + pos - DIGIT);
ary[i] = tmp;
i++;
pos -= DIGIT;
}
if (pos > 0)
{
hint::UINT_32 tmp = stoui64(buffer1, pos);
ary[i] = tmp;
i++;
}
ary_clr(ary + i, fft_len - i);
return len;
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
void tostr(char *str, uint32_t num) const
{
*reinterpret_cast<uint32_t *>(str) = table[num];
}
uint32_t tostr(uint32_t num) const
{
return table[num];
}
};
}
using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;
int main()
{
constexpr size_t STR_LEN = 2000008;
constexpr size_t fft_len = 1 << lut_max_rank;
static constexpr ItoStrBase10000 transfer;
static AlignAry<char, STR_LEN> out;
static AlignAry<HintFloat, fft_len * 2> fft_arr;
auto *fft_ary = fft_arr.data();
uint32_t *ary = out.cast_ptr<uint32_t>();
size_t len_a = 0, len_b = 0;
fread(out.data(), 1, STR_LEN, stdin);
char *p = out.data();
/*
struct stat sb;
int fd = fileno(stdin);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
*/
while (p[len_a] >= '0')
{
len_a++;
}
if (len_a == 1 && p[0] == '0')
{
puts("0");
return 0;
}
char *b = p + len_a;
while (!isdigit(*b))
{
b++;
}
while (b[len_b] >= '0')
{
len_b++;
}
if (len_b == 1 && b[0] == '0')
{
puts("0");
return 0;
}
size_t len1 = char_to_real(p, len_a, fft_ary, fft_len);
size_t len2 = char_to_real(b, len_b, fft_ary + fft_len, fft_len);
TABLE.expand(hint_log2(fft_len));
FFT<fft_len, fft_len>::fft_split_radix_dif(fft_ary);
for (size_t i = 0; i < fft_len; i += 2)
{
ComputeEnd2 real = fft_ary + i, imag = fft_ary + i + fft_len;
complex_mul_real(real, imag, real, imag).save(fft_ary + i);
(-complex_mul_imag(real, imag, real, imag)).save(fft_ary + i + fft_len);
}
FFT<fft_len, fft_len>::fft_split_radix_dit(fft_ary);
UINT_64 carry = 0;
size_t pos = STR_LEN / 4 - 1;
constexpr HintFloat inv = -0.5 / fft_len;
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
carry += UINT_64(fft_ary[i + fft_len] * inv + 0.5);
UINT_64 num = 0;
std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
ary[pos] = transfer.tostr(num);
pos--;
}
ary[pos] = transfer.tostr(carry);
pos *= 4;
while (out[pos] == '0')
{
pos++;
} // 0.8ms
fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 18.477 ms | 19 MB + 872 KB | Accepted | Score: 100 | 显示更多 |