#include <algorithm>
#include <atomic>
#include <complex>
#include <future>
#include <iostream>
#include <random>
#include <string>
#include <cstdlib>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("fma")
using Complex = std::complex<double>;
// 2个复数并行
struct Complex2
{
__m256d data;
Complex2()
{
data = _mm256_setzero_pd();
}
Complex2(double input)
{
data = _mm256_set1_pd(input);
}
Complex2(__m256d input)
{
data = input;
}
Complex2(const Complex2 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex2(double const *ptr)
{
data = _mm256_loadu_pd(ptr);
}
Complex2(Complex a)
{
data = _mm256_broadcast_pd((__m128d *)&a);
}
// Complex2(Complex a, Complex b)
// {
// data = _mm256_set_m128d(*(__m128d *)&b, *(__m128d *)&a);
// }
Complex2(const Complex *ptr)
{
data = _mm256_loadu_pd((const double *)ptr);
}
void clr()
{
data = _mm256_setzero_pd();
}
// void store(Complex *a, Complex *b) const
// {
// _mm256_storeu2_m128d((double *)b, (double *)a, data);
// }
void store(Complex *a) const
{
_mm256_storeu_pd((double *)a, data);
}
// void print() const
// {
// Complex a, b;
// store(&a, &b);
// std::cout << a << "\t" << b << "\n";
// }
template <int M>
Complex2 element_permute() const
{
return _mm256_permute_pd(data, M);
}
Complex2 all_real() const
{
return _mm256_movedup_pd(data);
}
Complex2 all_imag() const
{
return element_permute<0XF>();
}
Complex2 swap() const
{
return element_permute<0X5>();
}
Complex2 mul_neg_i() const
{
static const __m256d subber{};
return Complex2(_mm256_addsub_pd(subber, data)).swap();
}
Complex2 operator+(Complex2 input) const
{
return _mm256_add_pd(data, input.data);
}
Complex2 operator-(Complex2 input) const
{
return _mm256_sub_pd(data, input.data);
}
Complex2 operator*(Complex2 input) const
{
auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
}
Complex2 operator/(Complex2 input) const
{
return _mm256_div_pd(data, input.data);
}
};
#endif
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using LONG = long;
using Complex = std::complex<double>;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = HINT_PI * 2;
template <typename T>
constexpr T min_2pow(T n)
{
T res = 1;
while (res < n)
{
res *= 2;
}
return res;
}
template <typename T>
inline void ary_clr(T *ptr, size_t len)
{
memset(ptr, 0, len * sizeof(T));
}
template <typename T>
constexpr bool is_neg(T x)
{
return x < 0;
}
template <typename T>
constexpr size_t hint_log2(T n)
{
T res = 0;
while (n > 1)
{
n /= 2;
res++;
}
return res;
}
template <typename T>
constexpr bool is_odd(T x)
{
return static_cast<bool>(x & 1);
}
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
template <typename T>
void ary_copy(T *target, const T *source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::memcpy(target, source, len * sizeof(T));
}
template <typename T>
inline T *ary_realloc(T *ptr, size_t len)
{
if (len * sizeof(T) < INT64_MAX)
{
ptr = static_cast<T *>(realloc(ptr, len * sizeof(T)));
}
if (ptr == nullptr)
{
throw("realloc error");
}
return ptr;
}
template <typename T>
constexpr T max_2pow(T n)
{
T res = 1;
res <<= (sizeof(T) * 8 - 1);
while (res > n)
{
res /= 2;
}
return res;
}
constexpr UINT_64 qpow(UINT_64 m, UINT_64 n, UINT_64 mod)
{
if (m <= 1)
{
return m;
}
UINT_64 result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m % mod;
}
m = m * m % mod;
n >>= 1;
}
return result;
}
constexpr UINT_64 mod_inv(UINT_64 n, UINT_64 mod)
{
return qpow(n, mod - 2, mod);
}
template <typename T>
inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
{
size_t min_len = std::min(len1, len2);
size_t i = 0;
while (i < min_len)
{
target[i] = Complex(source1[i], source2[i]);
i++;
}
while (i < len1)
{
target[i].real(source1[i]);
i++;
}
while (i < len2)
{
target[i].imag(source2[i]);
i++;
}
}
namespace hint_transform
{
class ComplexTableY
{
private:
std::vector<std::vector<Complex>> table1;
std::vector<std::vector<Complex>> table3;
INT_32 max_log_size = 2;
INT_32 cur_log_size = 2;
static constexpr size_t FAC = 1;
ComplexTableY(const ComplexTableY &) = delete;
ComplexTableY &operator=(const ComplexTableY &) = delete;
public:
~ComplexTableY() {}
// 初始化可以生成平分圆1<<shift份产生的单位根的表
ComplexTableY(UINT_32 max_shift)
{
max_shift = std::max<size_t>(max_shift, 1);
max_log_size = max_shift;
table1.resize(max_shift + 1);
table3.resize(max_shift + 1);
table1[0] = table1[1] = table3[0] = table3[1] = std::vector<Complex>{1};
table1[2] = table3[2] = std::vector<Complex>{1};
#if TABLE_PRELOAD == 1
expand(max_shift);
#endif
}
void expand(INT_32 shift)
{
shift = std::max<INT_32>(shift, 2);
if (shift > max_log_size)
{
throw("FFT length too long for lut\n");
}
for (INT_32 i = cur_log_size + 1; i <= shift; i++)
{
size_t len = 1ull << i, vec_size = len * FAC / 4;
table1[i].resize(vec_size);
table3[i].resize(vec_size);
table1[i][0] = table3[i][0] = Complex(1, 0);
for (size_t pos = 0; pos < vec_size / 2; pos++)
{
table1[i][pos * 2] = table1[i - 1][pos];
table3[i][pos * 2] = table3[i - 1][pos];
}
for (size_t pos = 1; pos < vec_size / 2; pos += 2)
{
double cos_theta = std::cos(HINT_2PI * pos / len);
double sin_theta = std::sin(HINT_2PI * pos / len);
table1[i][pos] = Complex(cos_theta, -sin_theta);
table1[i][vec_size - pos] = Complex(sin_theta, -cos_theta);
}
table1[i][vec_size / 2] = std::conj(unit_root(8, 1));
for (size_t pos = 1; pos < vec_size / 2; pos += 2)
{
Complex tmp = get_omega(i, pos * 3);
table3[i][pos] = tmp;
table3[i][vec_size - pos] = Complex(tmp.imag(), tmp.real());
}
table3[i][vec_size / 2] = std::conj(unit_root(8, 3));
}
cur_log_size = std::max(cur_log_size, shift);
}
// 返回单位圆上辐角为theta的点
static Complex unit_root(double theta)
{
return std::polar<double>(1.0, theta);
}
// 返回单位圆上平分m份的第n个
static Complex unit_root(size_t m, size_t n)
{
return unit_root((HINT_2PI * n) / m);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex get_omega(UINT_32 shift, size_t n) const
{
size_t rank = 1ull << shift;
const size_t rank_ff = rank - 1, quad_n = n << 2;
// n &= rank_ff;
size_t zone = quad_n >> shift; // 第几象限
if ((quad_n & rank_ff) == 0)
{
static constexpr Complex ONES[4] = {Complex(1, 0), Complex(0, -1), Complex(-1, 0), Complex(0, 1)};
return ONES[zone];
}
Complex tmp;
if ((zone & 2) == 0)
{
if ((zone & 1) == 0)
{
tmp = table1[shift][n];
}
else
{
tmp = table1[shift][rank / 2 - n];
tmp.real(-tmp.real());
}
}
else
{
if ((zone & 1) == 0)
{
tmp = -table1[shift][n - rank / 2];
}
else
{
tmp = table1[shift][rank - n];
tmp.imag(-tmp.imag());
}
}
return tmp;
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex get_omega3(UINT_32 shift, size_t n) const
{
return table3[shift][n];
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex2 get_omegaX2(UINT_32 shift, size_t n) const
{
return Complex2(table1[shift].data() + n);
}
// shift表示圆平分为1<<shift份,3n表示第几个单位根
Complex2 get_omega3X2(UINT_32 shift, size_t n) const
{
return Complex2(table3[shift].data() + n);
}
};
constexpr size_t lut_max_rank = 23;
static ComplexTableY TABLE(lut_max_rank);
// 二进制逆序
template <typename T>
void binary_reverse_swap(T &ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
// 四进制逆序
template <typename SizeType = UINT_32, typename T>
void quaternary_reverse_swap(T &ary, size_t len)
{
SizeType log_n = hint_log2(len);
SizeType *rev = new SizeType[len / 4];
rev[0] = 0;
for (SizeType i = 1; i < len; i++)
{
SizeType index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
if (i < len / 4)
{
rev[i] = index;
}
if (i < index)
{
std::swap(ary[i], ary[index]);
}
}
delete[] rev;
}
// 2点fft
template <typename T>
inline void fft_2point(T &sum, T &diff)
{
T tmp0 = sum;
T tmp1 = diff;
sum = tmp0 + tmp1;
diff = tmp0 - tmp1;
}
// 4点fft
inline void fft_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp2 + tmp3;
input[rank * 2] = tmp0 - tmp1;
input[rank * 3] = tmp2 - tmp3;
}
inline void fft_dit_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp1);
fft_2point(tmp2, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp2;
input[rank] = tmp1 + tmp3;
input[rank * 2] = tmp0 - tmp2;
input[rank * 3] = tmp1 - tmp3;
}
inline void fft_dit_8point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
Complex tmp4 = input[rank * 4];
Complex tmp5 = input[rank * 5];
Complex tmp6 = input[rank * 6];
Complex tmp7 = input[rank * 7];
fft_2point(tmp0, tmp1);
fft_2point(tmp2, tmp3);
fft_2point(tmp4, tmp5);
fft_2point(tmp6, tmp7);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
tmp7 = Complex(tmp7.imag(), -tmp7.real());
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
fft_2point(tmp4, tmp6);
fft_2point(tmp5, tmp7);
static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
tmp6 = Complex(tmp6.imag(), -tmp6.real());
tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());
input[0] = tmp0 + tmp4;
input[rank] = tmp1 + tmp5;
input[rank * 2] = tmp2 + tmp6;
input[rank * 3] = tmp3 + tmp7;
input[rank * 4] = tmp0 - tmp4;
input[rank * 5] = tmp1 - tmp5;
input[rank * 6] = tmp2 - tmp6;
input[rank * 7] = tmp3 - tmp7;
}
inline void fft_dif_4point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp0 - tmp1;
input[rank * 2] = tmp2 + tmp3;
input[rank * 3] = tmp2 - tmp3;
}
inline void fft_dif_8point(Complex *input, size_t rank = 1)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
Complex tmp4 = input[rank * 4];
Complex tmp5 = input[rank * 5];
Complex tmp6 = input[rank * 6];
Complex tmp7 = input[rank * 7];
fft_2point(tmp0, tmp4);
fft_2point(tmp1, tmp5);
fft_2point(tmp2, tmp6);
fft_2point(tmp3, tmp7);
static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
tmp6 = Complex(tmp6.imag(), -tmp6.real());
tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
fft_2point(tmp4, tmp6);
fft_2point(tmp5, tmp7);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
tmp7 = Complex(tmp7.imag(), -tmp7.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp0 - tmp1;
input[rank * 2] = tmp2 + tmp3;
input[rank * 3] = tmp2 - tmp3;
input[rank * 4] = tmp4 + tmp5;
input[rank * 5] = tmp4 - tmp5;
input[rank * 6] = tmp6 + tmp7;
input[rank * 7] = tmp6 - tmp7;
}
// fft基2时间抽取蝶形变换
inline void fft_radix2_dit_butterfly(Complex omega, Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank] * omega;
input[0] = tmp0 + tmp1;
input[rank] = tmp0 - tmp1;
}
// fft基2频率抽取蝶形变换
inline void fft_radix2_dif_butterfly(Complex omega, Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
input[0] = tmp0 + tmp1;
input[rank] = (tmp0 - tmp1) * omega;
}
// // fft基2频率抽取蝶形变换
// inline void fft_radix2_dif_butterfly(Complex2 omega, Complex *input, size_t rank)
// {
// Complex2 tmp0(input);
// Complex2 tmp1(input + rank);
// (tmp0 + tmp1).store(input);
// ((tmp0 - tmp1) * omega).store(input + rank);
// }
// fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(Complex omega, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2] * omega;
Complex tmp3 = input[rank * 3] * omega_cube;
fft_2point(tmp2, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp2;
input[rank] = tmp1 + tmp3;
input[rank * 2] = tmp0 - tmp2;
input[rank * 3] = tmp1 - tmp3;
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(Complex omega, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0;
input[rank] = tmp1;
input[rank * 2] = (tmp2 + tmp3) * omega;
input[rank * 3] = (tmp2 - tmp3) * omega_cube;
}
// // fft分裂基时间抽取蝶形变换
inline void fft_split_radix_dit_butterfly(Complex2 omega, Complex2 omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = input;
Complex2 tmp1 = input + rank;
Complex2 tmp2 = Complex2(input + rank * 2) * omega;
Complex2 tmp3 = Complex2(input + rank * 3) * omega_cube;
fft_2point(tmp2, tmp3);
tmp3 = tmp3.mul_neg_i();
(tmp0 + tmp2).store(input);
(tmp1 + tmp3).store(input + rank);
(tmp0 - tmp2).store(input + rank * 2);
(tmp1 - tmp3).store(input + rank * 3);
}
// fft分裂基频率抽取蝶形变换
inline void fft_split_radix_dif_butterfly(Complex2 omega, Complex2 omega_cube,
Complex *input, size_t rank)
{
Complex2 tmp0 = (input);
Complex2 tmp1 = (input + rank);
Complex2 tmp2 = (input + rank * 2);
Complex2 tmp3 = (input + rank * 3);
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = tmp3.mul_neg_i();
tmp0.store(input);
tmp1.store(input + rank);
((tmp2 + tmp3) * omega).store(input + rank * 2);
((tmp2 - tmp3) * omega_cube).store(input + rank * 3);
}
// fft基4时间抽取蝶形变换
inline void fft_radix4_dit_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank] * omega;
Complex tmp2 = input[rank * 2] * omega_sqr;
Complex tmp3 = input[rank * 3] * omega_cube;
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = tmp2 + tmp3;
input[rank * 2] = tmp0 - tmp1;
input[rank * 3] = tmp2 - tmp3;
}
// fft基4频率抽取蝶形变换
inline void fft_radix4_dif_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex *input, size_t rank)
{
Complex tmp0 = input[0];
Complex tmp1 = input[rank];
Complex tmp2 = input[rank * 2];
Complex tmp3 = input[rank * 3];
fft_2point(tmp0, tmp2);
fft_2point(tmp1, tmp3);
tmp3 = Complex(tmp3.imag(), -tmp3.real());
input[0] = tmp0 + tmp1;
input[rank] = (tmp2 + tmp3) * omega;
input[rank * 2] = (tmp0 - tmp1) * omega_sqr;
input[rank * 3] = (tmp2 - tmp3) * omega_cube;
}
// 求共轭复数及归一化,逆变换用
inline void fft_conj(Complex *input, size_t fft_len, double div = 1)
{
for (size_t i = 0; i < fft_len; i++)
{
input[i] = std::conj(input[i]) / div;
}
}
// 归一化,逆变换用
inline void fft_normalize(Complex *input, size_t fft_len)
{
double len = static_cast<double>(fft_len);
for (size_t i = 0; i < fft_len; i++)
{
input[i] /= len;
}
}
// 模板化时间抽取分裂基fft
template <size_t LEN>
void fft_split_radix_dit_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
fft_split_radix_dit_template<half_len>(input);
fft_split_radix_dit_template<quarter_len>(input + half_len);
fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
for (size_t i = 0; i < quarter_len; i += 2)
{
auto omega = TABLE.get_omegaX2(log_len, i);
auto omega_cube = TABLE.get_omega3X2(log_len, i);
fft_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
}
}
template <>
void fft_split_radix_dit_template<0>(Complex *input) {}
template <>
void fft_split_radix_dit_template<1>(Complex *input) {}
template <>
void fft_split_radix_dit_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
void fft_split_radix_dit_template<4>(Complex *input)
{
fft_dit_4point(input, 1);
}
template <>
void fft_split_radix_dit_template<8>(Complex *input)
{
fft_dit_8point(input, 1);
}
// 模板化频率抽取分裂基fft
template <size_t LEN>
void fft_split_radix_dif_template(Complex *input)
{
constexpr size_t log_len = hint_log2(LEN);
constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
for (size_t i = 0; i < quarter_len; i += 2)
{
auto omega = TABLE.get_omegaX2(log_len, i);
auto omega_cube = TABLE.get_omega3X2(log_len, i);
fft_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
}
fft_split_radix_dif_template<half_len>(input);
fft_split_radix_dif_template<quarter_len>(input + half_len);
fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
}
template <>
void fft_split_radix_dif_template<0>(Complex *input) {}
template <>
void fft_split_radix_dif_template<1>(Complex *input) {}
template <>
void fft_split_radix_dif_template<2>(Complex *input)
{
fft_2point(input[0], input[1]);
}
template <>
void fft_split_radix_dif_template<4>(Complex *input)
{
fft_dif_4point(input, 1);
}
template <>
void fft_split_radix_dif_template<8>(Complex *input)
{
fft_dif_8point(input, 1);
}
template <size_t LEN = 1>
void fft_dit_template(Complex *input, size_t fft_len)
{
if (fft_len > LEN)
{
fft_dit_template<LEN * 2>(input, fft_len);
return;
}
TABLE.expand(hint_log2(LEN));
fft_split_radix_dit_template<LEN>(input);
}
template <>
void fft_dit_template<1 << 24>(Complex *input, size_t fft_len) {}
template <size_t LEN = 1>
void fft_dif_template(Complex *input, size_t fft_len)
{
if (fft_len > LEN)
{
fft_dif_template<LEN * 2>(input, fft_len);
return;
}
TABLE.expand(hint_log2(LEN));
fft_split_radix_dif_template<LEN>(input);
}
template <>
void fft_dif_template<1 << 24>(Complex *input, size_t fft_len) {}
/// @brief 时间抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dit(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
fft_dit_template<1>(input, fft_len);
}
/// @brief 频率抽取基2fft
/// @param input 复数组
/// @param fft_len 数组长度
/// @param bit_rev 是否逆序
inline void fft_dif(Complex *input, size_t fft_len, bool bit_rev = true)
{
fft_len = max_2pow(fft_len);
fft_dif_template<1>(input, fft_len);
if (bit_rev)
{
binary_reverse_swap(input, fft_len);
}
}
}
template <UINT_64 BASE, typename T>
constexpr void save_num_to_ary(T ary[], UINT_64 num, size_t digit)
{
size_t i = digit;
while (i > 0)
{
i--;
ary[i] = num % BASE;
num /= BASE;
}
}
UINT_64 stoui64(const std::string::const_iterator &begin,
const std::string::const_iterator &end, const UINT_32 base = 10)
{
UINT_64 result = 0;
for (auto it = begin; it < end && it - begin < 19; ++it)
{
result *= base;
char c = tolower(*it);
UINT_64 n = 0;
if (isalnum(c))
{
if (isalpha(c))
{
n = c - 'a' + 10;
}
else
{
n = c - '0';
}
}
if (n < base)
{
result += n;
}
}
return result;
}
UINT_64 stoui64(const std::string &str, const UINT_32 base = 10)
{
return stoui64(str.begin(), str.end(), base);
}
std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
template <typename T = UINT_32, typename SIZE_TYPE = UINT_32>
class HintVector
{
private:
T *ary_ptr = nullptr;
SIZE_TYPE sign_n_len = 0;
SIZE_TYPE size = 0;
public:
~HintVector()
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
ary_ptr = nullptr;
}
}
HintVector()
{
ary_ptr = nullptr;
}
HintVector(SIZE_TYPE ele_size)
{
if (ele_size > 0)
{
resize(ele_size);
}
}
HintVector(SIZE_TYPE ele_size, T ele)
{
if (ele_size > 0)
{
resize(ele_size);
change_length(ele_size);
fill_element(ele, ele_size);
}
}
HintVector(const T *ary, SIZE_TYPE len)
{
if (ary == nullptr)
{
throw("Can't copy from nullptr\n");
}
if (len > 0)
{
resize(len);
change_length(len);
ary_copy(ary_ptr, ary, len);
}
}
HintVector(const HintVector &input)
{
if (input.ary_ptr != nullptr)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
}
}
HintVector(HintVector &&input)
{
if (input.ary_ptr != nullptr)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
}
}
HintVector &operator=(const HintVector &input)
{
if (this != &input)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
}
return *this;
}
HintVector &operator=(HintVector &&input)
{
if (this != &input)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
}
return *this;
}
T &operator[](SIZE_TYPE index) const
{
return ary_ptr[index];
}
T *data() const
{
return ary_ptr;
}
static SIZE_TYPE size_generator(SIZE_TYPE new_size)
{
constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;
constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1));
constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1;
new_size = std::min<SIZE_TYPE>(new_size, LEN_MAX);
if (new_size <= 2)
{
return 2;
}
else if (new_size <= 4)
{
return 4;
}
SIZE_TYPE size1 = min_2pow(new_size);
SIZE_TYPE size2 = size1 / 2;
size2 = size2 + size2 / 2;
if (new_size <= size2)
{
return size2;
}
else
{
return size1;
}
}
SIZE_TYPE length() const
{
constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;
constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1));
constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1;
return sign_n_len & LEN_MAX;
}
bool sign() const
{
constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;
constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1));
return (SIZE_80 & sign_n_len) != 0;
}
constexpr void resize(SIZE_TYPE new_size, SIZE_TYPE(size_func)(SIZE_TYPE) = size_generator)
{
new_size = size_func(new_size);
if (ary_ptr == nullptr)
{
size = new_size;
ary_ptr = new T[size];
change_length(0);
}
else if (size != new_size)
{
size = new_size;
ary_ptr = ary_realloc(ary_ptr, size);
change_length(length());
}
}
void change_length(SIZE_TYPE new_length)
{
constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;
constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1));
constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1;
if (new_length > LEN_MAX)
{
throw("Length too long\n");
}
new_length = std::min(new_length, size);
bool sign_tmp = sign();
sign_n_len = new_length;
change_sign(sign_tmp);
}
void change_sign(bool is_sign)
{
constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;
constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1));
constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1;
if ((!is_sign) || length() == 0)
{
sign_n_len = sign_n_len & LEN_MAX;
}
else
{
sign_n_len = sign_n_len | SIZE_80;
}
}
SIZE_TYPE set_true_len()
{
if (ary_ptr == nullptr)
{
return 0;
}
SIZE_TYPE t_len = length();
while (t_len > 0 && ary_ptr[t_len - 1] == 0)
{
t_len--;
}
change_length(t_len);
return length();
}
void fill_element(T ele, SIZE_TYPE count, SIZE_TYPE begin = 0)
{
if (begin >= size || ary_ptr == nullptr)
{
return;
}
if (begin + count >= size)
{
count = size - begin;
}
std::fill(ary_ptr + begin, ary_ptr + begin + count, ele);
}
};
}
namespace hint_arithm
{
using namespace hint;
using SIZE_TYPE = UINT_32;
template <typename T>
using hintvector = HintVector<T, SIZE_TYPE>;
template <typename T>
constexpr size_t ary_true_len(const T &ary, size_t len)
{
while (len > 0 && ary[len - 1] == 0)
{
len--;
}
return len;
}
template <typename T>
constexpr INT_32 abs_compare(const T ary1[], const T ary2[], size_t len1, size_t len2)
{
len1 = ary_true_len(ary1, len1);
len2 = ary_true_len(ary2, len2);
if (len1 != len2)
{
return len1 > len2 ? 1 : -1;
}
if (ary1 == ary2)
{
return 0;
}
while (len1 > 0)
{
len1--;
T num1 = ary1[len1], num2 = ary2[len1];
if (num1 != num2)
{
return num1 > num2 ? 1 : -1;
}
}
return 0;
}
// 高精度乘低精度
template <UINT_64 BASE, bool is_carry = true, typename T>
constexpr void abs_mul_num(const T in[], T num, T out[], size_t len)
{
len = ary_true_len(in, len);
num %= BASE;
UINT_64 prod = 0;
if (num == 1)
{
ary_copy(out, in, len);
}
else
{
for (size_t i = 0; i < len; i++)
{
prod += static_cast<UINT_64>(in[i]) * num;
std::tie(prod, out[i]) = div_mod<UINT_64>(prod, BASE);
}
}
if (is_carry)
{
out[len] = prod;
}
}
// 除以num的同时返回余数
template <UINT_64 BASE, typename T>
constexpr INT_64 abs_div_num(const T in[], T num, T out[], size_t len)
{
size_t pos = ary_true_len(in, len);
num %= BASE;
if (num == 1)
{
ary_copy(out, in, len);
return 0;
}
UINT_64 last_rem = 0;
while (pos > 0)
{
pos--;
last_rem = last_rem * BASE + in[pos];
std::tie(out[pos], last_rem) = div_mod<UINT_64>(last_rem, num);
}
return last_rem;
}
// 高精度加法
template <INT_64 BASE, bool is_carry = true, typename T>
constexpr void abs_add(const T in1[], const T in2[], T out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
carry += (in1[pos] + in2[pos]);
out[pos] = carry < BASE ? carry : carry - BASE;
carry = carry < BASE ? 0 : 1;
pos++;
}
while (pos < len1 && carry > 0)
{
carry += in1[pos];
out[pos] = carry < BASE ? carry : carry - BASE;
carry = carry < BASE ? 0 : 1;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
if (is_carry)
{
out[len1] = carry % BASE;
}
}
// 高精度减法
template <INT_64 BASE, typename T>
constexpr void abs_sub(const T in1[], const T in2[], T out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
return;
}
size_t pos = 0;
INT_64 borrow = 0;
while (pos < len2)
{
borrow += (static_cast<INT_64>(in1[pos]) - in2[pos]);
out[pos] = borrow < 0 ? borrow + BASE : borrow;
;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
while (pos < len1 && borrow < 0)
{
borrow += in1[pos];
out[pos] = borrow < 0 ? borrow + BASE : borrow;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
}
// fft加速乘法
template <UINT_64 BASE, typename T>
void fft_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2)
{
using namespace hint_transform;
if (len1 == 0 || len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
const size_t conv_res_len = len1 + len2 - 1; // 卷积结果长度
const size_t fft_len = min_2pow(conv_res_len); // fft长度
Complex *fft_ary = new Complex[fft_len];
com_ary_combine_copy(fft_ary, in1, len1, in2, len2);
fft_dif(fft_ary, fft_len, false);
double inv = -1 / (2.0 * fft_len);
for (size_t i = 0; i < fft_len; i++)
{
Complex tmp = fft_ary[i];
fft_ary[i] = std::conj(tmp * tmp * inv);
}
fft_dit(fft_ary, fft_len, false);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += static_cast<hint::UINT_64>(fft_ary[i].imag() + 0.5);
std::tie(carry, out[i]) = div_mod<UINT_64>(carry, BASE);
}
out[conv_res_len] = carry % BASE;
delete[] fft_ary;
}
// 除数的规则化
template <UINT_64 BASE, typename T>
constexpr T divisor_normalize(const T in[], T out[], size_t len)
{
if (in == out)
{
throw("In can't be same as out\n");
}
T multiplier = 1;
if (len == 1)
{
multiplier = (BASE - 1) / in[0];
out[0] = in[0] * multiplier;
}
else if (in[len - 1] >= (BASE / 2))
{
ary_copy(out, in, len);
}
else
{
multiplier = (BASE - 1) * BASE / (BASE * in[len - 1] + in[len - 2]);
abs_mul_num<BASE, false>(in, multiplier, out, len);
if (out[len - 1] < (BASE / 2))
{
multiplier++;
abs_add<BASE, false>(out, in, out, len, len);
}
}
return multiplier;
}
// 长除法,从被除数返回余数,需要确保除数的规则化
template <UINT_64 BASE, typename T>
void abs_long_div(T dividend[], const T divisor[], T quot[],
size_t len1, size_t len2)
{
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (len2 == 1)
{
T rem = abs_div_num<BASE>(dividend, divisor[0], quot, len1);
dividend[0] = rem;
return;
}
if (divisor[len2 - 1] < (BASE / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
quot[len1 - len2] = 0;
const UINT_64 divisor_2digits = BASE * divisor[len2 - 1] + divisor[len2 - 2];
hintvector<T> sub(len2 + 1);
// 被除数(余数大于等于除数则继续减)
while (abs_compare(dividend, divisor, len1, len2) >= 0)
{
sub.change_length(len2 + 1);
UINT_64 dividend_2digits = dividend[len1 - 1] * BASE + dividend[len1 - 2];
T quo_digit = 0;
size_t shift = len1 - len2;
// 被除数前两位大于等于除数前两位试商的结果偏差不大于1
if (dividend_2digits > divisor_2digits)
{
quo_digit = dividend_2digits / divisor_2digits;
abs_mul_num<BASE>(divisor, quo_digit, sub.data(), len2);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.data(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub<BASE>(sub.data(), divisor, sub.data(), sub_len, len2);
}
}
else if (dividend_2digits == divisor_2digits)
{
if (abs_compare(dividend + shift, divisor, len1 - shift, len2) < 0)
{
quo_digit = BASE - 1;
shift--;
abs_mul_num<BASE>(divisor, quo_digit, sub.data(), len2);
}
else
{
quo_digit = 1;
ary_copy(sub.data(), divisor, len2);
sub.set_true_len();
}
}
else
{
// 被除数前两位和除数前一位试商的结果偏差不大于2
quo_digit = dividend_2digits / (divisor_2digits / BASE);
if (quo_digit >= BASE)
{
quo_digit = BASE - 1;
}
shift--;
abs_mul_num<BASE>(divisor, quo_digit, sub.data(), len2);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.data(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub<BASE>(sub.data(), divisor, sub.data(), sub_len, len2);
if (abs_compare(dividend + shift, sub.data(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub<BASE>(sub.data(), divisor, sub.data(), sub_len, len2);
}
}
}
abs_sub<BASE>(dividend + shift, sub.data(), dividend + shift, len1 - shift, sub.length());
len1 = ary_true_len(dividend, len1);
quot[shift] = quo_digit;
}
}
// 递归除法,从被除数返回余数,需要确保除数的规则化
template <UINT_64 BASE, typename T>
void abs_rec_div(T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2)
{
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (divisor[len2 - 1] < (BASE / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
size_t quot_len = len1 - len2 + 1;
quot.change_length(quot_len);
constexpr size_t LONG_DIV_THRESHOLD = 50;
if (len2 <= LONG_DIV_THRESHOLD) // 小于等于阈值调用长除法
{
abs_long_div<BASE>(dividend, divisor, quot.data(), len1, len2);
}
else if (len1 >= len2 * 2 || len1 > ((len2 + 1) / 2) * 3) // 2n/n的除法,进行两次递归
{
size_t base_len = (len1 + 3) / 4;
size_t quot_tmp_len = base_len * 3 - len2 + 2;
hintvector<T> quot_tmp(quot_tmp_len, 0);
abs_rec_div<BASE>(dividend + base_len, divisor, quot_tmp, len1 - base_len, len2);
quot_tmp_len = quot_tmp.set_true_len();
size_t dividend_len = ary_true_len(dividend, len1);
abs_rec_div<BASE>(dividend, divisor, quot, dividend_len, len2);
quot.change_length(quot_len);
quot_len = quot.set_true_len();
abs_add<BASE>(quot.data() + base_len, quot_tmp.data(), quot.data() + base_len, quot_len - base_len, quot_tmp_len);
quot.change_length(len1 - len2 + 1);
}
else
{
// 开始试商,用dividend/(base^base_len)除以divisor/(base^base_len)
size_t base_len = len2 / 2;
abs_rec_div<BASE>(dividend + base_len, divisor + base_len, quot, len1 - base_len, len2 - base_len);
constexpr T ONE[1] = {1};
quot_len = quot.set_true_len();
hintvector<T> prod(base_len + quot_len, 0);
// 用除数的低base_len位乘以刚刚试出来的商,而后与余数比较,必须满足quot*(divisor%(base^base_len))<=dividend
fft_mul<BASE>(divisor, quot.data(), prod.data(), base_len, quot_len);
size_t prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, len1);
while (abs_compare(prod.data(), dividend, prod_len, len1) > 0)
{
abs_sub<BASE>(quot.data(), ONE, quot.data(), quot_len, 1);
abs_sub<BASE>(prod.data(), divisor, prod.data(), prod_len, base_len);
abs_add<BASE>(dividend + base_len, divisor + base_len, dividend + base_len, len1 - base_len, len2 - base_len);
quot_len = quot.set_true_len();
prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, std::max(len1, len2) + 1);
}
abs_sub<BASE>(dividend, prod.data(), dividend, len1, prod_len);
}
}
// 绝对值除法
template <UINT_64 BASE, typename T>
hintvector<T> abs_div(const T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2, bool ret_rem = true)
{
hintvector<T> normalized_divisor(len2); // 定义规则化的除数
normalized_divisor.change_length(len2);
hintvector<T> normalized_dividend(len1 + 1); // 定义规则化的被除数
normalized_dividend.change_length(len1 + 1);
T *divisor_ptr = normalized_divisor.data();
T *dividend_ptr = normalized_dividend.data();
T multiplier = divisor_normalize<BASE>(divisor, divisor_ptr, len2); // 除数规则化,获得乘数
abs_mul_num<BASE>(dividend, multiplier, dividend_ptr, len1); // 被除数规则化
len1 = normalized_dividend.set_true_len();
quot = hintvector<T>(len1 - len2 + 2, 0);
if ((!ret_rem) && (len1 + 2 < len2 * 2))
{
// 除数过长时可以截取一部分不参与计算
size_t shift = len2 * 2 - len1 - 2;
abs_rec_div<BASE>(dividend_ptr + shift, divisor_ptr + shift, quot, len1 - shift, len2 - shift);
quot.set_true_len();
return normalized_dividend;
}
abs_rec_div<BASE>(dividend_ptr, divisor_ptr, quot, len1, len2);
quot.set_true_len();
if (ret_rem)
{
len1 = normalized_dividend.set_true_len();
abs_div_num<BASE>(dividend_ptr, multiplier, dividend_ptr, len1); // 余数除以乘数得到正确的结果
normalized_dividend.set_true_len();
}
return normalized_dividend;
}}
class Integer
{
public:
using DataType = hint::UINT_32;
using SizeType = hint::UINT_32;
using DataVec = hint::HintVector<DataType, SizeType>;
private:
DataVec data;
public:
static constexpr hint::UINT_32 DIGIT = 6;
static constexpr hint::UINT_64 BASE = 1000000;
Integer()
{
data = DataVec();
}
Integer(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
}
Integer(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
}
Integer &operator=(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
return *this;
}
Integer &operator=(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
return *this;
}
Integer &operator=(const std::string &input)
{
string_in(input);
return *this;
}
Integer &operator=(const char input[])
{
string_in(input);
return *this;
}
DataType first_num() const
{
if (length() == 0)
{
return 0;
}
return data[length() - 1];
}
void change_sign(bool is_neg)
{
data.change_sign(is_neg);
}
bool is_neg() const
{
return data.sign();
}
SizeType length() const
{
return data.length();
}
SizeType length_base10() const
{
size_t len = data.length();
if (len == 0)
{
return 1;
}
return (len - 1) * DIGIT + std::ceil(std::log10(first_num() + 1));
}
void string_in(const std::string &str)
{
size_t str_len = str.size();
if (str_len == 0)
{
data = DataVec();
return;
}
hint::INT_64 len = str_len, pos = len, i = 0;
bool is_neg = false;
if (str[0] == '-')
{
is_neg = true;
len--;
}
len = (len + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
while (pos - DIGIT > 0)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + pos - DIGIT, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
i++;
pos -= DIGIT;
}
hint::INT_64 begin = is_neg ? 1 : 0;
if (pos > begin)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + begin, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
}
change_sign(is_neg);
data.set_true_len();
}
std::string to_string() const
{
std::string result;
size_t pos = length();
if (pos == 0)
{
return "0";
}
if (is_neg())
{
result += '-';
}
result += std::to_string(first_num());
pos--;
while (pos > 0)
{
pos--;
result += hint::ui64to_string(data[pos], DIGIT);
}
return result;
}
hint::INT_32 abs_compare(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
return hint_arithm::abs_compare(data.data(), input.data.data(), len1, len2);
}
Integer operator/(const Integer &input) const
{
Integer result;
size_t len1 = length(), len2 = input.length();
if (abs_compare(input) < 0)
{
return result;
}
if (len2 == 0)
{
throw("Can't divide by zero\n");
}
auto ptr1 = data.data();
auto ptr2 = input.data.data();
hint_arithm::abs_div<BASE>(ptr1, ptr2, result.data, len1, len2, false);
result.data.set_true_len();
result.change_sign(is_neg() != input.is_neg());
return result;
}
};
int main()
{
std::string s;
Integer a, b;
std::cin >> s;
a.string_in(s);
std::cin >> s;
b.string_in(s);
std::cout << (a / b).to_string();
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 709.31 us | 132 KB | Accepted | Score: 100 | 显示更多 |