#include <iostream>
#include <complex>
#include <cstring>
#include "stopwatch.hpp"
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using Complex = std::complex<double>;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = 6.283185307179586476925286766559;
inline Complex unit_root(double theta)
{
return std::polar<double>(1, theta);
}
class CosTable
{
private:
Complex *table = nullptr;
size_t log_size = 0;
CosTable(const CosTable &) = delete;
CosTable &operator=(const CosTable &) = delete;
public:
~CosTable()
{
if (table != nullptr)
{
delete[] table;
table = nullptr;
}
}
// 初始化可以生成平分圆1<<shift份产生的单位根的表
CosTable(size_t shift)
{
shift = std::max<size_t>(shift, 3);
log_size = shift;
size_t ary_size = (1ull << (shift - 1)) - 2;
table = new Complex[ary_size];
shift -= 2;
for (size_t i = 1; i <= shift; i++)
{
size_t len = 1ull << i;
size_t begin = len - 2;
for (size_t pos = 0; pos < len; pos++)
{
table[pos + begin] = unit_root(pos * HINT_PI / (len * 2));
}
}
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex get_complex(size_t shift, size_t n) const
{
size_t rank = 1ull << shift;
n &= (rank - 1);
size_t zone = (n << 2) >> shift;
if (((n << 2) & (rank - 1)) == 0)
{
Complex ary[4] = {Complex(1, 0), Complex(0, 1), Complex(-1, 0), Complex(0, -1)};
return ary[zone];
}
rank >>= 2;
const Complex *ptr = table + rank - 2;
Complex tmp;
switch (zone)
{
case 0:
tmp = ptr[n];
break;
case 1:
tmp = ptr[(rank << 1) - n];
tmp = Complex(-tmp.real(), tmp.imag());
break;
case 2:
tmp = -ptr[n - (rank << 1)];
break;
case 3:
tmp = std::conj(ptr[(rank << 2) - n]);
break;
default:
break;
}
return tmp;
}
};
constexpr size_t lut_max_len = 21;
const CosTable TABLE(lut_max_len); // 初始化fft表
inline UINT_64 max_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::floor(std::log2(n)));
}
inline UINT_64 min_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::ceil(std::log2(n)));
}
template <typename T>
constexpr bool is_odd(T x)
{
return static_cast<bool>(x & 1);
}
template <typename T>
inline T *ary_copy(T *target, const T *source, size_t len)
{
return static_cast<T *>(std::memcpy(target, source, len * sizeof(T)));
}
inline void fft_radix2_butterfly(Complex omega, Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
input[pos] += tmp2;
input[pos + rank] = tmp1 - tmp2;
}
inline void fft_radix4_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
Complex tmp3 = input[pos + rank * 2] * omega_sqr;
Complex tmp4 = input[pos + rank * 3] * omega_cube;
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-1 * t4.imag(), t4.real());
input[pos] = t1 + t2;
input[pos + rank] = t3 + t4;
input[pos + rank * 2] = t1 - t2;
input[pos + rank * 3] = t3 - t4;
}
template <typename T>
constexpr void binary_inverse_swap(T *ary, size_t len)
{
size_t log_n = static_cast<UINT_16>(log2(len));
size_t *rev = new size_t[len / 2];
rev[0] = 0;
for (size_t i = 1; i < len; i++)
{
size_t index = (rev[i >> 1] >> 1) | ((i & 1) << (log_n - 1)); // 求rev交换数组
if (i < len / 2)
{
rev[i] = index;
}
if (i < index)
{
std::swap(ary[i], ary[index]);
}
}
delete[] rev;
}
// 四进制逆序
template <typename T>
constexpr void quaternary_inverse_swap(T *ary, size_t len)
{
size_t log_n = static_cast<UINT_16>(log2(len));
size_t *rev = new size_t[len / 4];
rev[0] = 0;
for (size_t i = 1; i < len; i++)
{
size_t index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
if (i < len / 4)
{
rev[i] = index;
}
if (i < index)
{
std::swap(ary[i], ary[index]);
}
}
delete[] rev;
}
inline void fft_conj(Complex *input, size_t fft_len, double div = 1)
{
for (size_t i = 0; i < fft_len; i++)
{
input[i] = std::conj(input[i]) / div;
}
}
// 基2
void fft_radix2_lut(Complex *input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
if (fft_len > (1ull << lut_max_len))
{
throw("fft length too long for lut\n");
}
binary_inverse_swap(input, fft_len);
UINT_8 log_rank = 1;
for (size_t rank = 1; rank < fft_len; rank *= 2)
{
size_t gap = rank * 2;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex omega = TABLE.get_complex(log_rank, pos - begin);
fft_radix2_butterfly(omega, input, pos, rank);
}
}
log_rank++;
}
}
// 基4查表快速傅里叶变换
void fft_radix4_lut(Complex *input, size_t fft_len)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_len))
{
throw("fft length too long for lut\n");
}
quaternary_inverse_swap(input, fft_len);
UINT_16 log_rank = 2;
for (size_t rank = 1; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex omega = TABLE.get_complex(log_rank, pos - begin);
Complex omega_sqr = TABLE.get_complex(log_rank, (pos - begin) << 1);
Complex omega_cube = omega * omega_sqr;
fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
}
}
log_rank += 2;
}
}
void fft_lut(Complex *input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
size_t log_len = std::log2(fft_len);
if (is_odd(log_len))
{
size_t half_len = fft_len / 2;
Complex *tmp_ary = new Complex[half_len];
for (size_t i = 0; i < fft_len; i += 2)
{
input[i / 2] = input[i];
tmp_ary[i / 2] = input[i + 1];
}
ary_copy(input + half_len, tmp_ary, half_len);
delete[] tmp_ary;
fft_radix4_lut(input, half_len);
fft_radix4_lut(input + half_len, half_len);
for (size_t i = 0; i < half_len; i++)
{
Complex omega = TABLE.get_complex(log_len, i);
fft_radix2_butterfly(omega, input, i, half_len);
}
}
else
{
fft_radix4_lut(input, fft_len);
}
}
void ifft_lut(Complex *input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
fft_conj(input, fft_len);
fft_lut(input, fft_len);
fft_conj(input, fft_len, fft_len);
}
template <typename T>
inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
{
for (size_t i = 0; i < len; i++)
{
out[i] = in1[i] * in2[i];
}
}
void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
if (n < m)
{
std::swap(a, b);
std::swap(n, m);
}
size_t fft_len = min_2pow(n + m);
Complex *fft_ary = new Complex[fft_len];
UINT_32 tmp = 0;
size_t pos = 0;
while (pos < m)
{
fft_ary[pos] = Complex(a[pos], b[pos]);
pos++;
}
while (pos < n)
{
fft_ary[pos] = Complex(a[pos], 0);
pos++;
}
fft_lut(fft_ary, fft_len);
ary_mul(fft_ary, fft_ary, fft_ary, fft_len);
ifft_lut(fft_ary, fft_len);
for (size_t i = 0; i < m + n - 1; i++)
{
c[i] = static_cast<unsigned>(fft_ary[i].imag() / 2 + 0.5);
}
delete[] fft_ary;
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |