#include <iostream>
#include <complex>
#include <vector>
#include <chrono>
#include <cmath>
#include <bitset>
#include <cstdint>
#include <cstddef>
#include <cassert>
#include <cstdlib>
#include <cstring>
namespace hint
{
using Float32 = float;
using Float64 = double;
using Complex32 = std::complex<Float32>;
using Complex64 = std::complex<Float64>;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
constexpr Float64 COS_PI_8 = 0.707106781186547524400844;
constexpr size_t FFT_MAX_LEN = size_t(1) << 23;
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2)
{
n |= (n >> i);
}
return n + 1;
}
template <typename IntTy>
constexpr bool is_2pow(IntTy n)
{
return n != 0 && (n & (n - 1)) == 0;
}
// 求整数的对数
template <typename T>
constexpr int hint_log2(T n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r)
{
int mid = (l + r) / 2;
if ((T(1) << mid) > n)
{
r = mid;
}
else
{
l = mid;
}
}
return l;
}
constexpr int hint_ctz(uint32_t x)
{
int r0 = 31;
x &= (-x);
if (x & 0x55555555)
{
r0 &= ~1;
}
if (x & 0x33333333)
{
r0 &= ~2;
}
if (x & 0x0F0F0F0F)
{
r0 &= ~4;
}
if (x & 0x00FF00FF)
{
r0 &= ~8;
}
if (x & 0x0000FFFF)
{
r0 &= ~16;
}
r0 += (x == 0);
return r0;
}
constexpr int hint_ctz(uint64_t x)
{
int r0 = 63;
x &= (-x);
if (x & 0x5555555555555555)
{
r0 &= ~1; // -1
}
if (x & 0x3333333333333333)
{
r0 &= ~2; // -2
}
if (x & 0x0F0F0F0F0F0F0F0F)
{
r0 &= ~4; // -4
}
if (x & 0x00FF00FF00FF00FF)
{
r0 &= ~8; // -8
}
if (x & 0x0000FFFF0000FFFF)
{
r0 &= ~16; // -16
}
if (x & 0x00000000FFFFFFFF)
{
r0 &= ~32; // -32
}
r0 += (x == 0);
return r0;
}
constexpr uint32_t bitrev(uint32_t n)
{
constexpr uint32_t mask55 = 0x55555555;
constexpr uint32_t mask33 = 0x33333333;
constexpr uint32_t mask0f = 0x0f0f0f0f;
constexpr uint32_t maskff = 0x00ff00ff;
n = ((n & mask55) << 1) | ((n >> 1) & mask55);
n = ((n & mask33) << 2) | ((n >> 2) & mask33);
n = ((n & mask0f) << 4) | ((n >> 4) & mask0f);
n = ((n & maskff) << 8) | ((n >> 8) & maskff);
return (n << 16) | (n >> 16);
}
constexpr uint64_t bitrev(uint64_t n)
{
constexpr uint64_t mask5555 = 0x5555555555555555;
constexpr uint64_t mask3333 = 0x3333333333333333;
constexpr uint64_t mask0f0f = 0x0f0f0f0f0f0f0f0f;
constexpr uint64_t mask00ff = 0x00ff00ff00ff00ff;
constexpr uint64_t maskffff = 0x0000ffff0000ffff;
n = ((n & mask5555) << 1) | ((n >> 1) & mask5555);
n = ((n & mask3333) << 2) | ((n >> 2) & mask3333);
n = ((n & mask0f0f) << 4) | ((n >> 4) & mask0f0f);
n = ((n & mask00ff) << 8) | ((n >> 8) & mask00ff);
n = ((n & maskffff) << 16) | ((n >> 16) & maskffff);
return (n << 32) | (n >> 32);
}
constexpr size_t bitrev(size_t n, int bits)
{
if (bits < 32)
{
return bitrev(static_cast<uint32_t>(n)) >> (32 - bits);
}
else
{
return bitrev(static_cast<uint64_t>(n)) >> (64 - bits);
}
}
constexpr int hint_popcnt(uint32_t n)
{
return __builtin_popcount(n);
}
template <typename T, T N>
struct StaticObject
{
using Type = T;
static constexpr Type value = N;
};
template <size_t N>
using StaticSize = StaticObject<size_t, N>;
template <int N>
using StaticInt = StaticObject<int, N>;
} // namespace hint
template <typename T>
inline void transform2(T &sum, T &diff)
{
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;
}
template <typename T>
inline void transform2(const T a, const T b, T &sum, T &diff)
{
sum = a + b;
diff = a - b;
}
using namespace hint;
template <typename Float>
void twiddle_omega(size_t rank, size_t indx, std::complex<Float> &omega)
{
omega = std::polar<Float>(1.0, -HINT_2PI * indx / rank);
}
template <typename Float>
void init_table(std::vector<std::complex<Float>> &table, size_t rank)
{
table.resize(rank / 2);
for (size_t i = 0; i < table.size(); i++)
{
twiddle_omega(rank, i, table[i]);
}
}
template <typename Float>
void init_table_rev(std::vector<std::complex<Float>> &table, size_t rank)
{
table.resize(rank / 2);
int bits = hint_log2(rank);
for (size_t i = 0; i < table.size(); i++)
{
size_t rev_indx = bitrev(i, bits - 1);
twiddle_omega(rank, rev_indx, table[i]);
}
}
template <typename Float>
inline void idit(std::complex<Float> in_out[], size_t len, bool norm = false)
{
using Complex = std::complex<Float>;
std::vector<std::complex<Float>> table;
for (size_t rank = 2; rank <= len; rank *= 2)
{
init_table(table, rank);
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
{
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i] * std::conj(table[i]);
it0[i] = x + y;
it1[i] = x - y;
}
}
}
if (norm)
{
const double len_inv = 1.0 / len;
for (size_t i = 0; i < len; i++)
{
in_out[i] *= len_inv;
}
}
}
template <typename Float>
inline void dif(std::complex<Float> in_out[], size_t len)
{
using Complex = std::complex<Float>;
std::vector<Complex> table;
for (size_t rank = len; rank >= 2; rank /= 2)
{
init_table(table, rank);
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
{
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i];
it0[i] = x + y;
it1[i] = (x - y) * table[i];
}
}
}
}
template <typename Float>
inline void idit_rtable(std::complex<Float> in_out[], size_t len, bool norm = false)
{
using Complex = std::complex<Float>;
std::vector<Complex> table;
init_table_rev(table, len);
for (size_t rank = 2; rank <= len; rank *= 2)
{
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
auto table_it = table.begin();
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank, table_it++)
{
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i];
it0[i] = x + y;
it1[i] = (x - y) * std::conj(table_it[0]);
}
}
}
if (norm)
{
const double len_inv = 1.0 / len;
for (size_t i = 0; i < len; i++)
{
in_out[i] *= len_inv;
}
}
}
template <typename Float>
inline void dif_rtable(std::complex<Float> in_out[], size_t len)
{
using Complex = std::complex<Float>;
std::vector<std::complex<Float>> table;
init_table_rev(table, len);
for (size_t rank = len; rank >= 2; rank /= 2)
{
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
auto table_it = table.begin();
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank, table_it++)
{
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i] * table_it[0];
it0[i] = x + y;
it1[i] = x - y;
}
}
}
}
template <typename FloatTy>
class BinRevTableComplexIterHP
{
public:
using Complex = std::complex<FloatTy>;
static constexpr int MAX_LOG_LEN = 32;
static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;
BinRevTableComplexIterHP(int log_max_iter_in, int log_fft_len_in)
: index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
{
assert(log_max_iter <= log_fft_len);
assert(log_fft_len <= MAX_LOG_LEN);
const FloatTy factor = FloatTy(1) / (size_t(1) << (log_fft_len - log_max_iter));
for (int i = 0; i < MAX_LOG_LEN; i++)
{
units[i] = getOmega(size_t(1) << (i + 1), 1, factor);
}
table[0] = Complex(1, 0);
}
void reset(size_t i = 0)
{
index = i;
if (i == 0)
{
pop = 0;
return;
}
pop = hint_popcnt(i);
const size_t len = size_t(1) << log_fft_len;
for (int p = pop; p > 0; p--)
{
table[p] = getOmega(len, bitrev(i, log_max_iter));
i &= (i - 1);
}
}
Complex iterate()
{
Complex res = table[pop];
index++;
int zero = hint_ctz(index);
pop -= zero;
table[pop + 1] = table[pop] * units[zero];
pop++;
return res;
}
static Complex getOmega(size_t n, size_t index, FloatTy factor = 1)
{
FloatTy theta = -HINT_2PI * index / n;
return std::polar<FloatTy>(1, theta * factor);
}
private:
Complex units[MAX_LOG_LEN]{};
Complex table[MAX_LOG_LEN]{};
size_t index;
int pop;
int log_max_iter, log_fft_len;
};
template <typename Float>
inline void idit_qtable(std::complex<Float> in_out[], size_t len, bool norm = false)
{
using Complex = std::complex<Float>;
BinRevTableComplexIterHP<Float> table(31, 32);
for (size_t rank = 2; rank <= len; rank *= 2)
{
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
table.reset(0);
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
{
Complex omega = table.iterate();
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i];
it0[i] = x + y;
it1[i] = (x - y) * std::conj(omega);
}
}
}
if (norm)
{
const double len_inv = 1.0 / len;
for (size_t i = 0; i < len; i++)
{
in_out[i] *= len_inv;
}
}
}
template <typename Float>
inline void dif_qtable(std::complex<Float> in_out[], size_t len)
{
using Complex = std::complex<Float>;
BinRevTableComplexIterHP<Float> table(31, 32);
for (size_t rank = len; rank >= 2; rank /= 2)
{
size_t stride = rank / 2;
auto it0 = in_out, it1 = it0 + stride;
table.reset(0);
for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
{
Complex omega = table.iterate();
for (size_t i = 0; i < stride; i++)
{
Complex x = it0[i];
Complex y = it1[i] * omega;
it0[i] = x + y;
it1[i] = x - y;
}
}
}
}
template <typename Float>
inline void idit(Float in_out[], size_t len, bool norm = false)
{
using Complex = std::complex<Float>;
auto in_out_c = reinterpret_cast<Complex *>(in_out);
idit_qtable(in_out_c, len / 2, norm);
}
template <typename Float>
inline void dif(Float in_out[], size_t len)
{
using Complex = std::complex<Float>;
auto in_out_c = reinterpret_cast<Complex *>(in_out);
dif_qtable(in_out_c, len / 2);
}
template <typename Float>
void real_conv_std(Float in_out1[], Float in2[], size_t len)
{
using Complex = std::complex<Float>;
assert(is_2pow(len));
std::vector<Complex> arr1(len), arr2(len);
std::copy(in_out1, in_out1 + len, arr1.begin());
std::copy(in2, in2 + len, arr2.begin());
dif(arr1.data(), len);
dif(arr2.data(), len);
for (size_t i = 0; i < len; i++)
{
arr1[i] *= arr2[i];
}
idit(arr1.data(), len, true);
for (size_t i = 0; i < len; i++)
{
in_out1[i] = arr1[i].real();
}
}
template <typename Float>
void real_conv_rev_table(Float in_out1[], Float in2[], size_t len)
{
using Complex = std::complex<Float>;
assert(is_2pow(len));
std::vector<Complex> arr1(len), arr2(len);
std::copy(in_out1, in_out1 + len, arr1.begin());
std::copy(in2, in2 + len, arr2.begin());
dif_qtable(arr1.data(), len);
dif_qtable(arr2.data(), len);
for (size_t i = 0; i < len; i++)
{
arr1[i] *= arr2[i];
}
idit_qtable(arr1.data(), len, true);
for (size_t i = 0; i < len; i++)
{
in_out1[i] = arr1[i].real();
}
}
template <size_t RI_DIFF = 1, typename FloatTy>
inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1,
const std::complex<FloatTy> &omega0, const FloatTy factor = 0.125)
{
using Complex = std::complex<FloatTy>;
auto combine2 = [&omega0](auto r0, auto i0, auto r1, auto i1, Complex &out0, Complex &out1)
{
auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
auto tr1 = r0 - r1, ti1 = i0 + i1; // diff
r0 = ti1 * omega0.real() + tr1 * omega0.imag();
i0 = ti1 * omega0.imag() - tr1 * omega0.real();
out0.real(tr0 + r0);
out0.imag(ti0 + i0);
out1.real(tr0 - r0);
out1.imag(i0 - ti0);
};
Complex x0, x1, x2, x3;
auto r0 = inout0[0], i0 = inout0[RI_DIFF], r1 = inout1[0], i1 = inout1[RI_DIFF];
combine2(r0, i0, r1, i1, x0, x1);
r0 = in0[0], i0 = in0[RI_DIFF], r1 = in1[0], i1 = in1[RI_DIFF];
combine2(r0, i0, r1, i1, x2, x3);
x0 *= x2;
x1 *= x3;
{ // separate2
r0 = x0.real(), i0 = x0.imag(), r1 = x1.real(), i1 = x1.imag();
auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
auto tr1 = r0 - r1, ti1 = i0 + i1; // diff
auto r = tr1 * omega0.imag() - ti1 * omega0.real();
auto i = tr1 * omega0.real() + ti1 * omega0.imag();
r0 = tr0 + r;
i0 = ti0 + i;
r1 = tr0 - r;
i1 = i - ti0;
}
inout0[0] = r0 * factor, inout0[RI_DIFF] = i0 * factor, inout1[0] = r1 * factor, inout1[RI_DIFF] = i1 * factor;
}
template <typename Float>
inline void real_dot_binrev(Float in_out[], const Float in[], size_t float_len)
{
assert(is_2pow(float_len));
using Complex = std::complex<Float>;
Float inv = 2.0 / float_len;
{
auto r0 = in_out[0], i0 = in_out[1], r1 = in[0], i1 = in[1];
transform2(r0, i0);
transform2(r1, i1);
r0 *= r1, i0 *= i1;
transform2(r0, i0);
in_out[0] = r0 * 0.5 * inv, in_out[1] = i0 * 0.5 * inv;
}
auto temp = Complex(in_out[2], in_out[3]) * Complex(in[2], in[3]) * inv;
in_out[2] = temp.real(), in_out[3] = temp.imag();
inv /= 8;
dot_rfft(&in_out[4], &in_out[6], &in[4], &in[6], Complex(COS_PI_8, -COS_PI_8), inv);
BinRevTableComplexIterHP<Float> table(31, 32);
for (size_t begin = 8; begin < float_len; begin *= 2)
{
table.reset(begin / 2);
auto it0 = in_out + begin, it1 = it0 + begin - 2, it2 = in + begin, it3 = it2 + begin - 2;
for (; it0 < it1; it0 += 2, it1 -= 2, it2 += 2, it3 -= 2)
{
auto omega = table.iterate();
dot_rfft(it0, it1, it2, it3, omega, inv);
}
}
}
inline void real_conv_binrev(Complex64 in_out[], Complex64 in[], size_t len)
{
std::cout << len << "\n";
static BinRevTableComplexIterHP<double> table(31, 32);
auto t0 = in_out[0], t1 = in[0];
auto t2 = (t0.real() + t0.imag()) * (t1.real() + t1.imag());
auto t3 = (t0.real() - t0.imag()) * (t1.real() - t1.imag());
in_out[0] = Complex64(t2 + t3, t2 - t3) * 0.5;
in_out[1] *= in[1];
dot_rfft((double *)&in_out[2], (double *)&in_out[3], (double *)&in[2], (double *)&in[3], Complex64(COS_PI_8, -COS_PI_8));
for (size_t begin = 4; begin < len; begin *= 2)
{
table.reset(begin);
// std::cout << begin << '\n';
auto it1 = in_out + begin, it2 = in + begin;
for (size_t i = 0; i < begin / 2; i++)
{
Complex64 omega = table.iterate();
// std::cout << omega;
dot_rfft((double *)&it1[i], (double *)&it1[begin - i - 1], (double *)&it2[i], (double *)&it2[begin - i - 1], omega);
}
}
}
void real_conv_rfft(Float64 in_out1[], Float64 in2[], size_t len)
{
assert(is_2pow(len));
dif(in_out1, len);
dif(in2, len);
real_dot_binrev(in_out1, in2, len);
idit(in_out1, len);
}
template <typename T>
std::vector<T> poly_multiply(const std::vector<T> &in1, const std::vector<T> &in2)
{
size_t len1 = in1.size(), len2 = in2.size();
size_t conv_len = len1 + len2;
size_t float_len = int_ceil2(conv_len);
size_t fft_len = float_len / 2;
auto p1 = (double *)malloc(float_len * sizeof(double));
auto p2 = (double *)malloc(float_len * sizeof(double));
std::copy(in1.begin(), in1.end(), p1);
std::copy(in2.begin(), in2.end(), p2);
std::fill(p1 + len1, p1 + float_len, 0);
std::fill(p2 + len2, p2 + float_len, 0);
// real_conv<false>(p1, p2, float_len);
// real_conv_std(p1, p2, float_len);
// real_conv_rev_table(p1, p2, float_len);
real_conv_rfft(p1, p2, float_len);
// auto i64_p = reinterpret_cast<uint64_t *>(p1);
std::vector<T> res(conv_len);
for (size_t i = 0; i < conv_len; i++)
{
res[i] = p1[i] + 0.5;
}
free(p1);
free(p2);
return res;
}
template <typename T>
void result_test(const std::vector<T> &res, uint64_t ele1, uint64_t ele2)
{
size_t len = res.size();
for (size_t i = 0; i < len / 2; i++)
{
uint64_t x = (i + 1) * ele1 * ele2;
uint64_t y = res[i];
if (x != y)
{
std::cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
return;
}
}
for (size_t i = len / 2; i < len; i++)
{
uint64_t x = (len - i - 1) * ele1 * ele2;
uint64_t y = res[i];
if (x != y)
{
std::cout << "fail:" << i << "\t" << x << "\t" << y << "\n";
return;
}
}
std::cout << "success\n";
}
void perf_conv()
{
int n = 13;
// std::cin >> n;
size_t len = size_t(1) << n; // 变换长度
std::cout << "conv len:" << len << "\n";
uint64_t ele1 = 9999, ele2 = 7777;
std::vector<uint64_t> in1(len / 2, 0);
std::vector<uint64_t> in2(len / 2, 0); // 计算两个长度为len/2,每个元素为ele的卷积
for (size_t i = 0; i < 2500; i++)
{
in1[i] = ele1;
in2[i] = ele2;
}
auto t1 = std::chrono::high_resolution_clock::now();
std::vector<uint64_t> res = poly_multiply(in1, in2);
auto t2 = std::chrono::high_resolution_clock::now();
for (auto i : res)
{
std::cout << i << " ";
}
std::cout << "\n";
result_test<uint64_t>(res, ele1, ele2); // 结果校验
std::cout << "cost:" << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us\n";
}
void test_rev_tb()
{
BinRevTableComplexIterHP<double> table(31, 32);
using cpx = std::complex<double>;
auto rev_cpx = [](size_t i, int div)
{
const int bits = 31;
const size_t len = size_t(1) << bits;
cpx ret;
twiddle_omega(len, bitrev(i, bits) / div, ret);
return ret;
};
size_t end = 1 << 23, begin = 0;
std::cin >> begin;
table.reset(begin);
for (size_t i = begin; i < end; i++)
{
cpx a = rev_cpx(i, 2), b = table.iterate();
if (std::abs(a - b) > 1e-15)
{
std::cout << a << b << '\n';
return;
}
}
std::cout << "end";
}
inline void test_rconv()
{
using Complex = Complex64;
static double ary1[1 << 23]{};
static double ary2[1 << 23]{};
size_t n = 1 << 23;
// real_conv_binrev(ary1, ary2, n);
for (size_t i = 0; i < n / 2; i++)
{
ary1[i] = 9999;
ary2[i] = 9999;
}
dif(ary1, n);
dif(ary2, n);
auto t1 = std::chrono::high_resolution_clock::now();
// real_conv_binrev4((double *)ary1, (double *)ary2, n / 2);
real_conv_binrev((Complex *)ary1, (Complex *)ary2, n / 2);
auto t2 = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < n; i++)
{
ary1[i] = ary1[i] * (0.25 / n);
}
idit(ary1, n, false);
for (size_t i = 0; i < n; i++)
{
std::cout << uint64_t((ary1)[i] + 0.5) / (9999ull * 9999) << "\t";
}
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us" << std::endl;
}
inline void test_rconv1()
{
using Complex = Complex64;
static Complex ary1[1 << 23]{};
static Complex ary2[1 << 23]{};
size_t n = 1 << 23;
// real_conv_binrev(ary1, ary2, n);
for (size_t i = 0; i < n / 2; i++)
{
((double *)ary1)[i] = 9999;
((double *)ary2)[i] = 9999;
}
dif(ary1, n / 2);
dif(ary2, n / 2);
// for (size_t i = 0; i < n / 2; i++)
// {
// std::cout << ary1[i] << ' ' << ary2[i] << std::endl;
// }
// for (size_t i = 0; i < n / 2; i += 4)
// {
// using C64X4 = hint_simd::Complex64X4;
// auto p1 = (double *)(ary1 + i), p2 = (double *)(ary2 + i);
// C64X4 a(p1), b(p2);
// a = a.toRRIIPermu();
// b = b.toRRIIPermu();
// a.storeu(p1);
// b.storeu(p2);
// }
auto t1 = std::chrono::high_resolution_clock::now();
// real_conv_binrev4((double *)ary1, (double *)ary2, n / 2);
real_conv_binrev(ary1, ary2, n / 2);
auto t2 = std::chrono::high_resolution_clock::now();
// for (size_t i = 0; i < n / 2; i += 4)
// {
// using C64X4 = hint_simd::Complex64X4;
// auto p1 = (double *)(ary1 + i), p2 = (double *)(ary2 + i);
// C64X4 a(p1), b(p2);
// a = a.toRIRIPermu();
// b = b.toRIRIPermu();
// a.storeu(p1);
// b.storeu(p2);
// }
// real_conv_binrev(ary1, ary2, n / 2);
// for (size_t i = 0; i < n / 2; i++)
// {
// std::cout << ary1[i] << ' ' << ary2[i] << std::endl;
// }
for (size_t i = 0; i < n / 2; i++)
{
ary1[i] = ary1[i] * (0.25 / n);
}
idit(ary1, n / 2, false);
for (size_t i = 0; i < n; i++)
{
std::cout << uint64_t(((double *)ary1)[i] + 0.5) / (9999ull * 9999) << "\t";
}
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() << "us" << std::endl;
// fft(ary1, n);
// for (size_t i = 0; i < n; i++)
// {
// ary1[i] *= ary1[i];
// }
// ifft(ary1, n);
// for (size_t i = 0; i < n; i++)
// {
// ((double *)ary1)[i] = ary1[i].real();
// }
// // fft(ary1, n / 2);
// fft(ary2, n / 2);
// real_conv(ary2, ary2, n / 2);
// ifft(ary2, n / 2);
// for (size_t i = 0; i < n; i++)
// {
// // if (abs(ary1[i] - ary2[i]) > 0.5)
// {
// std::cout << i << " " << ((double *)ary1)[i] << " " << ((double *)ary2)[i] << endl;
// // break;
// }
// }
}
void test_rdot()
{
using Complex = Complex64;
static double ary1[1 << 23]{};
static double ary2[1 << 23]{};
static double cary1[1 << 23]{};
static double cary2[1 << 23]{};
size_t n = 1 << 5;
srand(0);
for (size_t i = 0; i < n; i++)
{
ary1[i] = rand() % 10;
ary2[i] = rand() % 10;
cary1[i] = ary1[i];
cary2[i] = ary2[i];
}
real_conv_std(ary1, ary2, n);
dif(ary1, n);
dif(cary1, n);
dif(cary2, n);
real_dot_binrev(cary1, cary2, n);
// real_conv_binrev((Complex *)cary1, (Complex *)cary2, n / 2);
auto cp1 = reinterpret_cast<Complex *>(cary1), cp2 = reinterpret_cast<Complex *>(ary1);
for (size_t i = 0; i < n / 2; i++)
{
if (std::abs(cp1[i] - cp2[i]) > 1e-10)
{
std::cout << i << "\t" << cp1[i] << " " << cp2[i] << std::endl;
}
}
std::cout << "end\n";
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num)
{
uint32_t res = (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res + '0' * 0x1010101;
}
constexpr ItoStrBase10000()
{
for (size_t i = 0; i < 10000; i++)
{
table[i] = itosbase10000(i);
}
}
void tostr(char *str, uint32_t num) const
{
// *reinterpret_cast<uint32_t *>(str) = table[num];
std::memcpy(str, &table[num], sizeof(num));
}
uint32_t tostr(uint32_t num) const
{
return table[num];
}
};
class StrtoIBase100
{
private:
static constexpr size_t TABLE_SIZE = size_t(1) << 15;
uint16_t table[TABLE_SIZE]{};
public:
static constexpr uint16_t itosbase100(uint16_t num)
{
uint16_t res = (num / 10 % 10) | ((num % 10) << 8);
return res + '0' * 0x0101;
}
constexpr StrtoIBase100()
{
for (size_t i = 0; i < TABLE_SIZE; i++)
{
table[i] = UINT16_MAX;
}
for (size_t i = 0; i < 100; i++)
{
table[itosbase100(i)] = i;
}
}
uint16_t toInt(const char *str) const
{
uint16_t num;
std::memcpy(&num, str, sizeof(num));
return table[num];
}
};
constexpr ItoStrBase10000 itosbase10000{};
constexpr StrtoIBase100 strtoibase100{};
constexpr uint32_t stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
#include <immintrin.h>
template <typename T, size_t ALIGN = 64>
class AlignMem
{
public:
using Ptr = T *;
using ConstPtr = const T *;
AlignMem() : ptr(nullptr) {}
AlignMem(size_t n) : len(n), ptr(reinterpret_cast<Ptr>(_mm_malloc(n * sizeof(T), ALIGN))) {}
~AlignMem()
{
if (ptr)
{
_mm_free(ptr);
}
};
T &operator[](size_t i)
{
return ptr[i];
}
const T &operator[](size_t i) const
{
return ptr[i];
}
Ptr begin()
{
return ptr;
}
Ptr end()
{
return ptr + len;
}
ConstPtr begin() const
{
return ptr;
}
ConstPtr end() const
{
return ptr + len;
}
private:
T *ptr;
size_t len;
};
template <typename T>
void fill_zero(T *begin, T *end)
{
std::memset(begin, 0, (end - begin) * sizeof(T));
}
template <typename T>
size_t str_num_to_array_base10000(const char *str, size_t len, T *ary)
{
constexpr size_t BLOCK = 4;
auto end = str + len, p = str;
size_t i = 0;
for (auto ed = end - len % BLOCK; p < ed; p += BLOCK, i++)
{
ary[i] = stobase10000(p);
}
size_t shift = 0;
if (p < end)
{
size_t rem = end - p;
int n = 0;
for (; p < end; p++)
{
n = n * 10 + *p - '0';
}
shift = BLOCK - rem;
for (; rem < BLOCK; rem++)
{
n *= 10;
}
ary[i] = n;
i++;
}
return shift;
}
template <typename T>
size_t conv_to_str_base10000(const T *ary, size_t conv_len, size_t shift, char *res, size_t &res_len)
{
constexpr size_t BLOCK = 4, BASE = 10000;
res_len = (conv_len + 1) * BLOCK;
auto end = res + res_len;
size_t i = conv_len;
uint64_t carry = 0;
while (i > 0)
{
i--;
end -= BLOCK;
carry += uint64_t(ary[i] + 0.5);
itosbase10000.tostr(end, carry % BASE);
carry /= BASE;
}
assert(carry < BASE);
end -= 4;
itosbase10000.tostr(end, carry);
while (*end == '0')
{
end++;
}
size_t offset = end - res;
res_len -= (offset + shift);
return offset;
}
// return result begin
char *big_mul(const char *str1, size_t len1, const char *str2, size_t len2, char *res, size_t &res_len)
{
constexpr size_t BLOCK = 4, BASE = 10000;
size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
size_t conv_len = block_len1 + block_len2 - 1, fft_len = hint::int_ceil2(conv_len);
AlignMem<Float64> ary1(fft_len), ary2(fft_len);
size_t shift = str_num_to_array_base10000(str1, len1, &ary1[0]);
shift += str_num_to_array_base10000(str2, len2, &ary2[0]);
fill_zero(ary1.begin() + block_len1, ary1.end());
fill_zero(ary2.begin() + block_len2, ary2.end());
real_conv_rfft(ary1.begin(), ary2.begin(), fft_len);
return res + conv_to_str_base10000(ary1.begin(), conv_len, shift, res, res_len);
}
size_t preserve_strlen(size_t len1, size_t len2)
{
constexpr size_t BLOCK = 4;
size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
return (block_len1 + block_len2) * BLOCK;
}
void mul_test()
{
std::string s1, s2;
std::cin >> s1 >> s2;
size_t len1 = s1.size(), len2 = s2.size();
size_t res_len = preserve_strlen(len1, len2);
std::vector<char> res(res_len, '0');
auto begin = big_mul(s1.data(), len1, s2.data(), len2, res.data(), res_len);
auto end = begin + res_len;
fwrite(begin, 1, res_len, stdout);
}
int main()
{
mul_test();
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 44.015 ms | 15 MB + 660 KB | Accepted | Score: 100 | 显示更多 |