#include <algorithm>
#include <atomic>
#include <complex>
#include <future>
#include <iostream>
#include <random>
#include <stack>
#include <string>
#include <thread>
#include <tuple>
#include <cassert>
#include <climits>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HINT_MATH_HPP
#define HINT_MATH_HPP
// #define MULTITHREAD
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using Complex = std::complex<double>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = HINT_PI * 2;
constexpr double HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
#ifdef MULTITHREAD
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(std::log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
double cas(double x)
{
return std::cos(x) + std::sin(x);
}
/// @brief 生成不大于n的最大的2的幂次的数
/// @param n
/// @return 不大于n的最大的2的幂次的数
inline UINT_64 max_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::floor(std::log2(n)));
}
/// @brief 生成不小于n的最小的2的幂次的数
/// @param n
/// @return 不小于n的最小的2的幂次的数
inline UINT_64 min_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::ceil(std::log2(n)));
}
template <typename T>
constexpr bool is_neg(T x)
{
return x < 0;
}
template <typename T>
constexpr bool is_odd(T x)
{
return static_cast<bool>(x & 1);
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
// 取模快速幂
constexpr UINT_64 qpow(UINT_64 m, UINT_64 n, UINT_64 mod)
{
if (m <= 1)
{
return m;
}
UINT_64 result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m % mod;
}
m = m * m % mod;
n >>= 1;
}
return result;
}
// 利用编译器优化一次性算出商和余数
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
// 无溢出64位与32位乘
constexpr std::pair<UINT_32, UINT_64> safe_mul(UINT_64 a, UINT_32 b)
{
UINT_64 tmp1 = a & HINT_INT32_0XFF;
UINT_64 tmp2 = a >> HINT_INT_BIT;
tmp1 *= b;
tmp2 *= b;
tmp2 += (tmp1 >> HINT_INT_BIT);
tmp1 &= HINT_INT32_0XFF;
return std::make_pair(tmp2 >> HINT_INT_BIT, (tmp2 << HINT_INT_BIT) + tmp1);
}
// 无溢出64位加法
constexpr std::pair<UINT_64, bool> safe_add(UINT_64 a, UINT_64 b)
{
UINT_64 tmp = HINT_INT64_0XFF - b;
if (a > tmp)
{
return std::make_pair(a - tmp - 1, true);
}
return std::make_pair(a + b, false);
}
// 无溢出64位减法
constexpr std::pair<UINT_64, bool> safe_sub(UINT_64 a, UINT_64 b)
{
if (a >= b)
{
return std::make_pair(a - b, false);
}
UINT_64 tmp = HINT_INT64_0XFF - b;
return std::make_pair(a + tmp + 1, true);
}
// 96位整数除以64位整数
constexpr UINT_64 div_3by2(std::pair<UINT_64, UINT_32> dividend, UINT_64 divisor)
{
UINT_64 rem = 0;
UINT_64 quo = 0;
auto tmp = div_mod(dividend.first, divisor);
quo = tmp.first << 32;
rem = tmp.second << 32;
rem += dividend.second;
auto tmp2 = div_mod(rem, divisor);
quo += tmp2.first;
return quo;
}
// 最大公因数
constexpr UINT_64 gcd(UINT_64 a, UINT_64 b)
{
while (b > 0)
{
UINT_64 tmp = b;
b = a % b;
a = tmp;
}
return a;
}
// 中国剩余定理
UINT_64 crt(UINT_64 mods[], UINT_64 nums[], size_t n)
{
UINT_64 result = 0, mod_product = 1;
for (size_t i = 0; i < n; i++)
{
mod_product *= mods[i];
}
for (size_t i = 0; i < n; i++)
{
UINT_64 mod = mods[i];
UINT_64 tmp = mod_product / mod;
UINT_64 inv = qpow(tmp, mod - 2, mod);
result += nums[i] * tmp * inv % mod_product;
}
return result % mod_product;
}
/// @brief 快速计算两模数的中国剩余定理,返回n
/// @param num1 n模除mod1的余数
/// @param num2 n模除mod2的余数
/// @param mod1 第一个模数
/// @param mod2 第二个模数
/// @param inv1 第一个模数在第二个模数下的逆元
/// @param inv2 第二个模数在第一个模数下的逆元
/// @return n的最小解
constexpr UINT_64 qcrt(UINT_64 num1, UINT_64 num2,
UINT_64 mod1, UINT_64 mod2,
UINT_64 inv1, UINT_64 inv2)
{
if (num1 > num2)
{
return ((num1 - num2) * inv2 % mod1) * mod2 + num2;
}
else
{
return ((num2 - num1) * inv1 % mod2) * mod1 + num1;
}
}
// 模板数组拷贝
template <typename T>
void ary_copy(T *target, const T *source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::memcpy(target, source, len * sizeof(T));
}
// 模板数组拷贝
template <typename T1, typename T2>
void ary_copy_2type(T1 *target, const T2 *source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = static_cast<T1>(source[i]);
}
}
// 从其他类型数组拷贝到复数组实部
template <typename T>
inline void com_ary_real_copy(Complex *target, const T &source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = Complex(source[i], target[i].imag());
}
}
// 从其他类型数组拷贝到复数组虚部
template <typename T>
inline void com_ary_img_copy(Complex *target, const T &source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = Complex(target[i].real(), source[i]);
}
}
// 从其他类型数组拷贝到复数组
template <typename T>
inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
{
size_t min_len = std::min(len1, len2);
size_t i = 0;
while (i < min_len)
{
target[i] = Complex(source1[i], source2[i]);
i++;
}
while (i < len1)
{
target[i].real(source1[i]);
i++;
}
while (i < len2)
{
target[i].imag(source2[i]);
i++;
}
}
// 去除数组前导零后的长度
template <typename T>
constexpr size_t ary_true_len(const T &ary, size_t len)
{
while (len > 0 && ary[len - 1] == 0)
{
len--;
}
return len;
}
// 模版数组分配内存且清零
template <typename T>
inline void ary_calloc(T *&ptr, size_t len)
{
ptr = static_cast<T *>(calloc(len, sizeof(T)));
}
// 模版数组清零
template <typename T>
inline void ary_clr(T *ptr, size_t len)
{
memset(ptr, 0, len * sizeof(T));
}
// 重分配空间
template <typename T>
inline T *ary_realloc(T *ptr, size_t len)
{
if (len * sizeof(T) < INT64_MAX)
{
ptr = static_cast<T *>(realloc(ptr, len * sizeof(T)));
}
if (ptr == nullptr)
{
throw("realloc error");
}
return ptr;
}
// 数组交错重排
template <UINT_64 N, typename T>
void ary_interlace(T ary[], size_t len)
{
size_t sub_len = len / N;
T *tmp_ary = new T[len - sub_len];
for (size_t i = 0; i < len; i += N)
{
ary[i / N] = ary[i];
for (size_t j = 0; j < N - 1; j++)
{
tmp_ary[j * sub_len + i / N] = ary[i + j + 1];
}
}
ary_copy(ary + sub_len, tmp_ary, len - sub_len);
delete[] tmp_ary;
}
namespace hint_transform
{
class UnitTable
{
private:
Complex *table = nullptr;
INT_32 max_log_size = 3;
INT_32 cur_log_size = 3;
UnitTable(const UnitTable &) = delete;
UnitTable &operator=(const UnitTable &) = delete;
public:
~UnitTable()
{
if (table != nullptr)
{
delete[] table;
table = nullptr;
}
}
// 初始化可以生成平分圆1<<shift份产生的单位根的表
UnitTable(UINT_32 max_shift)
{
max_shift = std::max<size_t>(max_shift, 3);
max_log_size = max_shift;
size_t ary_size = (1ull << (max_shift - 1)) - 2;
table = new Complex[ary_size];
for (size_t pos = 0; pos < 2; pos++)
{
table[pos] = unit_root(pos * HINT_PI / 4);
}
// expend(max_shift - 3);
}
void expend(INT_32 shift)
{
shift = std::min(shift, max_log_size);
shift = std::max(shift, 4);
for (INT_32 i = cur_log_size - 1; i <= shift - 2; i++)
{
size_t len = 1ull << i;
size_t begin = len - 2;
for (size_t pos = 0; pos < len; pos++)
{
table[pos + begin] = unit_root(pos * HINT_PI / (len * 2));
}
}
cur_log_size = std::max(cur_log_size, shift);
}
// 返回单位圆上辐角为theta的点
static Complex unit_root(double theta)
{
return std::polar<double>(1.0, theta);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex get_complex(INT_32 shift, size_t n)
{
size_t rank = 1ull << shift;
n &= (rank - 1);
size_t zone = (n << 2) >> shift; // 第几象限
if (((n << 2) & (rank - 1)) == 0)
{
constexpr Complex ary[4] = {Complex(1, 0), Complex(0, 1), Complex(-1, 0), Complex(0, -1)};
return ary[zone];
}
rank >>= 2;
const Complex *ptr = table + rank - 2;
Complex tmp;
switch (zone)
{
case 0:
tmp = ptr[n];
break;
case 1:
tmp = ptr[(rank << 1) - n];
tmp.real(-tmp.real());
break;
case 2:
tmp = -ptr[n - (rank << 1)];
break;
case 3:
tmp = std::conj(ptr[(rank << 2) - n]);
break;
default:
break;
}
return tmp;
}
};
constexpr size_t lut_max_rank = 10;
UnitTable TABLE(lut_max_rank); // 初始化fft表
// 二进制逆序
template <typename T>
void binary_inverse_swap(T &ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
// 四进制逆序
template <size_t BATCH = 1, typename SizeType = UINT_32, typename T>
void quaternary_inverse_swap(T &ary, size_t len)
{
SizeType log_n = std::log2(len);
SizeType *rev = new SizeType[len / 4];
rev[0] = 0;
for (SizeType i = 1; i < len; i++)
{
SizeType index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
if (i < len / 4)
{
rev[i] = index;
}
if (i < index)
{
for (size_t j = 0; j < BATCH; j++)
{
std::swap(ary[i * BATCH + j], ary[index * BATCH + j]);
}
}
}
delete[] rev;
}
// fft基2时间抽取蝶形变换
inline void fft_radix2_dit_butterfly(Complex omega, Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
input[pos] = tmp1 + tmp2;
input[pos + rank] = tmp1 - tmp2;
}
// fft基2频率抽取蝶形变换
inline void fft_radix2_dif_butterfly(Complex omega, Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank];
input[pos] = tmp1 + tmp2;
input[pos + rank] = (tmp1 - tmp2) * omega;
}
// fft基4时间抽取蝶形变换
inline void fft_radix4_dit_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
Complex tmp3 = input[pos + rank * 2] * omega_sqr;
Complex tmp4 = input[pos + rank * 3] * omega_cube;
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-t4.imag(), t4.real());
input[pos] = t1 + t2;
input[pos + rank] = t3 - t4;
input[pos + rank * 2] = t1 - t2;
input[pos + rank * 3] = t3 + t4;
}
// fft基4频率抽取蝶形变换
inline void fft_radix4_dif_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank];
Complex tmp3 = input[pos + rank * 2];
Complex tmp4 = input[pos + rank * 3];
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-t4.imag(), t4.real());
input[pos] = (t1 + t2);
input[pos + rank] = (t3 - t4) * omega;
input[pos + rank * 2] = (t1 - t2) * omega_sqr;
input[pos + rank * 3] = (t3 + t4) * omega_cube;
}
// 2点fft
inline void fft_2point(Complex &sum, Complex &diff)
{
Complex tmp1 = sum;
Complex tmp2 = diff;
sum = tmp1 + tmp2;
diff = tmp1 - tmp2;
}
// 4点fft
inline void fft_4point(Complex *input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank];
Complex tmp3 = input[pos + rank * 2];
Complex tmp4 = input[pos + rank * 3];
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-t4.imag(), t4.real());
input[pos] = t1 + t2;
input[pos + rank] = t3 - t4;
input[pos + rank * 2] = t1 - t2;
input[pos + rank * 3] = t3 + t4;
}
// 求共轭复数及归一化,逆变换用
inline void fft_conj(Complex *input, size_t fft_len, double div = 1)
{
for (size_t i = 0; i < fft_len; i++)
{
input[i] = std::conj(input[i]) / div;
}
}
// 归一化,逆变换用
inline void fft_normalize(Complex *input, size_t fft_len)
{
double len = static_cast<double>(fft_len);
for (size_t i = 0; i < fft_len; i++)
{
input[i] /= len;
}
}
// 经典模板,学习用
void fft_radix2_dit(Complex *input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
binary_inverse_swap(input, fft_len);
for (size_t rank = 1; rank < fft_len; rank *= 2)
{
// rank表示上一级fft的长度,gap表示由两个上一级可以迭代计算出这一级的长度
size_t gap = rank * 2;
Complex unit_omega = std::polar<double>(1, -HINT_2PI / gap);
for (size_t begin = 0; begin < fft_len; begin += gap)
{
Complex omega = Complex(1, 0);
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
input[pos] = tmp1 + tmp2;
input[pos + rank] = tmp1 - tmp2;
omega *= unit_omega;
}
}
}
}
// 基4快速傅里叶变换,模板,学习用
void fft_radix4_dit(Complex *input, size_t fft_len)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
quaternary_inverse_swap(input, fft_len);
for (size_t pos = 0; pos < fft_len; pos += 4)
{
fft_4point(input, pos, 1);
}
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
// rank表示上一级fft的长度,gap表示由四个上一级可以迭代计算出这一级的长度
size_t gap = rank * 4;
Complex unit_omega = std::polar<double>(1, -HINT_2PI / gap);
Complex unit_sqr = std::polar<double>(1, -HINT_2PI * 2 / gap);
Complex unit_cube = std::polar<double>(1, -HINT_2PI * 3 / gap);
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_4point(input, begin, rank);
Complex omega = unit_omega;
Complex omega_sqr = unit_sqr;
Complex omega_cube = unit_cube;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
fft_radix4_dit_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
omega *= unit_omega;
omega_sqr *= unit_sqr;
omega_cube *= unit_cube;
}
}
}
}
// 基2查表快速傅里叶变换
void fft_radix2_dit_lut(Complex *input, size_t fft_len, bool bit_inv = true)
{
fft_len = max_2pow(fft_len);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
if (fft_len == 2)
{
fft_2point(input[0], input[1]);
return;
}
if (bit_inv)
{
binary_inverse_swap(input, fft_len);
}
for (size_t i = 0; i < fft_len; i += 2)
{
fft_2point(input[i], input[i + 1]);
}
INT_32 shift = 2;
for (size_t rank = 2; rank < fft_len / 4; rank *= 2)
{
size_t gap = rank * 2;
for (size_t begin = 0; begin < fft_len; begin += (gap * 2))
{
fft_2point(input[begin], input[begin + rank]);
fft_2point(input[gap + begin], input[gap + begin + rank]);
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
Complex omega = std::conj(TABLE.get_complex(shift, pos - begin));
fft_radix2_dit_butterfly(omega, input, pos, rank);
fft_radix2_dit_butterfly(omega, input, pos + gap, rank);
}
}
shift++;
}
for (size_t rank = fft_len / 4; rank < fft_len; rank *= 2)
{
size_t gap = rank * 2;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_2point(input[begin], input[begin + rank]);
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
Complex omega = std::conj(TABLE.get_complex(shift, pos - begin));
fft_radix2_dit_butterfly(omega, input, pos, rank);
}
}
shift++;
}
}
// 基2查表快速傅里叶变换
void fft_radix2_dif_lut(Complex *input, size_t fft_len, bool bit_inv = true)
{
fft_len = max_2pow(fft_len);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
if (fft_len == 2)
{
fft_2point(input[0], input[1]);
return;
}
INT_32 shift = std::log2(fft_len);
for (size_t rank = fft_len / 2, gap; rank > fft_len / 4; rank /= 2)
{
gap = rank * 2;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex omega = std::conj(TABLE.get_complex(shift, pos - begin));
fft_radix2_dif_butterfly(omega, input, pos, rank);
}
}
shift--;
}
for (size_t rank = fft_len / 4, gap; rank > 1; rank /= 2)
{
gap = rank * 2;
for (size_t begin = 0; begin < fft_len; begin += gap * 2)
{
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex omega = std::conj(TABLE.get_complex(shift, pos - begin));
fft_radix2_dif_butterfly(omega, input, pos, rank);
fft_radix2_dif_butterfly(omega, input, pos + gap, rank);
}
}
shift--;
}
for (size_t i = 0; i < fft_len; i += 2)
{
fft_2point(input[i], input[i + 1]);
}
if (bit_inv)
{
binary_inverse_swap(input, fft_len);
}
}
void fft_radix4_dit_lut(Complex *input, size_t fft_len, bool bit_inv = true)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
if (bit_inv)
{
quaternary_inverse_swap(input, fft_len);
}
for (size_t pos = 0; pos < fft_len; pos += 4)
{
fft_4point(input, pos, 1);
}
UINT_32 shift = 4;
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_4point(input, begin, rank);
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
size_t count = pos - begin;
Complex omega = std::conj(TABLE.get_complex(shift, count));
Complex omega_sqr = std::conj(TABLE.get_complex(shift, count * 2));
Complex omega_cube = std::conj(TABLE.get_complex(shift, count * 3));
fft_radix4_dit_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
}
}
shift += 2;
}
}
void fft_radix4_dif_lut(Complex *input, size_t fft_len, bool bit_inv = true)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
if (fft_len == 4)
{
fft_4point(input, 0, 1);
return;
}
UINT_32 shift = log4_len * 2;
for (size_t rank = fft_len / 4; rank > 1; rank /= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_4point(input, begin, rank);
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
size_t count = pos - begin;
Complex omega = std::conj(TABLE.get_complex(shift, count));
Complex omega_sqr = std::conj(TABLE.get_complex(shift, count * 2));
Complex omega_cube = std::conj(TABLE.get_complex(shift, count * 3));
fft_radix4_dif_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
}
}
shift -= 2;
}
for (size_t pos = 0; pos < fft_len; pos += 4)
{
fft_4point(input, pos, 1);
}
if (bit_inv)
{
quaternary_inverse_swap(input, fft_len);
}
}
// 批量基4fft,BATCH为1时为正常基4fft
template <size_t BATCH = 1>
void fft_radix4_batch(Complex *input, size_t fft_len, bool bit_inv = true)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
if (bit_inv)
{
quaternary_inverse_swap<BATCH>(input, fft_len);
}
for (size_t pos = 0; pos < fft_len; pos += 4)
{
for (size_t j = 0; j < BATCH; j++)
{
fft_4point(input, j + pos * BATCH, BATCH);
}
}
UINT_32 shift = 4;
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
for (size_t j = 0; j < BATCH; j++)
{
fft_4point(input, j + begin * BATCH, rank * BATCH);
}
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
size_t count = pos - begin;
Complex omega = std::conj(TABLE.get_complex(shift, count));
Complex omega_sqr = std::conj(TABLE.get_complex(shift, count * 2));
Complex omega_cube = std::conj(TABLE.get_complex(shift, count * 3));
for (size_t j = 0; j < BATCH; j++)
{
fft_radix4_dit_butterfly(omega, omega_sqr, omega_cube, input, j + pos * BATCH, rank * BATCH);
}
}
}
shift += 2;
}
}
// 使用2批量计算的非基4fft
void fft_radix2_bat2(Complex *input, size_t fft_len)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len == 1)
{
return;
}
if (fft_len == 2)
{
fft_2point(input[0], input[1]);
return;
}
size_t sub_len = fft_len / 2;
fft_radix4_batch<2>(input, sub_len);
for (size_t i = 0; i < fft_len; i += 2)
{
Complex omega = std::conj(TABLE.get_complex(log_len, i / 2));
fft_radix2_dit_butterfly(omega, input, i, 1);
}
ary_interlace<2>(input, fft_len);
}
// 使用4批量计算的基4fft
void fft_radix4_bat4(Complex *input, size_t fft_len)
{
size_t log_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("FFT length too long for lut\n");
}
if (fft_len == 1)
{
return;
}
if (fft_len == 4)
{
fft_4point(input, 0, 1);
return;
}
log_len *= 2;
size_t quarter_len = fft_len / 4;
fft_radix4_batch<4>(input, quarter_len);
for (size_t i = 0; i < fft_len; i += 4)
{
Complex omega = std::conj(TABLE.get_complex(log_len, i / 4));
Complex omega_sqr = std::conj(TABLE.get_complex(log_len, i / 2));
Complex omega_cube = std::conj(TABLE.get_complex(log_len, i * 3 / 4));
fft_radix4_dit_butterfly(omega, omega_sqr, omega_cube, input, i, 1);
}
ary_interlace<4>(input, fft_len);
}
void fft_dif(Complex *input, size_t fft_len, bool r4_bit_inv = true)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len <= 1)
{
return;
}
TABLE.expend(log_len);
if (is_odd(log_len))
{
fft_radix2_bat2(input, fft_len);
}
else
{
fft_radix4_dif_lut(input, fft_len, r4_bit_inv);
}
}
void fft_dit(Complex *input, size_t fft_len, bool r4_bit_inv = true)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len <= 1)
{
return;
}
TABLE.expend(log_len);
if (is_odd(log_len))
{
fft_radix2_bat2(input, fft_len);
}
else
{
fft_radix4_dit_lut(input, fft_len, r4_bit_inv);
}
}
/// @brief 查表快速傅里叶变换
/// @param input 复数组
/// @param fft_len 变换长度
/// @param r4_bit_inv 基4是否进行比特逆序,与逆变换同时设为false可以提高性能
void fft_lut(Complex *input, size_t fft_len, bool r4_bit_inv = true)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len <= 1)
{
return;
}
fft_dit(input, fft_len, r4_bit_inv);
}
/// @brief 查表快速傅里叶逆变换
/// @param input 复数组
/// @param fft_len 变换长度
/// @param r4_bit_inv 基4是否进行比特逆序,与逆变换同时设为false可以提高性能
void ifft_lut(Complex *input, size_t fft_len, bool r4_bit_inv = true)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len <= 1)
{
return;
}
fft_len = max_2pow(fft_len);
fft_conj(input, fft_len);
fft_dif(input, fft_len, r4_bit_inv);
fft_conj(input, fft_len, fft_len);
}
/// @brief 快速哈特莱变换
/// @param input 浮点数组指针
/// @param fht_len 变换的长度
/// @param is_ifht 是否为逆变换
void fht(double *input, size_t fht_len)
{
fht_len = max_2pow(fht_len);
if (fht_len <= 1)
{
return;
}
UINT_32 log_len = std::log2(fht_len);
TABLE.expend(log_len);
binary_inverse_swap(input, fht_len);
for (size_t i = 0; i < fht_len; i += 2)
{
double tmp1 = input[i];
double tmp2 = input[i + 1];
input[i] = tmp1 + tmp2;
input[i + 1] = tmp1 - tmp2;
}
UINT_32 shift = 2;
for (size_t rank = 2; rank < fht_len; rank *= 2)
{
size_t gap = rank * 2;
size_t half = rank / 2;
for (size_t begin = 0; begin < fht_len; begin += gap)
{
size_t index1 = begin, index2 = begin + half;
size_t index3 = begin + rank, index4 = begin + half * 3;
double tmp1 = input[index1];
double tmp2 = input[index3];
input[index1] = tmp1 + tmp2;
input[index3] = tmp1 - tmp2;
tmp1 = input[index2];
tmp2 = input[index4];
input[index2] = tmp1 + tmp2;
input[index4] = tmp1 - tmp2;
for (size_t pos = 1; pos < half; pos++)
{
index1 = begin + pos;
index2 = rank + begin - pos;
index3 = rank + begin + pos;
index4 = gap + begin - pos;
double tmp1 = input[index1];
double tmp2 = input[index2];
double tmp3 = input[index3];
double tmp4 = input[index4];
Complex omega = TABLE.get_complex(shift, pos);
double t1 = tmp3 * omega.real() + tmp4 * omega.imag();
double t2 = tmp3 * omega.imag() - tmp4 * omega.real();
input[index1] = tmp1 + t1;
input[index2] = tmp2 + t2;
input[index3] = tmp1 - t1;
input[index4] = tmp2 - t2;
}
}
shift++;
}
}
void ifht(double *input, size_t fht_len)
{
fht_len = max_2pow(fht_len);
fht(input, fht_len);
double len = fht_len;
for (size_t i = 0; i < fht_len; i++)
{
input[i] /= len;
}
}
constexpr UINT_64 add_mod(UINT_64 a, UINT_64 b, UINT_64 mod)
{
return a + b < mod ? a + b : a + b - mod;
}
constexpr UINT_64 sub_mod(UINT_64 a, UINT_64 b, UINT_64 mod)
{
return a < b ? a + mod - b : a - b;
}
template <INT_64 MOD>
constexpr INT_64 mul_mod(INT_64 a, INT_64 b)
{
if (((a | b) >> 32) == 0)
{
return a * b % MOD;
}
INT_64 result = a * b - static_cast<UINT_64>(1.0 * a / MOD * b) * MOD;
if (result < 0)
{
result += MOD;
}
else if (result >= MOD)
{
result -= MOD;
}
return result;
}
// 归一化,逆变换用
template <UINT_64 MOD>
inline void ntt_normalize(UINT_32 *input, size_t ntt_len)
{
const UINT_64 inv = qpow(ntt_len, MOD - 2, MOD);
for (size_t i = 0; i < ntt_len; i++)
{
input[i] = inv * input[i] % MOD;
}
}
// 基2时间抽取ntt蝶形
template <UINT_64 MOD>
constexpr void ntt_radix2_dit_butterfly(UINT_64 omega, UINT_32 *input, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank] * omega % MOD;
input[pos] = add_mod(tmp1, tmp2, MOD);
input[pos + rank] = sub_mod(tmp1, tmp2, MOD);
}
// 基2频率抽取ntt蝶形
template <UINT_64 MOD>
constexpr void ntt_radix2_dif_butterfly(UINT_64 omega, UINT_32 *input, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank];
input[pos] = add_mod(tmp1, tmp2, MOD);
input[pos + rank] = sub_mod(tmp1, tmp2, MOD) * omega % MOD;
}
// ntt基4时间抽取蝶形变换
template <UINT_64 MOD = 2281701377>
constexpr void ntt_radix4_dit_butterfly(UINT_64 omega, UINT_64 omega_sqr, UINT_64 omega_cube,
UINT_64 quarter, UINT_32 *input, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank] * omega % MOD;
UINT_32 tmp3 = input[pos + rank * 2] * omega_sqr % MOD;
UINT_32 tmp4 = input[pos + rank * 3] * omega_cube % MOD;
UINT_32 t1 = add_mod(tmp1, tmp3, MOD);
UINT_32 t2 = add_mod(tmp2, tmp4, MOD);
UINT_32 t3 = sub_mod(tmp1, tmp3, MOD);
UINT_32 t4 = (MOD + tmp2 - tmp4) * quarter % MOD;
input[pos] = add_mod(t1, t2, MOD);
input[pos + rank] = add_mod(t3, t4, MOD);
input[pos + rank * 2] = sub_mod(t1, t2, MOD);
input[pos + rank * 3] = sub_mod(t3, t4, MOD);
}
// ntt基4频率抽取蝶形变换
template <UINT_64 MOD = 2281701377>
constexpr void ntt_radix4_dif_butterfly(UINT_64 omega, UINT_64 omega_sqr, UINT_64 omega_cube,
UINT_64 quarter, UINT_32 *input, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank];
UINT_32 tmp3 = input[pos + rank * 2];
UINT_32 tmp4 = input[pos + rank * 3];
UINT_32 t1 = add_mod(tmp1, tmp3, MOD);
UINT_32 t2 = add_mod(tmp2, tmp4, MOD);
UINT_32 t3 = sub_mod(tmp1, tmp3, MOD);
UINT_32 t4 = sub_mod(tmp2, tmp4, MOD) * quarter % MOD;
input[pos] = add_mod(t1, t2, MOD);
input[pos + rank] = add_mod(t3, t4, MOD) * omega % MOD;
input[pos + rank * 2] = sub_mod(t1, t2, MOD) * omega_sqr % MOD;
input[pos + rank * 3] = sub_mod(t3, t4, MOD) * omega_cube % MOD;
}
template <UINT_64 MOD>
constexpr void ntt_2point(UINT_32 &sum, UINT_32 &diff)
{
UINT_32 tmp1 = sum;
UINT_32 tmp2 = diff;
sum = add_mod(tmp1, tmp2, MOD);
diff = sub_mod(tmp1, tmp2, MOD);
}
// 4点NTT
template <UINT_64 MOD>
constexpr void ntt_4point(UINT_32 *input, UINT_64 quarter, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank];
UINT_32 tmp3 = input[pos + rank * 2];
UINT_32 tmp4 = input[pos + rank * 3];
UINT_32 t1 = add_mod(tmp1, tmp3, MOD);
UINT_32 t2 = add_mod(tmp2, tmp4, MOD);
UINT_32 t3 = sub_mod(tmp1, tmp3, MOD);
UINT_32 t4 = (MOD + tmp2 - tmp4) * quarter % MOD;
input[pos] = add_mod(t1, t2, MOD);
input[pos + rank] = add_mod(t3, t4, MOD);
input[pos + rank * 2] = sub_mod(t1, t2, MOD);
input[pos + rank * 3] = sub_mod(t3, t4, MOD);
}
// 基2时间抽取ntt,未对数组做位逆序
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_radix2_dit(UINT_32 *input, size_t ntt_len, bool bit_inv = false)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
if (ntt_len == 2)
{
ntt_2point<MOD>(input[0], input[1]);
return;
}
if (bit_inv)
{
binary_inverse_swap(input, ntt_len);
}
for (size_t pos = 0; pos < ntt_len; pos += 2)
{
ntt_2point<MOD>(input[pos], input[pos + 1]);
}
constexpr size_t THRESHOLD = 4;
for (size_t rank = 2; rank < ntt_len / THRESHOLD; rank *= 2)
{
size_t gap = rank * 2;
UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / gap, MOD);
for (size_t begin = 0; begin < ntt_len; begin += (gap * 2))
{
ntt_2point<MOD>(input[begin], input[begin + rank]);
ntt_2point<MOD>(input[begin + gap], input[begin + rank + gap]);
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
ntt_radix2_dit_butterfly<MOD>(omega, input, pos, rank);
ntt_radix2_dit_butterfly<MOD>(omega, input, pos + gap, rank);
omega = omega * unit_omega % MOD;
}
}
}
for (size_t rank = ntt_len / THRESHOLD; rank < ntt_len; rank *= 2)
{
size_t gap = rank * 2;
UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / gap, MOD);
for (size_t begin = 0; begin < ntt_len; begin += gap)
{
ntt_2point<MOD>(input[begin], input[begin + rank]);
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
ntt_radix2_dit_butterfly<MOD>(omega, input, pos, rank);
omega = omega * unit_omega % MOD;
}
}
}
}
// 基2频率抽取ntt,未对数组做位逆序
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_radix2_dif(UINT_32 *input, size_t ntt_len, bool bit_inv = true)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
if (ntt_len == 2)
{
ntt_2point<MOD>(input[0], input[1]);
return;
}
UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / ntt_len, MOD);
constexpr size_t THRESHOLD1 = 4;
for (size_t rank = ntt_len / 2; rank > ntt_len / THRESHOLD1; rank /= 2)
{
size_t gap = rank * 2;
for (size_t begin = 0; begin < ntt_len; begin += gap)
{
ntt_2point<MOD>(input[begin], input[begin + rank]);
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
ntt_radix2_dif_butterfly<MOD>(omega, input, pos, rank);
omega = omega * unit_omega % MOD;
}
}
unit_omega = unit_omega * unit_omega % MOD;
}
for (size_t rank = ntt_len / THRESHOLD1; rank > 1; rank /= 2)
{
size_t gap = rank * 2;
for (size_t begin = 0; begin < ntt_len; begin += (gap * 2))
{
ntt_2point<MOD>(input[begin], input[begin + rank]);
ntt_2point<MOD>(input[begin + gap], input[begin + rank + gap]);
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
ntt_radix2_dif_butterfly<MOD>(omega, input, pos, rank);
ntt_radix2_dif_butterfly<MOD>(omega, input, pos + gap, rank);
omega = omega * unit_omega % MOD;
}
}
unit_omega = unit_omega * unit_omega % MOD;
}
for (size_t pos = 0; pos < ntt_len; pos += 2)
{
ntt_2point<MOD>(input[pos], input[pos + 1]);
}
if (bit_inv)
{
binary_inverse_swap(input, ntt_len);
}
}
// 批量基4快速数论变换,默认BATCH为1时为普通基4变换
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3, size_t BATCH = 1>
void ntt_radix4_dit(UINT_32 *input, size_t ntt_len, bool bit_inv = true)
{
size_t log4_len = std::log2(ntt_len) / 2;
ntt_len = 1ull << (log4_len * 2);
if (ntt_len <= 1)
{
return;
}
if (bit_inv)
{
quaternary_inverse_swap<BATCH>(input, ntt_len);
}
constexpr UINT_64 quarter = qpow(G_ROOT, (MOD - 1) / 4, MOD); // 等价于复数的i
for (size_t pos = 0; pos < ntt_len; pos += 4)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + pos * BATCH, BATCH);
}
}
for (size_t rank = 4; rank < ntt_len; rank *= 4)
{
size_t gap = rank * 4;
UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / gap, MOD);
for (size_t begin = 0; begin < ntt_len; begin += gap)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + begin * BATCH, rank * BATCH);
}
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
UINT_64 omega_sqr = omega * omega % MOD;
UINT_64 omega_cube = omega_sqr * omega % MOD;
for (size_t j = 0; j < BATCH; j++)
{
ntt_radix4_dit_butterfly<MOD>(omega, omega_sqr, omega_cube, quarter, input, j + pos * BATCH, rank * BATCH);
}
omega = omega * unit_omega % MOD;
}
}
}
}
// 批量基4快速数论变换
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3, size_t BATCH = 1>
void ntt_radix4_dif(UINT_32 *input, size_t ntt_len, bool bit_inv = true)
{
size_t log4_len = std::log2(ntt_len) / 2;
ntt_len = 1ull << (log4_len * 2);
if (ntt_len <= 1)
{
return;
}
constexpr UINT_64 quarter = qpow(G_ROOT, (MOD - 1) / 4, MOD); // 等价于复数的i
if (ntt_len == 4)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j, BATCH);
}
return;
}
UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / ntt_len, MOD);
for (size_t rank = ntt_len / 4; rank > 1; rank /= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < ntt_len; begin += (gap))
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + begin * BATCH, rank * BATCH);
}
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
UINT_64 omega_sqr = omega * omega % MOD;
UINT_64 omega_cube = omega_sqr * omega % MOD;
for (size_t j = 0; j < BATCH; j++)
{
ntt_radix4_dif_butterfly<MOD>(omega, omega_sqr, omega_cube, quarter, input, j + pos * BATCH, rank * BATCH);
}
omega = omega * unit_omega % MOD;
}
}
unit_omega = unit_omega * unit_omega % MOD;
unit_omega = unit_omega * unit_omega % MOD;
}
for (size_t pos = 0; pos < ntt_len; pos += 4)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + pos * BATCH, BATCH);
}
}
if (bit_inv)
{
quaternary_inverse_swap<BATCH>(input, ntt_len);
}
}
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_dif(UINT_32 *input, size_t ntt_len, bool bit_inv = true)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
size_t log_len = std::log2(ntt_len);
if (is_odd(log_len))
{
ntt_radix2_dif<MOD, G_ROOT>(input, ntt_len, bit_inv);
}
else
{
ntt_radix4_dif<MOD, G_ROOT>(input, ntt_len, bit_inv);
}
}
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void intt_dit(UINT_32 *input, size_t ntt_len, bool bit_inv = true)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
constexpr UINT_64 IG_ROOT = qpow(G_ROOT, MOD - 2, MOD);
size_t log_len = std::log2(ntt_len);
if (is_odd(log_len))
{
ntt_radix2_dit<MOD, IG_ROOT>(input, ntt_len, bit_inv);
}
else
{
ntt_radix4_dit<MOD, IG_ROOT>(input, ntt_len, bit_inv);
}
ntt_normalize<MOD>(input, ntt_len);
}
// 单线程NTT
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_single(UINT_32 *input, size_t ntt_len)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
size_t log_len = std::log2(ntt_len);
if (is_odd(log_len))
{
ntt_radix2_dif<MOD, G_ROOT>(input, ntt_len, true);
}
else
{
size_t quarter_len = ntt_len / 4;
ntt_radix4_dif<MOD, G_ROOT, 4>(input, quarter_len, true);
constexpr UINT_64 quarter = qpow(G_ROOT, (MOD - 1) / 4, MOD); // 等价于复数的i
const UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / ntt_len, MOD);
UINT_64 omega = 1;
for (size_t i = 0; i < ntt_len; i += 4)
{
UINT_64 omega_sqr = omega * omega % MOD;
UINT_64 omega_cube = omega_sqr * omega % MOD;
ntt_radix4_dit_butterfly<MOD>(omega, omega_sqr, omega_cube, quarter, input, i, 1);
omega = omega * unit_omega % MOD;
}
ary_interlace<4>(input, ntt_len);
}
}
// 双线程NTT
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_dual(UINT_32 *input, size_t ntt_len)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
size_t half_len = ntt_len / 2;
UINT_32 *tmp_ary = new UINT_32[half_len];
for (size_t i = 0; i < ntt_len; i += 2)
{
input[i / 2] = input[i];
tmp_ary[i / 2] = input[i + 1];
}
ary_copy(input + half_len, tmp_ary, half_len);
delete[] tmp_ary;
std::future<void> th = std::async(ntt_single<MOD, G_ROOT>, input, half_len);
ntt_single<MOD, G_ROOT>(input + half_len, half_len);
th.wait();
constexpr UINT_64 omega2 = qpow(G_ROOT, (MOD - 1) / 4, MOD);
const UINT_64 unit_omega = qpow(G_ROOT, (MOD - 1) / ntt_len, MOD);
auto merge_proc = [=](size_t start, size_t end, UINT_64 omega_start)
{
UINT_64 omega = omega_start;
for (size_t i = start; i < end; i++)
{
UINT_32 tmp1 = input[i];
UINT_32 tmp2 = input[i + half_len] * omega % MOD;
input[i] = add_mod(tmp1, tmp2, MOD);
input[i + half_len] = sub_mod(tmp1, tmp2, MOD);
omega = omega * unit_omega % MOD;
}
};
th = std::async(merge_proc, 0, half_len / 2, 1);
merge_proc(half_len / 2, half_len, omega2);
th.wait();
}
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt(UINT_32 *input, size_t ntt_len, bool multi_threads = false)
{
if (multi_threads)
{
ntt_dual<MOD, G_ROOT>(input, ntt_len);
}
else
{
ntt_single<MOD, G_ROOT>(input, ntt_len);
}
}
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void intt(UINT_32 *input, size_t ntt_len, bool multi_threads = false)
{
constexpr UINT_64 IG_ROOT = qpow(G_ROOT, MOD - 2, MOD);
if (multi_threads)
{
ntt_dual<MOD, IG_ROOT>(input, ntt_len);
}
else
{
ntt_single<MOD, IG_ROOT>(input, ntt_len);
}
ntt_normalize<MOD>(input, ntt_len);
}
}
// 数组按位相乘
template <typename T>
inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
{
for (size_t i = 0; i < len; i++)
{
out[i] = in1[i] * in2[i];
}
}
// 数组按位带模相乘,4路循环展开
template <UINT_64 MOD, typename T>
constexpr void ary_mul_mod(const T in1[], const T in2[], T out[], size_t len)
{
size_t mod4 = len % 4;
len -= mod4;
for (size_t i = 0; i < len; i += 4)
{
out[i] = static_cast<UINT_64>(in1[i]) * in2[i] % MOD;
out[i + 1] = static_cast<UINT_64>(in1[i + 1]) * in2[i + 1] % MOD;
out[i + 2] = static_cast<UINT_64>(in1[i + 2]) * in2[i + 2] % MOD;
out[i + 3] = static_cast<UINT_64>(in1[i + 3]) * in2[i + 3] % MOD;
}
for (size_t i = len; i < len + mod4; i++)
{
out[i] = static_cast<UINT_64>(in1[i]) * in2[i] % MOD;
}
}
template <typename T>
constexpr void normal_convolution(const T in1[], const T in2[], T out[],
size_t len1, size_t len2)
{
ary_clr(out, len1 + len2 - 1);
for (size_t i = 0; i < len1; i++)
{
T num1 = in1[i];
for (size_t j = 0; j < len2; j++)
{
out[i + j] = num1 * in2[j];
}
}
}
void fht_convolution(double fht_ary1[], double fht_ary2[], double out[], size_t fht_len)
{
hint_transform::fht(fht_ary1, fht_len);
if (fht_ary1 != fht_ary2)
{
hint_transform::fht(fht_ary2, fht_len);
}
out[0] = fht_ary1[0] * fht_ary2[0];
for (size_t i = 1; i < fht_len; ++i)
{
double tmp1 = fht_ary1[i], tmp2 = fht_ary1[fht_len - i];
out[i] = (fht_ary2[i] * (tmp1 + tmp2) + fht_ary2[fht_len - i] * (tmp1 - tmp2)) / 2;
}
hint_transform::ifht(out, fht_len);
}
std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
UINT_64 stoui64(const std::string::const_iterator &begin,
const std::string::const_iterator &end, const UINT_32 base = 10)
{
UINT_64 result = 0;
for (auto it = begin; it < end && it - begin < 19; ++it)
{
result *= base;
char c = tolower(*it);
UINT_64 n = 0;
if (isalnum(c))
{
if (isalpha(c))
{
n = c - 'a' + 10;
}
else
{
n = c - '0';
}
}
if (n < base)
{
result += n;
}
}
return result;
}
UINT_64 stoui64(const std::string &str, const UINT_32 base = 10)
{
return stoui64(str.begin(), str.end(), base);
}
template <typename T = UINT_32, typename SIZE_TYPE = UINT_32>
class HintVector
{
private:
T *ary_ptr = nullptr;
SIZE_TYPE sign_n_len = 0;
SIZE_TYPE size = 0;
public:
static constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8; // size和len成员的比特数
static constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1)); // 第一位为1,其余位为0
static constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1; // 定义最大长度
static constexpr SIZE_TYPE SIZE_FF = SIZE_80 + LEN_MAX; // 所有比特为1的数
~HintVector()
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
ary_ptr = nullptr;
}
}
constexpr HintVector()
{
ary_ptr = nullptr;
}
HintVector(SIZE_TYPE ele_size)
{
if (ele_size > 0)
{
resize(ele_size);
}
}
HintVector(SIZE_TYPE ele_size, T ele)
{
if (ele_size > 0)
{
resize(ele_size);
change_length(ele_size);
fill_element(ele, ele_size);
}
}
HintVector(const T *ary, SIZE_TYPE len)
{
if (ary == nullptr)
{
throw("Can't copy from nullptr\n");
}
if (len > 0)
{
resize(len);
change_length(len);
ary_copy(ary_ptr, ary, len);
}
}
HintVector(const HintVector &input)
{
if (input.ary_ptr != nullptr)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
}
}
HintVector(HintVector &&input)
{
if (input.ary_ptr != nullptr)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
}
}
HintVector &operator=(const HintVector &input)
{
if (this != &input)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
}
return *this;
}
HintVector &operator=(HintVector &&input)
{
if (this != &input)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
}
return *this;
}
template <typename Ty, typename SIZE_Ty>
void copy_from(const HintVector<Ty, SIZE_Ty> &input)
{
if (this != &input)
{
if (input.length() > LEN_MAX)
{
std::cerr << "HintVector err: too long\n";
exit(EXIT_FAILURE);
}
delete[] ary_ptr;
ary_ptr = nullptr;
SIZE_TYPE len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy_2type(ary_ptr, input.raw_ptr(), len);
}
}
template <typename Ty>
void copy_from(const Ty *input, SIZE_TYPE len)
{
if (input == nullptr)
{
throw("Can't copy from nullptr\n");
}
len = min(len, length());
ary_copy(ary_ptr, input, len);
}
T front() const
{
if (length() == 0)
{
return 0;
}
return ary_ptr[0];
}
T back() const
{
if (length() == 0)
{
return 0;
}
return ary_ptr[length() - 1];
}
T &operator[](SIZE_TYPE index) const
{
return ary_ptr[index];
}
void *raw_ptr() const
{
if (ary_ptr == nullptr)
{
std::cerr << "HintVector err: Get a ptr from an empty vector\n";
exit(EXIT_FAILURE);
}
return ary_ptr;
}
T *type_ptr() const
{
if (ary_ptr == nullptr)
{
std::cerr << "HintVector err: Get a ptr from an empty vector\n";
exit(EXIT_FAILURE);
}
return ary_ptr;
}
static SIZE_TYPE size_generator(SIZE_TYPE new_size)
{
new_size = std::min(new_size, LEN_MAX);
if (new_size <= 2)
{
return 2;
}
else if (new_size <= 4)
{
return 4;
}
SIZE_TYPE size1 = min_2pow(new_size);
SIZE_TYPE size2 = size1 / 2;
size2 = size2 + size2 / 2;
if (new_size <= size2)
{
return size2;
}
else
{
return size1;
}
}
void print() const
{
SIZE_TYPE i = length();
if (i == 0 || ary_ptr == nullptr)
{
std::cout << "Empty vector" << std::endl;
}
while (i > 0)
{
i--;
std::cout << ary_ptr[i] << "\t";
}
std::cout << std::endl;
}
SIZE_TYPE capacity() const
{
return size;
}
SIZE_TYPE length() const
{
return sign_n_len & LEN_MAX;
}
bool sign() const
{
return (SIZE_80 & sign_n_len) != 0;
}
constexpr void resize(SIZE_TYPE new_size, SIZE_TYPE(size_func)(SIZE_TYPE) = size_generator)
{
new_size = size_func(new_size);
if (ary_ptr == nullptr)
{
size = new_size;
ary_ptr = new T[size];
change_length(0);
}
else if (size != new_size)
{
size = new_size;
ary_ptr = ary_realloc(ary_ptr, size);
change_length(length());
}
}
void change_length(SIZE_TYPE new_length)
{
if (new_length > LEN_MAX)
{
throw("Length too long\n");
}
new_length = std::min(new_length, size);
bool sign_tmp = sign();
sign_n_len = new_length;
change_sign(sign_tmp);
}
void change_sign(bool is_sign)
{
if ((!is_sign) || length() == 0)
{
sign_n_len = sign_n_len & LEN_MAX;
}
else
{
sign_n_len = sign_n_len | SIZE_80;
}
}
SIZE_TYPE set_true_len()
{
if (ary_ptr == nullptr)
{
return 0;
}
SIZE_TYPE t_len = length();
while (t_len > 0 && ary_ptr[t_len - 1] == 0)
{
t_len--;
}
change_length(t_len);
return length();
}
void fill_element(T ele, SIZE_TYPE count, SIZE_TYPE begin = 0)
{
if (begin >= size || ary_ptr == nullptr)
{
return;
}
if (begin + count >= size)
{
count = size - begin;
}
std::fill(ary_ptr + begin, ary_ptr + begin + count, ele);
}
void clear()
{
ary_clr(ary_ptr, size);
}
};
}
#endif
#ifndef HINT_MINI_HPP
#define HINT_MINI_HPP
namespace hint_arithm
{
using namespace hint;
using SIZE_TYPE = UINT_32;
template <typename T>
using hintvector = HintVector<T, SIZE_TYPE>;
// 高精度绝对值比较,前者大于后者返回1,小于返回-1等于返回0
template <typename T>
constexpr INT_32 abs_compare(const T ary1[], const T ary2[], size_t len1, size_t len2)
{
len1 = ary_true_len(ary1, len1);
len2 = ary_true_len(ary2, len2);
if (len1 != len2)
{
return len1 > len2 ? 1 : -1;
}
if (ary1 == ary2)
{
return 0;
}
while (len1 > 0)
{
len1--;
T num1 = ary1[len1], num2 = ary2[len1];
if (num1 != num2)
{
return num1 > num2 ? 1 : -1;
}
}
return 0;
}
// 与左移后的ary2绝对值比较,前者大于后者返回1,小于返回-1等于返回0
template <typename T>
constexpr INT_32 abs_compare_shift(const T ary1[], const T ary2[], size_t len1, size_t len2, size_t shift = 0)
{
len1 = ary_true_len(ary1, len1);
len2 = ary_true_len(ary2, len2);
if (len1 != len2 + shift)
{
return len1 > (len2 + shift) ? 1 : -1;
}
INT_32 cmp = abs_compare(ary1 + shift, ary2, len1, len2);
if (cmp != 0)
{
return cmp;
}
for (size_t i = 0; i < shift; i++)
{
if (ary1[i] > 0)
{
return 1;
}
}
return 0;
}
// 高精度加法
template <bool is_carry = true, typename T>
constexpr void abs_add(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
carry += (in1[pos] + in2[pos]);
out[pos] = carry < base ? carry : carry - base;
carry = carry < base ? 0 : 1;
pos++;
}
while (pos < len1 && carry > 0)
{
carry += in1[pos];
out[pos] = carry < base ? carry : carry - base;
carry = carry < base ? 0 : 1;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
if (is_carry)
{
out[len1] = carry % base;
}
}
// 高精度减法
template <typename T>
constexpr void abs_sub(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
if (len1 < len2)
{
return;
}
size_t pos = 0;
INT_64 borrow = 0;
while (pos < len2)
{
borrow += (static_cast<INT_64>(in1[pos]) - in2[pos]);
out[pos] = borrow < 0 ? borrow + base : borrow;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
while (pos < len1 && borrow < 0)
{
borrow += in1[pos];
out[pos] = borrow < 0 ? borrow + base : borrow;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
}
// 64位搞精度加法
constexpr void abs_add64(const UINT_64 in1[], const UINT_64 in2[], UINT_64 out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
bool is_carry1 = false, is_carry2 = false;
std::tie(carry, is_carry1) = safe_add(carry, in1[pos]);
std::tie(carry, is_carry2) = safe_add(carry, in2[pos]);
out[pos] = carry;
carry = is_carry1 || is_carry2 ? 1 : 0;
pos++;
}
while (pos < len1)
{
bool is_carry = false;
std::tie(carry, is_carry) = safe_add(carry, in1[pos]);
out[pos] = carry;
carry = carry ? 1 : 0;
pos++;
}
out[len1] = carry;
}
// 64位多精度减法
constexpr void abs_sub64(const UINT_64 in1[], const UINT_64 in2[], UINT_64 out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
bool is_carry1 = false, is_carry2 = false;
std::tie(carry, is_carry1) = safe_add(carry, in1[pos]);
std::tie(carry, is_carry2) = safe_add(carry, in2[pos]);
out[pos] = carry;
carry = is_carry1 || is_carry2 ? 1 : 0;
pos++;
}
while (pos < len1)
{
bool is_carry = false;
std::tie(carry, is_carry) = safe_add(carry, in1[pos]);
out[pos] = carry;
carry = carry ? 1 : 0;
pos++;
}
out[len1] = carry;
}
// 小学乘法
template <typename T>
void normal_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
if (len1 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
T *res = out;
if (in1 == out || in2 == out)
{
res = new T[len1 + len2];
}
ary_clr(res, len1 + len2);
for (size_t i = 0; i < len1; i++)
{
UINT_64 num1 = in1[i];
for (size_t j = 0; j < len2; j++)
{
UINT_64 tmp = num1 * in2[j];
for (size_t k = i + j; tmp > 0 && k < len1 + len2; k++)
{
tmp += res[k];
std::tie(tmp, res[k]) = div_mod(tmp, base);
}
}
}
if (res != out)
{
ary_copy(out, res, len1 + len2);
delete[] res;
}
}
// fft加速乘法
template <typename T>
void fft_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 == 0 || len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
size_t conv_res_len = len1 + len2 - 1; // 卷积结果长度
size_t fft_len = min_2pow(conv_res_len); // fft长度
Complex *fft_ary = new Complex[fft_len];
com_ary_combine_copy(fft_ary, in1, len1, in2, len2);
hint_transform::fft_dif(fft_ary, fft_len, false);
for (size_t i = 0; i < fft_len; i++)
{
Complex tmp = fft_ary[i];
fft_ary[i] = std::conj(tmp * tmp);
}
hint_transform::fft_dit(fft_ary, fft_len, false);
hint::UINT_64 carry = 0;
double inv = 2 * fft_len;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += static_cast<hint::UINT_64>(-fft_ary[i].imag() / inv + 0.5);
std::tie(carry, out[i]) = div_mod<UINT_64>(carry, base);
}
out[conv_res_len] = carry % base;
delete[] fft_ary;
}
// fht加速乘法
template <typename T>
void fht_sqr(const T in[], T out[], size_t len, const UINT_64 base)
{
if (len == 0 || in == nullptr)
{
return;
}
size_t conv_res_len = len * 2 - 1; // 卷积结果长度
size_t fht_len = min_2pow(conv_res_len); // fht长度
double *fht_ary = new double[fht_len * 2];
ary_clr(fht_ary, fht_len);
ary_copy_2type(fht_ary, in, len);
fht_convolution(fht_ary, fht_ary, fht_ary + fht_len, fht_len);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += static_cast<hint::UINT_64>(fht_ary[i + fht_len] + 0.5);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] fht_ary;
}
// ntt加速乘法
template <typename T>
void ntt_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 == 0 || len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
size_t conv_res_len = len1 + len2 - 1; // 卷积结果长度
size_t ntt_len = min_2pow(conv_res_len); // ntt长度
UINT_32 *ntt_ary1 = new UINT_32[ntt_len * 4];
ary_clr(ntt_ary1, ntt_len * 4);
UINT_32 *ntt_ary2 = ntt_ary1 + ntt_len;
UINT_32 *ntt_ary3 = ntt_ary1 + ntt_len * 2;
UINT_32 *ntt_ary4 = ntt_ary1 + ntt_len * 3;
hint::ary_copy_2type(ntt_ary1, in1, len1);
hint::ary_copy_2type(ntt_ary2, in1, len1);
hint::ary_copy_2type(ntt_ary3, in2, len2);
hint::ary_copy_2type(ntt_ary4, in2, len2);
constexpr UINT_64 mod1 = NTT_MOD1, mod2 = NTT_MOD2;
constexpr UINT_64 root1 = NTT_ROOT1, root2 = NTT_ROOT2;
hint_transform::ntt_dif<mod1, root1>(ntt_ary1, ntt_len, false);
hint_transform::ntt_dif<mod2, root2>(ntt_ary2, ntt_len, false);
hint_transform::ntt_dif<mod1, root1>(ntt_ary3, ntt_len, false);
hint_transform::ntt_dif<mod2, root2>(ntt_ary4, ntt_len, false);
hint::ary_mul_mod<mod1>(ntt_ary1, ntt_ary3, ntt_ary1, ntt_len);
hint::ary_mul_mod<mod2>(ntt_ary2, ntt_ary4, ntt_ary2, ntt_len);
hint_transform::intt_dit<mod1, root1>(ntt_ary1, ntt_len, false);
hint_transform::intt_dit<mod2, root2>(ntt_ary2, ntt_len, false);
constexpr UINT_64 inv1 = qpow(mod1, mod2 - 2, mod2);
constexpr UINT_64 inv2 = qpow(mod2, mod1 - 2, mod1);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += qcrt(ntt_ary1[i], ntt_ary2[i], mod1, mod2, inv1, inv2);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] ntt_ary1;
}
// ntt加速平方
template <typename T>
void ntt_sqr(const T in[], T out[], size_t len, const UINT_64 base)
{
if (len == 0 || in == nullptr)
{
return;
}
size_t conv_res_len = len * 2 - 1; // 卷积结果长度
size_t ntt_len = min_2pow(conv_res_len); // ntt长度
UINT_32 *ntt_ary1 = new UINT_32[ntt_len * 2];
ary_clr(ntt_ary1, ntt_len * 2);
UINT_32 *ntt_ary2 = ntt_ary1 + ntt_len;
hint::ary_copy_2type(ntt_ary1, in, len);
hint::ary_copy(ntt_ary2, ntt_ary1, len);
constexpr UINT_64 mod1 = NTT_MOD1, mod2 = NTT_MOD2;
constexpr UINT_64 root1 = NTT_ROOT1, root2 = NTT_ROOT2;
constexpr UINT_64 iroot1 = qpow(root1, mod1 - 2, mod1);
constexpr UINT_64 iroot2 = qpow(root2, mod2 - 2, mod2);
hint_transform::ntt_dif<mod1, root1>(ntt_ary1, ntt_len);
hint_transform::ntt_dif<mod2, root2>(ntt_ary2, ntt_len);
hint::ary_mul_mod<mod1>(ntt_ary1, ntt_ary1, ntt_ary1, ntt_len);
hint::ary_mul_mod<mod2>(ntt_ary2, ntt_ary2, ntt_ary2, ntt_len);
hint_transform::intt_dit<mod1, iroot1>(ntt_ary1, ntt_len);
hint_transform::intt_dit<mod2, iroot2>(ntt_ary2, ntt_len);
constexpr UINT_64 inv1 = qpow(mod1, mod2 - 2, mod2);
constexpr UINT_64 inv2 = qpow(mod2, mod1 - 2, mod1);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += qcrt(ntt_ary1[i], ntt_ary2[i], mod1, mod2, inv1, inv2);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] ntt_ary1;
}
// karatsuba乘法
template <typename T, typename NTT_Ty = UINT_16>
constexpr void karatsuba_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
// (a*base^n+b)*(c*base^n+d) = a*c*base^2n+(a*d+b*c)*base^n+b*d
// compute: a*c,b*d,(a+b)*(c+d),a*b+b*c = (a+b)*(c+d)-a*c-b*d
len1 = ary_true_len(in1, len1);
len2 = ary_true_len(in2, len2);
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
// std::cin.get();
if (len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
if (len1 + len2 - 1 <= NTT_MAX_LEN)
{
const size_t ntt_len1 = len1 * std::max<size_t>(1, sizeof(T) / sizeof(NTT_Ty));
const size_t ntt_len2 = len2 * std::max<size_t>(1, sizeof(T) / sizeof(NTT_Ty));
if (in1 == in2 && len1 == len2)
{
ntt_sqr<NTT_Ty>(in1, out, ntt_len1, base);
}
else
{
ntt_mul<NTT_Ty>(in1, in2, out, ntt_len1, ntt_len2, base);
}
return;
}
size_t len_a = len1 / 2;
size_t len_b = len1 - len_a; // 公共长度
size_t len_c = len2 > len_b ? len2 - len_b : 1;
size_t len_d = len2 > len_b ? len_b : len2;
size_t len_ac = len_c > 0 ? len_a + len_c : 0; // a*c的长度
size_t len_bd = len_b + len_d; // b*d的长度
size_t len_add_mul = len_b + len_d + 2; //(a+b)*(c*d)的长度
const T *a_ptr = in1 + len_b; // in1代表b
const T *c_ptr = in2 + len_d; // in2代表d
hintvector<T> mul_ac(len_ac, 0); // 存储a*c
hintvector<T> mul_bd(len_bd, 0); // 存储b*d
hintvector<T> add_mul(len_add_mul, 0); // 存储a+b与c+d,a+b的长度为len_b+1
T *add_ab = add_mul.type_ptr();
T *add_cd = add_ab + len_b + 1;
abs_add(in1, a_ptr, add_ab, len_b, len_a, base); // b+a
abs_add(in2, c_ptr, add_cd, len_d, len_c, base); // d+c
size_t len_add_ab = ary_true_len(add_ab, len_b + 1);
size_t len_add_cd = ary_true_len(add_cd, len_d + 1);
karatsuba_mul(a_ptr, c_ptr, mul_ac.type_ptr(), len_a, len_c, base); // a*c
karatsuba_mul(in1, in2, mul_bd.type_ptr(), len_b, len_d, base); // b*d
karatsuba_mul(add_ab, add_cd, add_mul.type_ptr(), len_b + 1, len_d + 1, base); //(a+b)*(c+d)
add_mul.change_length(len_add_ab + len_add_cd);
len_ac = mul_ac.set_true_len();
len_bd = mul_bd.set_true_len();
len_add_mul = add_mul.set_true_len();
ary_clr(out, len1 + len2);
ary_copy(out, mul_bd.type_ptr(), len_bd); // 结果加上b*d
ary_copy(out + len_b * 2, mul_ac.type_ptr(), len_ac);
INT_64 carry = 0;
for (size_t pos = len_b; pos < len1 + len2; pos++)
{
size_t t_pos = pos - len_b;
carry += out[pos];
if (t_pos < len_add_mul)
{
carry += add_mul[t_pos];
}
if (t_pos < len_ac)
{
carry -= mul_ac[t_pos];
}
if (t_pos < len_bd)
{
carry -= mul_bd[t_pos];
}
INT_64 rem = carry % base;
carry = carry < 0 ? (carry + 1) / base - 1 : carry / base;
out[pos] = rem < 0 ? rem + base : rem;
}
}
// 高精度乘法
template <typename T>
constexpr void abs_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
len1 = ary_true_len(in1, len1);
len2 = ary_true_len(in2, len2);
if (len1 + len2 <= 48 || len1 * len2 < (len1 + len2) * std::log2(len1 + len2))
{
normal_mul(in1, in2, out, len1, len2, base);
}
else if (len1 + len2 - 1 <= (1 << hint_transform::lut_max_rank))
{
fft_mul(in1, in2, out, len1, len2, base);
}
else if (len1 + len2 - 1 <= NTT_MAX_LEN)
{
ntt_mul(in1, in2, out, len1, len2, base);
}
else
{
karatsuba_mul(in1, in2, out, len1, len2, base);
}
}
// 高精度平方
template <typename T>
constexpr void abs_sqr(const T in[], T out[], size_t len, const INT_64 base)
{
len = ary_true_len(in, len);
if (len <= 24)
{
normal_mul(in, in, out, len, len, base);
}
else if (len * 2 - 1 <= (1 << hint_transform::lut_max_rank))
{
fht_sqr(in, out, len, base);
}
else if (len * 2 - 1 <= NTT_MAX_LEN)
{
ntt_sqr(in, out, len, base);
}
else
{
karatsuba_mul(in, in, out, len, len, base);
}
}
// 高精度乘低精度
template <bool is_carry = true, typename T>
constexpr void abs_mul_num(const T in[], T num, T out[], size_t len, const UINT_64 base)
{
len = ary_true_len(in, len);
num %= base;
UINT_64 prod = 0;
if (num == 1)
{
ary_copy(out, in, len);
}
else
{
for (size_t i = 0; i < len; i++)
{
prod += static_cast<UINT_64>(in[i]) * num;
std::tie(prod, out[i]) = div_mod<UINT_64>(prod, base);
}
}
if (is_carry)
{
out[len] = prod;
}
}
// 除以num的同时返回余数
template <typename T>
constexpr INT_64 abs_div_num(const T in[], T num, T out[], size_t len, const UINT_64 base)
{
size_t pos = ary_true_len(in, len);
num %= base;
if (num == 1)
{
ary_copy(out, in, len);
return 0;
}
UINT_64 last_rem = 0;
while (pos > 0)
{
pos--;
last_rem = last_rem * base + in[pos];
std::tie(out[pos], last_rem) = div_mod<UINT_64>(last_rem, num);
}
return last_rem;
}
// 除数的规则化
template <typename T>
constexpr T divisor_normalize(const T in[], T out[], size_t len, const UINT_64 base)
{
if (in == out)
{
throw("In can't be same as out\n");
}
T multiplier = 1;
if (len == 1)
{
multiplier = (base - 1) / in[0];
out[0] = in[0] * multiplier;
}
else if (in[len - 1] >= (base / 2))
{
ary_copy(out, in, len);
}
else
{
multiplier = (base - 1) * base / (base * in[len - 1] + in[len - 2]);
abs_mul_num<false>(in, multiplier, out, len, base);
if (out[len - 1] < (base / 2))
{
multiplier++;
abs_add<false>(out, in, out, len, len, base);
}
}
return multiplier;
}
// 长除法,从被除数返回余数,需要确保除数的规则化
template <typename T>
void abs_long_div(T dividend[], const T divisor[], T quot[],
size_t len1, size_t len2, const UINT_64 base)
{
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (len2 == 1)
{
T rem = abs_div_num(dividend, divisor[0], quot, len1, base);
dividend[0] = rem;
return;
}
if (divisor[len2 - 1] < (base / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
quot[len1 - len2] = 0;
const UINT_64 divisor_2digits = base * divisor[len2 - 1] + divisor[len2 - 2];
hintvector<T> sub(len2 + 1);
sub.change_length(len2 + 1);
// 被除数(余数大于等于除数则继续减)
while (abs_compare(dividend, divisor, len1, len2) >= 0)
{
sub.change_length(len2 + 1);
UINT_64 dividend_2digits = dividend[len1 - 1] * base + dividend[len1 - 2];
T quo_digit = 0;
size_t shift = len1 - len2;
// 被除数前两位大于等于除数前两位试商的结果偏差不大于1
if (dividend_2digits >= divisor_2digits)
{
quo_digit = dividend_2digits / divisor_2digits;
abs_mul_num(divisor, quo_digit, sub.type_ptr(), len2, base);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
}
}
else
{
// 被除数前两位和除数前一位试商的结果偏差不大于2
quo_digit = dividend_2digits / (divisor_2digits / base);
if (quo_digit >= base)
{
quo_digit = base - 1;
}
shift--;
abs_mul_num(divisor, quo_digit, sub.type_ptr(), len2, base);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
}
}
}
abs_sub(dividend + shift, sub.type_ptr(), dividend + shift, len1, sub.length(), base);
len1 = ary_true_len(dividend, len1);
quot[shift] = quo_digit;
}
}
// 递归除法,从被除数返回余数,需要确保除数的规则化
template <typename T>
void abs_rec_div(T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2, const UINT_64 base, int c = 0)
{
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (divisor[len2 - 1] < (base / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
size_t quot_len = len1 - len2 + 1;
quot.change_length(quot_len);
constexpr size_t LONG_DIV_THRESHOLD = 50;
if (len2 <= LONG_DIV_THRESHOLD) // 小于等于阈值调用长除法
{
abs_long_div(dividend, divisor, quot.type_ptr(), len1, len2, base);
}
else if (len1 >= len2 * 2 || len1 > ((len2 + 1) / 2) * 3) // 2n/n的除法,进行两次递归
{
size_t base_len = (len1 + 3) / 4;
size_t quot_tmp_len = base_len * 3 - len2 + 1;
hintvector<T> quot_tmp(quot_tmp_len, 0);
abs_rec_div(dividend + base_len, divisor, quot_tmp, len1 - base_len, len2, base, c + 1);
quot_tmp_len = quot_tmp.set_true_len();
size_t dividend_len = ary_true_len(dividend, len1);
abs_rec_div(dividend, divisor, quot, dividend_len, len2, base, c + 2);
quot.change_length(quot_len);
quot_len = quot.set_true_len();
abs_add(quot.type_ptr() + base_len, quot_tmp.type_ptr(), quot.type_ptr() + base_len, quot_len - base_len, quot_tmp_len, base);
quot.change_length(len1 - len2 + 1);
}
else
{
// 开始试商,用dividend/(base^base_len)除以divisor/(base^base_len)
size_t base_len = len2 / 2;
abs_rec_div(dividend + base_len, divisor + base_len, quot, len1 - base_len, len2 - base_len, base, c + 2);
constexpr T ONE[1] = {1};
quot_len = quot.set_true_len();
hintvector<T> prod(base_len + quot_len);
prod.change_length(base_len + quot_len);
// 用除数的低base_len位乘以刚刚试出来的商,而后与余数比较,必须满足quot*(divisor%(base^base_len))<=dividend
abs_mul(divisor, quot.type_ptr(), prod.type_ptr(), base_len, quot_len, base);
size_t prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, len1);
while (abs_compare(prod.type_ptr(), dividend, prod_len, len1) > 0)
{
abs_sub(quot.type_ptr(), ONE, quot.type_ptr(), quot_len, 1, base);
abs_sub(prod.type_ptr(), divisor, prod.type_ptr(), prod_len, base_len, base);
abs_add(dividend + base_len, divisor + base_len, dividend + base_len, len1 - base_len, len2 - base_len, base);
quot_len = quot.set_true_len();
prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, std::max(len1, len2) + 1);
}
abs_sub(dividend, prod.type_ptr(), dividend, len1, prod_len, base);
}
}
// 绝对值除法
template <typename T>
hintvector<T> abs_div(const T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2, const UINT_64 base)
{
hintvector<T> normalized_divisor(len2); // 定义规则化的除数
normalized_divisor.change_length(len2);
hintvector<T> normalized_dividend(len1 + 1); // 定义规则化的被除数
normalized_dividend.change_length(len1 + 1);
T *divisor_ptr = normalized_divisor.type_ptr();
T *dividend_ptr = normalized_dividend.type_ptr();
T multiplier = divisor_normalize(divisor, divisor_ptr, len2, base); // 除数规则化,获得乘数
abs_mul_num(dividend, multiplier, dividend_ptr, len1, base); // 被除数规则化
len1 = normalized_dividend.set_true_len();
quot = hintvector<T>(len1 - len2 + 2, 0);
abs_rec_div(dividend_ptr, divisor_ptr, quot, len1, len2, base);
// abs_long_div(dividend_ptr, divisor_ptr, quot.type_ptr(), normalized_dividend.length(), len2, base);
abs_div_num(dividend_ptr, multiplier, dividend_ptr, len1, base); // 余数除以乘数得到正确的结果
normalized_dividend.set_true_len();
return normalized_dividend;
}
}
// 简单高精度简单实现
class Integer
{
public:
using DataType = hint::UINT_16;
using SizeType = hint::UINT_32;
using DataVec = hint::HintVector<DataType, SizeType>;
private:
DataVec data;
public:
static constexpr hint::UINT_32 DIGIT = 4;
static constexpr hint::UINT_64 BASE = hint::qpow(10, DIGIT);
Integer()
{
data = DataVec();
}
// Integer 拷贝构造
Integer(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
}
// Integer 移动构造
Integer(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
}
// string 参数构造
Integer(const std::string &input)
{
string_in(input);
}
// 字符串构造
Integer(char input[])
{
string_in(input);
}
// 字符串构造
Integer(const char input[])
{
string_in(input);
}
// 通用构造
template <typename T>
Integer(T input)
{
bool is_neg = hint::is_neg(input);
hint::UINT_64 tmp = std::abs<hint::INT_64>(input);
size_t digits = std::ceil(std::log10(tmp + 1));
size_t len = (digits + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
for (size_t i = 0; i < len; i++)
{
data[i] = tmp % BASE;
tmp /= BASE;
}
data.change_sign(is_neg);
data.set_true_len();
}
// Integer 拷贝赋值
Integer &operator=(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
return *this;
}
// Integer 移动赋值
Integer &operator=(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
return *this;
}
// string 赋值
Integer &operator=(const std::string &input)
{
string_in(input);
return *this;
}
// 字符串赋值
Integer &operator=(const char input[])
{
string_in(input);
return *this;
}
// 字符串赋值
Integer &operator=(char input[])
{
string_in(input);
return *this;
}
// 通用赋值
template <typename T>
Integer &operator=(T input)
{
bool is_neg = hint::is_neg(input);
hint::UINT_64 tmp = std::abs<hint::INT_64>(input);
size_t digits = std::ceil(std::log10(tmp + 1));
size_t len = (digits + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
for (size_t i = 0; i < len; i++)
{
std::tie(tmp, data[i]) = hint::div_mod(tmp, BASE);
}
data.change_sign(is_neg);
data.set_true_len();
return *this;
}
// 首位的数字
DataType first_num() const
{
if (length() == 0)
{
return 0;
}
return data[length() - 1];
}
// 前2位的数字
hint::UINT_64 first2_num() const
{
if (length() < 2)
{
return first_num();
}
return data[length() - 1] * BASE + data[length() - 2];
}
// 更改符号
void change_sign(bool is_neg)
{
data.change_sign(is_neg);
}
// 是否为负号
bool is_neg() const
{
return data.sign();
}
SizeType length() const
{
return data.length();
}
SizeType length_base10() const
{
size_t len = data.length();
if (len == 0)
{
return 1;
}
return (len - 1) * DIGIT + std::ceil(std::log10(first_num() + 1));
}
void string_in(const std::string &str)
{
size_t str_len = str.size();
if (str_len == 0)
{
data = DataVec();
return;
}
hint::INT_64 len = str_len, pos = len, i = 0;
bool is_neg = false;
if (str[0] == '-')
{
is_neg = true;
len--;
}
len = (len + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
while (pos - DIGIT > 0)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + pos - DIGIT, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
i++;
pos -= DIGIT;
}
hint::INT_64 begin = is_neg ? 1 : 0;
if (pos > begin)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + begin, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
}
change_sign(is_neg);
data.set_true_len();
}
std::string to_string() const
{
std::string result;
size_t pos = length();
if (pos == 0)
{
return "0";
}
if (is_neg())
{
result += '-';
}
result += std::to_string(first_num());
pos--;
while (pos > 0)
{
pos--;
result += hint::ui64to_string(data[pos], DIGIT);
}
return result;
}
void print() const
{
size_t pos = length();
if (pos < 1)
{
putchar('0');
return;
}
if (is_neg())
{
putchar('-');
}
printf("%d", first_num());
pos--;
while (pos > 0)
{
pos--;
printf(" %04d", data[pos]);
}
putchar('\n');
}
hint::INT_32 abs_compare(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
return hint_arithm::abs_compare(data.type_ptr(), input.data.type_ptr(), len1, len2);
}
Integer abs_add(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
Integer result;
result.data = DataVec(std::max(len1, len2) + 1);
result.data.change_length(std::max(len1, len2) + 1);
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
auto res_ptr = result.data.type_ptr();
hint_arithm::abs_add<true>(ptr1, ptr2, res_ptr, len1, len2, BASE);
result.data.set_true_len();
return result;
}
Integer abs_sub(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
Integer result;
result.data = DataVec(std::max(len1, len2));
result.data.change_length(std::max(len1, len2));
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
auto res_ptr = result.data.type_ptr();
hint_arithm::abs_sub(ptr1, ptr2, res_ptr, len1, len2, BASE);
result.data.set_true_len();
return result;
}
bool operator>(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return !is_neg();
}
return is_neg() != (abs_compare(input) > 0);
}
bool operator<(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return is_neg();
}
return is_neg() != (abs_compare(input) < 0);
}
bool operator>=(const Integer &input) const
{
return !(*this < input);
}
bool operator<=(const Integer &input) const
{
return !(*this > input);
}
bool operator==(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return false;
}
return abs_compare(input) == 0;
}
bool operator!=(const Integer &input) const
{
return !(*this == input);
}
Integer operator/(const Integer &input) const
{
Integer result;
size_t len1 = length(), len2 = input.length();
if (abs_compare(input) < 0)
{
return result;
}
if (len2 == 0)
{
throw("Can't divide by zero\n");
}
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
hint_arithm::abs_div(ptr1, ptr2, result.data, len1, len2, BASE);
result.data.set_true_len();
result.change_sign(is_neg() != input.is_neg());
return result;
}
};
#endif
using namespace std;
int main()
{
Integer a, b;
string s;
cin >> s;
a = s;
cin >> s;
b = s;
cout << (a / b).to_string();
return 0;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 2.235 ms | 116 KB | Accepted | Score: 100 | 显示更多 |