#include <tuple>
#include <iostream>
#include <cstdint>
#include <cstring>
#include <cmath>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("avx2")
#pragma GCC target("fma")
namespace hint_simd
{
template <typename T, size_t LEN>
class AlignAry
{
private:
alignas(128) T ary[LEN];
public:
constexpr AlignAry() {}
constexpr T &operator[](size_t index){
return ary[index];}
constexpr const T &operator[](size_t index) const{
return ary[index];}
constexpr size_t size() const{
return LEN;}
T *data(){
return reinterpret_cast<T *>(ary);}
const T *data() const{
return reinterpret_cast<const T *>(ary);}
template <typename Ty>
Ty *cast_ptr(){
return reinterpret_cast<Ty *>(ary);}
template <typename Ty>
const Ty *cast_ptr() const{
return reinterpret_cast<const Ty *>(ary);}
};
// Use AVX
// 256bit simd
// 4个Double并行
struct DoubleX4
{
__m256d data;
DoubleX4(){
data = _mm256_setzero_pd();}
DoubleX4(double input){
data = _mm256_set1_pd(input);}
DoubleX4(__m256d input){
data = input;}
DoubleX4(const DoubleX4 &input){
data = input.data;}
// 从连续的数组构造
DoubleX4(double const *ptr){
load(ptr);}
// 用4个数构造
DoubleX4(double a7, double a6, double a5, double a4){
data = _mm256_set_pd(a7, a6, a5, a4);}
void clr(){
data = _mm256_setzero_pd();}
void load(double const *ptr){
data = _mm256_load_pd(ptr);}
void loadu(double const *ptr){
data = _mm256_loadu_pd(ptr);}
void store(double *ptr) const{
_mm256_store_pd(ptr, data);}
void storeu(double *ptr) const{
_mm256_storeu_pd(ptr, data);}
void operator<<(double *ptr){
data = _mm256_loadu_pd(ptr);}
void print() const{
double ary[4];
store(ary);
printf("(%lf,%lf,%lf,%lf)\n",
ary[0], ary[1], ary[2], ary[3]);}
template <int N>
DoubleX4 permute4x64() const{
return _mm256_permute4x64_pd(data, N);}
template <int N>
DoubleX4 permute() const{
return _mm256_permute_pd(data, N);}
DoubleX4 reverse() const{
return permute4x64<0b00011011>();}
DoubleX4 addsub(DoubleX4 input) const{
return _mm256_addsub_pd(data, input.data);}
DoubleX4 fmadd(DoubleX4 mul1, DoubleX4 mul2) const{
return _mm256_fmadd_pd(mul1.data, mul2.data, data);}
DoubleX4 fmsub(DoubleX4 mul1, DoubleX4 mul2) const{
return _mm256_fmsub_pd(mul1.data, mul2.data, data);}
DoubleX4 operator+(DoubleX4 input) const{
return _mm256_add_pd(data, input.data);}
DoubleX4 operator-(DoubleX4 input) const{
return _mm256_sub_pd(data, input.data);}
DoubleX4 operator*(DoubleX4 input) const{
return _mm256_mul_pd(data, input.data);}
DoubleX4 operator/(DoubleX4 input) const{
return _mm256_div_pd(data, input.data);}
DoubleX4 operator-() const{
return _mm256_sub_pd(_mm256_setzero_pd(), data);}
};
#ifdef __AVX512F__
#pragma GCC target("avx512f")
struct DoubleX8
{
__m512d data;
DoubleX8(){
data = _mm512_setzero_pd();}
DoubleX8(double input){
data = _mm512_set1_pd(input);}
DoubleX8(__m512d input){
data = input;}
DoubleX8(const DoubleX8 &input){
data = input.data;}
// 从连续的数组构造
DoubleX8(double const *ptr){
load(ptr);}
// 用4个数构造
DoubleX8(double a7, double a6, double a5, double a4, double a3, double a2, double a1, double a0){
data = _mm512_set_pd(a7, a6, a5, a4, a3, a2, a1, a0);}
void clr(){
data = _mm512_setzero_pd();}
void load(double const *ptr){
data = _mm512_load_pd(ptr);}
void loadu(double const *ptr){
data = _mm512_loadu_pd(ptr);}
void store(double *ptr) const{
_mm512_store_pd(ptr, data);}
void storeu(double *ptr) const{
_mm512_storeu_pd(ptr, data);}
void print() const{
double ary[8];
storeu(ary);
printf("(%lf,%lf,%lf,%lf,%lf,%lf,%lf,%lf)\n",
ary[0], ary[1], ary[2], ary[3], ary[4], ary[5], ary[6], ary[7]);}
template <int N>
DoubleX8 permutex() const{
return _mm512_permutex_pd(data, N);}
template <int N>
DoubleX8 permute() const{
return _mm512_permute_pd(data, N);}
template <int N>
DoubleX8 shuffle(DoubleX8 in) const{
return _mm512_shuffle_pd(data, in.data, N);}
template <int N>
DoubleX8 shuffle_f128(DoubleX8 in) const{
return _mm512_shuffle_f64x2(data, in.data, N);}
DoubleX8 swap_oe() const{
return permute<0b01010101>();}
DoubleX8 reverse() const{
auto eo = swap_oe();
return eo.shuffle_f128<0b00011011>(eo);}
DoubleX8 fmadd(DoubleX8 mul1, DoubleX8 mul2) const{
return _mm512_fmadd_pd(mul1.data, mul2.data, data);}
DoubleX8 fmsub(DoubleX8 mul1, DoubleX8 mul2) const{
return _mm512_fmsub_pd(mul1.data, mul2.data, data);}
DoubleX8 operator+(DoubleX8 input) const{
return _mm512_add_pd(data, input.data);}
DoubleX8 operator-(DoubleX8 input) const{
return _mm512_sub_pd(data, input.data);}
DoubleX8 operator*(DoubleX8 input) const{
return _mm512_mul_pd(data, input.data);}
DoubleX8 operator/(DoubleX8 input) const{
return _mm512_div_pd(data, input.data);}
DoubleX8 operator-() const{
return _mm512_sub_pd(_mm512_setzero_pd(), data);}
};
#else
struct DoubleX8
{
DoubleX4 data0;
DoubleX4 data1;
DoubleX8(){
data0 = data1 = _mm256_setzero_pd();}
DoubleX8(DoubleX4 in0, DoubleX4 in1) : data0(in0), data1(in1) {}
DoubleX8(double input){
data0 = data1 = _mm256_set1_pd(input);}
DoubleX8(const DoubleX8 &input){
data0 = input.data0;
data1 = input.data1;}
// 从连续的数组构造
DoubleX8(double const *ptr){
load(ptr);}
// 用4个数构造
DoubleX8(double a7, double a6, double a5, double a4, double a3, double a2, double a1, double a0){
data1 = _mm256_set_pd(a7, a6, a5, a4);
data0 = _mm256_set_pd(a3, a2, a1, a0);}
void clr(){
data0 = data1 = _mm256_setzero_pd();}
void load(double const *ptr){
data0 = _mm256_load_pd(ptr);
data1 = _mm256_load_pd(ptr + 4);}
void loadu(double const *ptr){
data0 = _mm256_loadu_pd(ptr);
data1 = _mm256_loadu_pd(ptr + 4);}
void store(double *ptr) const{
_mm256_store_pd(ptr, data0.data);
_mm256_store_pd(ptr + 4, data1.data);}
void storeu(double *ptr) const{
_mm256_storeu_pd(ptr, data0.data);
_mm256_storeu_pd(ptr + 4, data1.data);}
void print() const{
double ary[8];
storeu(ary);
printf("(%lf,%lf,%lf,%lf,%lf,%lf,%lf,%lf)\n",
ary[0], ary[1], ary[2], ary[3], ary[4], ary[5], ary[6], ary[7]);}
DoubleX8 reverse() const{
return DoubleX8(data1.reverse(), data0.reverse());}
DoubleX8 fmadd(DoubleX8 mul1, DoubleX8 mul2) const{
return DoubleX8(data0.fmadd(mul1.data0, mul2.data0), data1.fmadd(mul1.data1, mul2.data1));}
DoubleX8 fmsub(DoubleX8 mul1, DoubleX8 mul2) const{
return DoubleX8(data0.fmsub(mul1.data0, mul2.data0), data1.fmsub(mul1.data1, mul2.data1));}
DoubleX8 operator+(DoubleX8 input) const{
return DoubleX8(data0 + input.data0, data1 + input.data1);}
DoubleX8 operator-(DoubleX8 input) const{
return DoubleX8(data0 - input.data0, data1 - input.data1);}
DoubleX8 operator*(DoubleX8 input) const{
return DoubleX8(data0 * input.data0, data1 * input.data1);}
DoubleX8 operator/(DoubleX8 input) const{
return DoubleX8(data0 / input.data0, data1 / input.data1);}
DoubleX8 operator-() const{
return DoubleX8(-data0, -data1);}
};
#endif
}
#endif
#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
namespace hint
{
using namespace hint_simd;
using Float32 = float;
using Float64 = double;
constexpr Float64 HINT_PI = 3.141592653589793238462643;
constexpr Float64 HINT_2PI = HINT_PI * 2;
constexpr size_t FHT_MAX_LEN = size_t(1) << 19;
template <typename T>
constexpr T int_floor2(T n)
{
constexpr int bits = sizeof(n) * 8;
for (int i = 1; i < bits; i *= 2){
n |= (n >> i);}
return (n >> 1) + 1;
}
template <typename T>
constexpr T int_ceil2(T n)
{
constexpr int bits = sizeof(n) * 8;
n--;
for (int i = 1; i < bits; i *= 2){
n |= (n >> i);}
return n + 1;
}
// 求整数的对数
template <typename T>
constexpr int hint_log2(T n)
{
constexpr int bits = sizeof(n) * 8;
int l = -1, r = bits;
while ((l + 1) != r){
int mid = (l + r) / 2;
if ((T(1) << mid) > n){
r = mid;}
else{
l = mid;}}
return l;
}
template <typename T>
constexpr std::pair<T, T> div_mod(T dividend, T divisor)
{
return std::make_pair(dividend / divisor, dividend % divisor);
}
namespace hint_transform
{
template <typename FloatTy, int log_len, int len_div>
class CosStaticTable{
public:
static constexpr size_t vec_len = (size_t(1) << log_len) / len_div;
using Ty = FloatTy;
using TableTy = hint_simd::AlignAry<Ty, vec_len>;
CosStaticTable() {}
CosStaticTable(int factor){
init(factor);}
void init(int factor){
size_t len = table.size() * len_div;
FloatTy unity = -HINT_2PI * factor / len;
for (size_t i = 0; i < table.size(); i++){
table[i] = std::cos(unity * i);}}
const auto &operator[](size_t n) const{
return table[n];}
auto &operator[](size_t n){
return table[n];}
auto get_it(size_t n = 0) const{
return &table[n];}
private:
TableTy table;};
template <typename T>
inline void transform_2point(T &sum, T &diff){
T temp0 = sum, temp1 = diff;
sum = temp0 + temp1;
diff = temp0 - temp1;}
namespace hint_fht{
template <typename FloatTy, int log_len>
class FHTTable{
public:
using HalfTable = FHTTable<FloatTy, log_len - 1>;
using TableTy = CosStaticTable<FloatTy, log_len, 4>;
FHTTable(){
init();}
static void init(){
if (!has_init)
{
HalfTable::init();
constexpr size_t table_len = table.vec_len, fht_len = size_t(1) << log_len;
static_assert(table_len * 4 == fht_len, "Length of table must be fht_len / 4 ");
static_assert(table_len > 1, "Length of table must be larger than 1");
static const FloatTy unity = HINT_2PI / fht_len;
static const FloatTy cos_unit = std::cos(unity);
static const FloatTy sin_unit = std::sin(unity);
table[0] = 1;
table[1] = cos_unit;
table[table_len - 1] = sin_unit;
table[table_len / 2] = std::cos(HINT_PI / 4);
for (size_t i = 2, j = table_len - 2; i < j; i += 2, j -= 2)
{
FloatTy cos_i = table[i] = HalfTable::table[i / 2]; // s_j
FloatTy sin_i = table[j] = HalfTable::table[j / 2]; // c_j
table[i + 1] = cos_i * cos_unit - sin_i * sin_unit; // c_i*c_u - s_i*s_u;
table[j - 1] = sin_i * cos_unit + cos_i * sin_unit; // c_j*c_u + s_j*s_u;
}
has_init = true;
}}
auto get_it(size_t n = 0) const{
return table.get_it(n);}
auto get_last() const{
return get_it((size_t(1) << log_len) / 4 - 1);}
static bool has_init;
static TableTy table;};
template <typename FloatTy, int log_len>
typename FHTTable<FloatTy, log_len>::TableTy FHTTable<FloatTy, log_len>::table;
template <typename FloatTy, int log_len>
bool FHTTable<FloatTy, log_len>::has_init = false;
template <typename FloatTy>
class FHTTable<FloatTy, 3>{
public:
using TableTy = CosStaticTable<FloatTy, 3, 4>;
FHTTable(){
init();}
static void init(){
table.init(1);}
static TableTy table;};
template <typename FloatTy>
typename FHTTable<FloatTy, 3>::TableTy FHTTable<FloatTy, 3>::table;
template <size_t LEN, typename FloatTy>
struct FHT{
enum{
fht_len = LEN,
half_len = LEN / 2,
quarter_len = LEN / 4,
log_len = hint_log2(fht_len)};
using HalfFHT = FHT<half_len, FloatTy>;
using TableTy = FHTTable<FloatTy, log_len>;
static TableTy TABLE;
template <typename FloatIt>
static void dit(FloatIt in_out){
static_assert(std::is_same<typename std::iterator_traits<FloatIt>::value_type, FloatTy>::value, "Must be same as the FHT template float type");
HalfFHT::dit(in_out + half_len);
HalfFHT::dit(in_out);
transform_2point(in_out[0], in_out[half_len]);
transform_2point(in_out[quarter_len], in_out[half_len + quarter_len]);
auto it0 = in_out + 1, it1 = in_out + half_len - 1;
auto it2 = it0 + half_len, it3 = it1 + half_len;
auto cos_it = TABLE.get_it(1), sin_it = TABLE.get_last();
for (; it0 < in_out + 4; ++it0, --it1, ++it2, --it3, cos_it++, sin_it--)
{
auto c = cos_it[0], s = sin_it[0];
auto temp0 = it2[0], temp1 = it3[0];
auto temp2 = temp0 * c + temp1 * s; //+*+ -(-*-)
auto temp3 = temp0 * s - temp1 * c; //-(+)*- -*-(+)
temp0 = it0[0], temp1 = it1[0];
it0[0] = temp0 + temp2; //+
it1[0] = temp1 + temp3; //-
it2[0] = temp0 - temp2; //+
it3[0] = temp1 - temp3; //-
}
it1 -= 3, it3 -= 3, sin_it -= 3;
for (; it0 < in_out + 8; it0 += 4, it1 -= 4, it2 += 4, it3 -= 4, cos_it += 4, sin_it -= 4)
{
DoubleX4 c4, s4, temp0, temp1, temp2, temp3;
c4.load(&cos_it[0]), s4.loadu(&sin_it[0]);
temp0.load(&it2[0]), temp1.loadu(&it3[0]);
// temp2 = (temp1 * s4).reverse().fmadd(temp0, c4);
temp2 = (temp0 * c4).fmadd(temp1.reverse(), s4.reverse());
temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4);
temp0.load(&it0[0]), temp1.loadu(&it1[0]);
(temp0 + temp2).store(&it0[0]);
(temp1 + temp3).storeu(&it1[0]);
(temp0 - temp2).store(&it2[0]);
(temp1 - temp3).storeu(&it3[0]);
}
it1 -= 4, it3 -= 4, sin_it -= 4;
for (; it0 < in_out + quarter_len; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8, cos_it += 8, sin_it -= 8)
{
DoubleX8 c4, s4, temp0, temp1, temp2, temp3;
c4.load(&cos_it[0]), s4.loadu(&sin_it[0]);
temp0.load(&it2[0]), temp1.loadu(&it3[0]);
// temp2 = (temp1 * s4).reverse().fmadd(temp0, c4);
temp2 = (temp0 * c4).fmadd(temp1.reverse(), s4.reverse());
temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4);
temp0.load(&it0[0]), temp1.loadu(&it1[0]);
(temp0 + temp2).store(&it0[0]);
(temp1 + temp3).storeu(&it1[0]);
(temp0 - temp2).store(&it2[0]);
(temp1 - temp3).storeu(&it3[0]);
}}
template <typename FloatIt>
static void dif(FloatIt in_out){
static_assert(std::is_same<typename std::iterator_traits<FloatIt>::value_type, FloatTy>::value, "Must be same as the FHT template float type");
auto it0 = in_out + 1, it1 = in_out + half_len - 1;
auto it2 = it0 + half_len, it3 = it1 + half_len;
auto cos_it = TABLE.get_it(1), sin_it = TABLE.get_last();
for (; it0 < in_out + 4; ++it0, --it1, ++it2, --it3, cos_it++, sin_it--)
{
auto c = cos_it[0], s = sin_it[0]; //+,-
auto temp0 = it0[0], temp1 = it1[0]; //+,-
auto temp2 = it2[0], temp3 = it3[0]; //+,-
it0[0] = temp0 + temp2; //+
it1[0] = temp1 + temp3; //-
temp0 = temp0 - temp2; //+
temp1 = temp1 - temp3; //-
it2[0] = temp0 * c + temp1 * s; //+*+ -(-*-)
it3[0] = temp0 * s - temp1 * c; //-(+)*- -*-(+)
}
it1 -= 3, it3 -= 3, sin_it -= 3;
for (; it0 < in_out + 8; it0 += 4, it1 -= 4, it2 += 4, it3 -= 4, cos_it += 4, sin_it -= 4)
{
DoubleX4 c4, s4, temp0, temp1, temp2, temp3;
temp0.load(&it0[0]), temp1.loadu(&it1[0]);
temp2.load(&it2[0]), temp3.loadu(&it3[0]);
(temp0 + temp2).store(&it0[0]);
(temp1 + temp3).storeu(&it1[0]);
temp0 = temp0 - temp2;
temp1 = temp1 - temp3;
c4.load(&cos_it[0]), s4.loadu(&sin_it[0]);
// temp2 = (temp1 * s4).reverse().fmadd(temp0, c4);
temp2 = (temp0 * c4).fmadd(temp1.reverse(), s4.reverse());
temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4);
temp2.store(&it2[0]), temp3.storeu(&it3[0]);
}
it1 -= 4, it3 -= 4, sin_it -= 4;
for (; it0 < in_out + quarter_len; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8, cos_it += 8, sin_it -= 8)
{
DoubleX8 c4, s4, temp0, temp1, temp2, temp3;
temp0.load(&it0[0]), temp1.loadu(&it1[0]);
temp2.load(&it2[0]), temp3.loadu(&it3[0]);
(temp0 + temp2).store(&it0[0]);
(temp1 + temp3).storeu(&it1[0]);
temp0 = temp0 - temp2;
temp1 = temp1 - temp3;
c4.load(&cos_it[0]), s4.loadu(&sin_it[0]);
// temp2 = (temp1 * s4).reverse() + (temp0 * c4);
temp2 = (temp0 * c4).fmadd(temp1.reverse(), s4.reverse());
temp3 = temp3 = (c4.reverse() * temp1).fmsub(temp0.reverse(), s4);
temp2.store(&it2[0]), temp3.storeu(&it3[0]);
}
transform_2point(in_out[0], in_out[half_len]);
transform_2point(in_out[quarter_len], in_out[half_len + quarter_len]);
HalfFHT::dif(in_out);
HalfFHT::dif(in_out + half_len);}};
template <size_t LEN, typename FloatTy>
typename FHT<LEN, FloatTy>::TableTy FHT<LEN, FloatTy>::TABLE; //(FHT<LEN, FloatTy>::log_len, 4, 1);
template <typename FloatTy>
struct FHT<0, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out) {}
template <typename FloatIt>
static void dif(FloatIt in_out) {}};
template <typename FloatTy>
struct FHT<1, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out) {}
template <typename FloatIt>
static void dif(FloatIt in_out) {}};
template <typename FloatTy>
struct FHT<2, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out){
transform_2point(in_out[0], in_out[1]);}
template <typename FloatIt>
static void dif(FloatIt in_out){
transform_2point(in_out[0], in_out[1]);}};
template <typename FloatTy>
struct FHT<4, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out){
auto temp0 = in_out[0], temp1 = in_out[1];
auto temp2 = in_out[2], temp3 = in_out[3];
transform_2point(temp0, temp1);
transform_2point(temp2, temp3);
in_out[0] = temp0 + temp2;
in_out[1] = temp1 + temp3;
in_out[2] = temp0 - temp2;
in_out[3] = temp1 - temp3;}
template <typename FloatIt>
static void dif(FloatIt in_out){
auto temp0 = in_out[0], temp1 = in_out[1];
auto temp2 = in_out[2], temp3 = in_out[3];
transform_2point(temp0, temp2);
transform_2point(temp1, temp3);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;}};
template <typename FloatTy>
struct FHT<8, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out){
auto temp0 = in_out[0], temp1 = in_out[1];
auto temp2 = in_out[2], temp3 = in_out[3];
auto temp4 = in_out[4], temp5 = in_out[5];
auto temp6 = in_out[6], temp7 = in_out[7];
transform_2point(temp0, temp1);
transform_2point(temp2, temp3);
transform_2point(temp4, temp5);
transform_2point(temp6, temp7);
transform_2point(temp0, temp2);
transform_2point(temp1, temp3);
transform_2point(temp4, temp6);
static constexpr decltype(temp0) SQRT_2 = 1.4142135623730950488016887242097;
temp5 *= SQRT_2, temp7 *= SQRT_2;
in_out[0] = temp0 + temp4;
in_out[1] = temp1 + temp5;
in_out[2] = temp2 + temp6;
in_out[3] = temp3 + temp7;
in_out[4] = temp0 - temp4;
in_out[5] = temp1 - temp5;
in_out[6] = temp2 - temp6;
in_out[7] = temp3 - temp7;}
template <typename FloatIt>
static void dif(FloatIt in_out){
auto temp0 = in_out[0], temp1 = in_out[1];
auto temp2 = in_out[2], temp3 = in_out[3];
auto temp4 = in_out[4], temp5 = in_out[5];
auto temp6 = in_out[6], temp7 = in_out[7];
transform_2point(temp0, temp4);
transform_2point(temp1, temp5);
transform_2point(temp2, temp6);
transform_2point(temp3, temp7);
transform_2point(temp0, temp2);
transform_2point(temp1, temp3);
static constexpr decltype(temp0) SQRT_2 = 1.4142135623730950488016887242097;
temp5 *= SQRT_2, temp7 *= SQRT_2;
transform_2point(temp4, temp6);
in_out[0] = temp0 + temp1;
in_out[1] = temp0 - temp1;
in_out[2] = temp2 + temp3;
in_out[3] = temp2 - temp3;
in_out[4] = temp4 + temp5;
in_out[5] = temp4 - temp5;
in_out[6] = temp6 + temp7;
in_out[7] = temp6 - temp7;}};
template <typename FloatTy>
struct FHT<16, FloatTy>{
template <typename FloatIt>
static void dit(FloatIt in_out){
using value_type = typename std::iterator_traits<FloatIt>::value_type;
FHT<4, FloatTy>::dit(in_out + 12);
FHT<4, FloatTy>::dit(in_out + 8);
FHT<8, FloatTy>::dit(in_out);
static constexpr value_type SQRT_2 = 1.4142135623730950488016887242097;
static constexpr value_type COS_16 = 0.9238795325112867561281831893967; // cos(2PI/16);
static constexpr value_type SIN_16 = 0.3826834323650897717284599840304; // sin(2PI/16);
auto temp4 = in_out[9], temp5 = in_out[11];
auto temp0 = temp4 * COS_16 + temp5 * SIN_16;
auto temp2 = temp4 * SIN_16 - temp5 * COS_16;
temp4 = in_out[13], temp5 = in_out[15];
auto temp1 = temp4 * SIN_16 + temp5 * COS_16;
auto temp3 = temp4 * COS_16 - temp5 * SIN_16;
transform_2point(temp0, temp1);
transform_2point(temp3, temp2);
temp4 = in_out[1], temp5 = in_out[3];
in_out[1] = temp4 + temp0, in_out[3] = temp5 + temp1;
in_out[9] = temp4 - temp0, in_out[11] = temp5 - temp1;
temp4 = in_out[5], temp5 = in_out[7];
in_out[5] = temp4 + temp2, in_out[7] = temp5 + temp3;
in_out[13] = temp4 - temp2, in_out[15] = temp5 - temp3;
in_out[10] *= SQRT_2, in_out[14] *= SQRT_2;
transform_2point(in_out[8], in_out[12]);
transform_2point(in_out[0], in_out[8]);
transform_2point(in_out[2], in_out[10]);
transform_2point(in_out[4], in_out[12]);
transform_2point(in_out[6], in_out[14]);}
template <typename FloatIt>
static void dif(FloatIt in_out){
using value_type = typename std::iterator_traits<FloatIt>::value_type;
static constexpr value_type SQRT_2 = 1.4142135623730950488016887242097;
static constexpr value_type COS_16 = 0.9238795325112867561281831893967; // cos(2PI/16);
static constexpr value_type SIN_16 = 0.3826834323650897717284599840304; // sin(2PI/16);
transform_2point(in_out[0], in_out[8]);
transform_2point(in_out[2], in_out[10]);
transform_2point(in_out[4], in_out[12]);
transform_2point(in_out[6], in_out[14]);
transform_2point(in_out[8], in_out[12]);
in_out[10] *= SQRT_2, in_out[14] *= SQRT_2;
auto temp0 = in_out[9], temp1 = in_out[11];
auto temp2 = in_out[13], temp3 = in_out[15];
auto temp4 = in_out[1], temp5 = in_out[3];
in_out[1] = temp4 + temp0;
in_out[3] = temp5 + temp1;
temp0 = temp4 - temp0;
temp1 = temp5 - temp1;
temp4 = in_out[5], temp5 = in_out[7];
in_out[5] = temp4 + temp2;
in_out[7] = temp5 + temp3;
temp2 = temp4 - temp2;
temp3 = temp5 - temp3;
transform_2point(temp0, temp1);
transform_2point(temp3, temp2);
in_out[9] = temp0 * COS_16 + temp2 * SIN_16;
in_out[11] = temp0 * SIN_16 - temp2 * COS_16;
in_out[13] = temp1 * SIN_16 + temp3 * COS_16;
in_out[15] = temp1 * COS_16 - temp3 * SIN_16;
FHT<8, FloatTy>::dif(in_out);
FHT<4, FloatTy>::dif(in_out + 8);
FHT<4, FloatTy>::dif(in_out + 12);}};
// 辅助选择函数
template <size_t LEN = 1>
inline void fht_dit_template_alt(Float64 *input, size_t fht_len){
if (fht_len < LEN){
fht_dit_template_alt<LEN / 2>(input, fht_len);
return;}
FHT<LEN, Float64>::dit(input);}
template <>
inline void fht_dit_template_alt<0>(Float64 *input, size_t fht_len) {}
// 辅助选择函数
template <size_t LEN = 1>
inline void fht_dif_template_alt(Float64 *input, size_t fht_len){
if (fht_len < LEN){
fht_dif_template_alt<LEN / 2>(input, fht_len);
return;}
FHT<LEN, Float64>::dif(input);}
template <>
inline void fht_dif_template_alt<0>(Float64 *input, size_t fht_len) {}
auto fht_dit = fht_dit_template_alt<FHT_MAX_LEN>;
auto fht_dif = fht_dif_template_alt<FHT_MAX_LEN>;
// FHT加速卷积
inline void fht_convolution(Float64 fht_ary1[], Float64 fht_ary2[], Float64 out[], size_t fht_len){
if (fht_len == 0){
return;}
if (fht_len == 1){
out[0] = fht_ary1[0] * fht_ary2[0];
return;}
fht_len = int_floor2(fht_len);
if (fht_len > FHT_MAX_LEN){
throw("FHT len cannot be larger than FHT_MAX_LEN");}
fht_dif(fht_ary1, fht_len);
// 两个输入相同时只进行一次计算,提升平方速度
if (fht_ary1 != fht_ary2){
fht_dif(fht_ary2, fht_len);}
const double inv = 0.5 / fht_len;
out[0] = fht_ary1[0] * fht_ary2[0] / fht_len;
out[1] = fht_ary1[1] * fht_ary2[1] / fht_len;
if (fht_len == 2){
return;}
// DHT的卷积定理
auto temp0 = fht_ary1[2], temp1 = fht_ary1[3];
auto temp2 = fht_ary2[2], temp3 = fht_ary2[3];
transform_2point(temp0, temp1);
out[2] = (temp2 * temp0 + temp3 * temp1) * inv;
out[3] = (temp3 * temp0 - temp2 * temp1) * inv;
for (size_t i = 4; i < fht_len; i *= 2){
auto it0 = fht_ary1 + i, it1 = it0 + i - 1;
auto it2 = fht_ary2 + i, it3 = it2 + i - 1;
auto it4 = out + i, it5 = it4 + i - 1;
for (; it0 < it1; it0 += 2, it1 -= 2, it2 += 2, it3 -= 2, it4 += 2, it5 -= 2)
{
temp0 = *it0, temp1 = *it1, temp2 = *it2, temp3 = *it3;
transform_2point(temp0, temp1);
*it4 = (temp2 * temp0 + temp3 * temp1) * inv;
*it5 = (temp3 * temp0 - temp2 * temp1) * inv;
temp0 = *(it1 - 1), temp1 = *(it0 + 1), temp2 = *(it3 - 1), temp3 = *(it2 + 1);
transform_2point(temp0, temp1);
*(it5 - 1) = (temp2 * temp0 + temp3 * temp1) * inv;
*(it4 + 1) = (temp3 * temp0 - temp2 * temp1) * inv;
}}
fht_dit(out, fht_len);}}
}
inline std::string ui64to_string(uint64_t input, uint8_t digits)
{
std::string result(digits, '0');
for (uint8_t i = 0; i < digits; i++){
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;}
return result;
}
constexpr uint64_t stoui64(const char *s, size_t dig = 4)
{
uint64_t result = 0;
for (size_t i = 0; i < dig; i++){
result *= 10;
result += (s[i] - '0');}
return result;
}
constexpr uint32_t stobase10000(const char *s)
{
return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
}
constexpr uint32_t stobase100000(const char *s)
{
return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
}
static constexpr int DIGIT = 4;
constexpr uint64_t BASE = 10000;
inline size_t char_to_float64(const char *buffer, double *float_ary, size_t str_len)
{
int64_t len = str_len, pos = len, i = 0;
len = (len + DIGIT - 1) / DIGIT;
while (pos - DIGIT > 0){
// hint::uint64_t tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
uint32_t tmp = stobase10000(buffer + pos - DIGIT);
float_ary[i] = tmp;
i++;
pos -= DIGIT;}
if (pos > 0){
uint32_t tmp = stoui64(buffer, pos);
float_ary[i] = tmp;}
return len;
}
class ItoStrBase10000
{
private:
uint32_t table[10000]{};
public:
static constexpr uint32_t itosbase10000(uint32_t num){
uint32_t res = '0' * 0x1010101;
res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
((num / 10 % 10) << 16) | ((num % 10) << 24);
return res;}
constexpr ItoStrBase10000(){
for (size_t i = 0; i < 10000; i++){
table[i] = itosbase10000(i);}}
void tostr(char *str, uint32_t num) const{
*reinterpret_cast<uint32_t *>(str) = table[num];}
uint32_t tostr(uint32_t num) const{
return table[num];}
};
// 读取两个数字字符串
void read_2num_str(const char *s, const char *&a, size_t &len_a, const char *&b, size_t &len_b)
{
while (!isdigit(*s)){
s++;}
a = s;
while (*s >= '0'){
s++;}
len_a = s - a;
while (!isdigit(*s)){
s++;}
b = s;
len_b = strlen(b);
while (!isdigit(b[len_b - 1])){
len_b--;}
}
}
using namespace std;
using namespace hint;
using namespace hint_simd;
using namespace hint_transform;
using namespace hint_fht;
int main()
{
constexpr size_t STR_LEN = 2000008;
static constexpr ItoStrBase10000 transfer;
static AlignAry<char, STR_LEN> out;
static AlignAry<Float64, FHT_MAX_LEN> fht_ary1;
static AlignAry<Float64, FHT_MAX_LEN> fht_ary2;
uint32_t *ary = out.template cast_ptr<uint32_t>();
size_t len_a = 0, len_b = 0;
fread(out.data(), 1, STR_LEN, stdin);
const char *a, *b;
read_2num_str(out.data(), a, len_a, b, len_b);
/*
struct stat sb;
int fd = fileno(stdin);
fstat(fd, &sb);
p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
madvise(p, sb.st_size, MADV_SEQUENTIAL);
*/
if (len_a == 1 && a[0] == '0')
{
puts("0");
return 0;
}
if (len_b == 1 && b[0] == '0')
{
puts("0");
return 0;
} // 0.46ms
size_t len2 = char_to_float64(b, fht_ary2.data(), len_b); // 1.67ms
size_t len1 = char_to_float64(a, fht_ary1.data(), len_a);
size_t fht_len = int_ceil2(len1 + len2 - 1);
fht_convolution(fht_ary1.data(), fht_ary2.data(), fht_ary1.data(), fht_len);
uint64_t carry = 0;
size_t pos = STR_LEN / 4 - 1;
for (size_t i = 0; i < len1 + len2 - 1; i++)
{
carry += uint64_t(abs(fht_ary1[i]) + 0.5);
uint64_t num = 0;
std::tie(carry, num) = div_mod<uint64_t>(carry, BASE);
ary[pos] = transfer.tostr(num);
pos--;
}
ary[pos] = transfer.tostr(carry);
pos *= 4;
while (out[pos] == '0')
{
pos++;
} // 0.8ms
fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 11.703 ms | 13 MB + 872 KB | Accepted | Score: 100 | 显示更多 |