#include <algorithm>
#include <atomic>
#include <complex>
#include <future>
#include <iostream>
#include <random>
#include <stack>
#include <string>
#include <thread>
#include <tuple>
#include <cassert>
#include <climits>
#include <cmath>
#include <cstdlib>
#include <cstring>
#ifndef HINT_MATH_HPP
#define HINT_MATH_HPP
// #define MULTITHREAD
namespace hint
{
using UINT_8 = uint8_t;
using UINT_16 = uint16_t;
using UINT_32 = uint32_t;
using UINT_64 = uint64_t;
using INT_32 = int32_t;
using INT_64 = int64_t;
using ULONG = unsigned long;
using LONG = long;
using Complex = std::complex<double>;
constexpr UINT_64 HINT_CHAR_BIT = 8;
constexpr UINT_64 HINT_SHORT_BIT = 16;
constexpr UINT_64 HINT_INT_BIT = 32;
constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
constexpr UINT_64 HINT_INT32_0X01 = 1;
constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
constexpr double HINT_PI = 3.1415926535897932384626433832795;
constexpr double HINT_2PI = HINT_PI * 2;
constexpr double HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
constexpr UINT_64 NTT_MOD1 = 3221225473;
constexpr UINT_64 NTT_ROOT1 = 5;
constexpr UINT_64 NTT_MOD2 = 3489660929;
constexpr UINT_64 NTT_ROOT2 = 3;
constexpr size_t NTT_MAX_LEN = 1ull << 28;
#ifdef MULTITHREAD
const UINT_32 hint_threads = std::thread::hardware_concurrency();
const UINT_32 log2_threads = std::ceil(std::log2(hint_threads));
std::atomic<UINT_32> cur_ths;
#endif
double cas(double x)
{
return std::cos(x) + std::sin(x);
}
/// @brief 生成不大于n的最大的2的幂次的数
/// @param n
/// @return 不大于n的最大的2的幂次的数
inline UINT_64 max_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::floor(std::log2(n)));
}
/// @brief 生成不小于n的最小的2的幂次的数
/// @param n
/// @return 不小于n的最小的2的幂次的数
inline UINT_64 min_2pow(UINT_64 n)
{
return 1ull << static_cast<UINT_16>(std::ceil(std::log2(n)));
}
template <typename T>
constexpr bool is_neg(T x)
{
return x < 0;
}
template <typename T>
constexpr bool is_odd(T x)
{
return static_cast<bool>(x & 1);
}
// 模板快速幂
template <typename T>
constexpr T qpow(T m, UINT_64 n)
{
T result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m;
}
m = m * m;
n >>= 1;
}
return result;
}
// 取模快速幂
constexpr UINT_64 qpow(UINT_64 m, UINT_64 n, UINT_64 mod)
{
if (m <= 1)
{
return m;
}
UINT_64 result = 1;
while (n > 0)
{
if ((n & 1) != 0)
{
result = result * m % mod;
}
m = m * m % mod;
n >>= 1;
}
return result;
}
// 利用编译器优化一次性算出商和余数
template <typename T>
constexpr std::pair<T, T> div_mod(T a, T b)
{
return std::make_pair(a / b, a % b);
}
// 无溢出64位与32位乘
constexpr std::pair<UINT_32, UINT_64> safe_mul(UINT_64 a, UINT_32 b)
{
UINT_64 tmp1 = a & HINT_INT32_0XFF;
UINT_64 tmp2 = a >> HINT_INT_BIT;
tmp1 *= b;
tmp2 *= b;
tmp2 += (tmp1 >> HINT_INT_BIT);
tmp1 &= HINT_INT32_0XFF;
return std::make_pair(tmp2 >> HINT_INT_BIT, (tmp2 << HINT_INT_BIT) + tmp1);
}
// 无溢出64位加法
constexpr std::pair<UINT_64, bool> safe_add(UINT_64 a, UINT_64 b)
{
UINT_64 tmp = HINT_INT64_0XFF - b;
if (a > tmp)
{
return std::make_pair(a - tmp - 1, true);
}
return std::make_pair(a + b, false);
}
// 无溢出64位减法
constexpr std::pair<UINT_64, bool> safe_sub(UINT_64 a, UINT_64 b)
{
if (a >= b)
{
return std::make_pair(a - b, false);
}
UINT_64 tmp = HINT_INT64_0XFF - b;
return std::make_pair(a + tmp + 1, true);
}
// 96位整数除以64位整数
constexpr UINT_64 div_3by2(std::pair<UINT_64, UINT_32> dividend, UINT_64 divisor)
{
UINT_64 rem = 0;
UINT_64 quo = 0;
auto tmp = div_mod(dividend.first, divisor);
quo = tmp.first << 32;
rem = tmp.second << 32;
rem += dividend.second;
auto tmp2 = div_mod(rem, divisor);
quo += tmp2.first;
return quo;
}
// 最大公因数
constexpr UINT_64 gcd(UINT_64 a, UINT_64 b)
{
while (b > 0)
{
UINT_64 tmp = b;
b = a % b;
a = tmp;
}
return a;
}
// 中国剩余定理
UINT_64 crt(UINT_64 mods[], UINT_64 nums[], size_t n)
{
UINT_64 result = 0, mod_product = 1;
for (size_t i = 0; i < n; i++)
{
mod_product *= mods[i];
}
for (size_t i = 0; i < n; i++)
{
UINT_64 mod = mods[i];
UINT_64 tmp = mod_product / mod;
UINT_64 inv = qpow(tmp, mod - 2, mod);
result += nums[i] * tmp * inv % mod_product;
}
return result % mod_product;
}
/// @brief 快速计算两模数的中国剩余定理,返回n
/// @param num1 n模除mod1的余数
/// @param num2 n模除mod2的余数
/// @param mod1 第一个模数
/// @param mod2 第二个模数
/// @param inv1 第一个模数在第二个模数下的逆元
/// @param inv2 第二个模数在第一个模数下的逆元
/// @return n的最小解
constexpr UINT_64 qcrt(UINT_64 num1, UINT_64 num2,
UINT_64 mod1, UINT_64 mod2,
UINT_64 inv1, UINT_64 inv2)
{
if (num1 > num2)
{
return ((num1 - num2) * inv2 % mod1) * mod2 + num2;
}
else
{
return ((num2 - num1) * inv1 % mod2) * mod1 + num1;
}
}
// 模板数组拷贝
template <typename T>
void ary_copy(T* target, const T* source, size_t len)
{
if (len == 0 || target == source)
{
return;
}
if (len >= INT64_MAX)
{
throw("Ary too long\n");
}
std::memcpy(target, source, len * sizeof(T));
}
// 模板数组拷贝
template <typename T1, typename T2>
void ary_copy_2type(T1* target, const T2* source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = static_cast<T1>(source[i]);
}
}
// 从其他类型数组拷贝到复数组实部
template <typename T>
inline void com_ary_real_copy(Complex* target, const T& source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = Complex(source[i], target[i].imag());
}
}
// 从其他类型数组拷贝到复数组虚部
template <typename T>
inline void com_ary_img_copy(Complex* target, const T& source, size_t len)
{
for (size_t i = 0; i < len; i++)
{
target[i] = Complex(target[i].real(), source[i]);
}
}
// 从其他类型数组拷贝到复数组
template <typename T>
inline void com_ary_combine_copy(Complex* target, const T& source1, size_t len1, const T& source2, size_t len2)
{
size_t min_len = std::min(len1, len2);
size_t i = 0;
while (i < min_len)
{
target[i] = Complex(source1[i], source2[i]);
i++;
}
while (i < len1)
{
target[i].real(source1[i]);
i++;
}
while (i < len2)
{
target[i].imag(source2[i]);
i++;
}
}
// 去除数组前导零后的长度
template <typename T>
constexpr size_t ary_true_len(const T& ary, size_t len)
{
while (len > 0 && ary[len - 1] == 0)
{
len--;
}
return len;
}
// 模版数组分配内存且清零
template <typename T>
inline void ary_calloc(T*& ptr, size_t len)
{
ptr = static_cast<T*>(calloc(len, sizeof(T)));
}
// 模版数组清零
template <typename T>
inline void ary_clr(T* ptr, size_t len)
{
memset(ptr, 0, len * sizeof(T));
}
// 重分配空间
template <typename T>
inline T* ary_realloc(T* ptr, size_t len)
{
if (len * sizeof(T) < INT64_MAX)
{
return static_cast<T*>(realloc(ptr, len * sizeof(T)));
}
throw("realloc error");
return nullptr;
}
// 二进制逆序
template <typename T>
constexpr void binary_inverse_swap(T& ary, size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = len >> 1;
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
// 四进制逆序
template <typename SizeType = UINT_32, typename T>
constexpr void quaternary_inverse_swap(T& ary, size_t len, size_t batch = 1)
{
SizeType log_n = std::log2(len);
SizeType* rev = new SizeType[len / 4];
rev[0] = 0;
for (SizeType i = 1; i < len; i++)
{
SizeType index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
if (i < len / 4)
{
rev[i] = index;
}
if (i < index)
{
for (size_t j = 0; j < batch; j++)
{
std::swap(ary[i * batch + j], ary[index * batch + j]);
}
}
}
delete[] rev;
}
template <typename T>
void q_swap(T ary[], size_t len)
{
size_t i = 0;
for (size_t j = 1; j < len - 1; j++)
{
size_t k = (len >> 1) + (len >> 2);
i ^= k;
while (k > i)
{
k >>= 1;
i ^= k;
};
if (j < i)
{
std::swap(ary[i], ary[j]);
}
}
}
// 数组交错重拍
template <UINT_64 N, typename T>
void ary_interlace(T ary[], size_t len)
{
size_t sub_len = len / N;
T* tmp_ary = new T[len - sub_len];
for (size_t i = 0; i < len; i += N)
{
ary[i / N] = ary[i];
for (size_t j = 0; j < N - 1; j++)
{
tmp_ary[j * sub_len + i / N] = ary[i + j + 1];
}
}
ary_copy(ary + sub_len, tmp_ary, len - sub_len);
delete[] tmp_ary;
}
class UnitTable
{
private:
Complex* table = nullptr;
INT_32 max_log_size = 3;
INT_32 cur_log_size = 3;
UnitTable(const UnitTable&) = delete;
UnitTable& operator=(const UnitTable&) = delete;
public:
~UnitTable()
{
if (table != nullptr)
{
delete[] table;
table = nullptr;
}
}
// 初始化可以生成平分圆1<<shift份产生的单位根的表
UnitTable(UINT_32 max_shift)
{
max_shift = std::max<size_t>(max_shift, 3);
max_log_size = max_shift;
size_t ary_size = (1ull << (max_shift - 1)) - 2;
table = new Complex[ary_size];
for (size_t pos = 0; pos < 2; pos++)
{
table[pos] = unit_root(pos * HINT_PI / 4);
}
expend(max_shift - 3);
}
void expend(INT_32 shift)
{
shift = std::min(shift, max_log_size);
shift = std::max(shift, 4);
for (INT_32 i = cur_log_size - 1; i <= shift - 2; i++)
{
size_t len = 1ull << i;
size_t begin = len - 2;
for (size_t pos = 0; pos < len; pos++)
{
table[pos + begin] = unit_root(pos * HINT_PI / (len * 2));
}
}
cur_log_size = std::max(cur_log_size, shift);
}
// 返回单位圆上辐角为theta的点
static Complex unit_root(double theta)
{
return std::polar<double>(1.0, theta);
}
// shift表示圆平分为1<<shift份,n表示第几个单位根
Complex get_complex(INT_32 shift, size_t n)
{
size_t rank = 1ull << shift;
n &= (rank - 1);
size_t zone = (n << 2) >> shift; // 第几象限
if (((n << 2) & (rank - 1)) == 0)
{
constexpr Complex ary[4] = { Complex(1, 0), Complex(0, 1), Complex(-1, 0), Complex(0, -1) };
return ary[zone];
}
rank >>= 2;
const Complex* ptr = table + rank - 2;
Complex tmp;
switch (zone)
{
case 0:
tmp = ptr[n];
break;
case 1:
tmp = ptr[(rank << 1) - n];
tmp.real(-tmp.real());
break;
case 2:
tmp = -ptr[n - (rank << 1)];
break;
case 3:
tmp = std::conj(ptr[(rank << 2) - n]);
break;
default:
break;
}
return tmp;
}
};
constexpr size_t lut_max_rank = 23;
UnitTable TABLE(lut_max_rank); // 初始化fft表
// 基2蝶形变换
inline void fft_radix2_butterfly(Complex omega, Complex* input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
input[pos] = tmp1 + tmp2;
input[pos + rank] = tmp1 - tmp2;
}
// 基4蝶形变换
inline void fft_radix4_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
Complex* input, size_t pos, size_t rank)
{
size_t pos1 = pos + rank;
size_t pos2 = pos1 + rank;
size_t pos3 = pos2 + rank;
Complex tmp1 = input[pos];
Complex tmp2 = input[pos1] * omega;
Complex tmp3 = input[pos2] * omega_sqr;
Complex tmp4 = input[pos3] * omega_cube;
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-t4.imag(), t4.real());
input[pos] = t1 + t2;
input[pos1] = t3 - t4;
input[pos2] = t1 - t2;
input[pos3] = t3 + t4;
}
// 分裂基蝶形变换
inline void fft_splitradix_butterfly(Complex omega, Complex omega_cube,
Complex* input, size_t pos, size_t rank)
{
size_t pos1 = pos + rank;
size_t pos2 = pos1 + rank;
size_t pos3 = pos2 + rank;
Complex tmp1 = input[pos];
Complex tmp2 = input[pos1];
Complex tmp3 = input[pos2] * omega;
Complex tmp4 = input[pos3] * omega_cube;
Complex tmp5 = tmp3 + tmp4;
Complex tmp6 = tmp3 - tmp4;
tmp6 = Complex(-tmp6.imag(), tmp6.real());
input[pos] = tmp1 + tmp5;
input[pos1] = tmp2 + tmp6;
input[pos2] = tmp1 - tmp5;
input[pos3] = tmp2 - tmp6;
}
// 2点fft
inline void fft_2point(Complex input1, Complex input2, Complex& sum, Complex diff)
{
sum = input1 + input2;
diff = input1 - input2;
}
// 4点fft
inline void fft_4point(Complex* input, size_t pos, size_t rank)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank];
Complex tmp3 = input[pos + rank * 2];
Complex tmp4 = input[pos + rank * 3];
Complex t1 = tmp1 + tmp3;
Complex t2 = tmp2 + tmp4;
Complex t3 = tmp1 - tmp3;
Complex t4 = tmp2 - tmp4;
t4 = Complex(-t4.imag(), t4.real());
input[pos] = t1 + t2;
input[pos + rank] = t3 - t4;
input[pos + rank * 2] = t1 - t2;
input[pos + rank * 3] = t3 + t4;
}
// 经典模板,学习用
void fft_radix2(Complex* input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
binary_inverse_swap(input, fft_len);
for (size_t rank = 1, gap; rank < fft_len; rank *= 2)
{
gap = rank * 2;
Complex unit_omega = std::polar<double>(1, -HINT_2PI / gap);
for (size_t begin = 0; begin < fft_len; begin += gap)
{
Complex omega = Complex(1, 0);
for (size_t pos = begin; pos < begin + rank; pos++)
{
Complex tmp1 = input[pos];
Complex tmp2 = input[pos + rank] * omega;
input[pos] = tmp1 + tmp2;
input[pos + rank] = tmp1 - tmp2;
omega *= unit_omega;
}
}
}
}
// 基4快速傅里叶变换,模板,学习用
void fft_radix4(Complex* input, size_t fft_len)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
quaternary_inverse_swap(input, fft_len);
for (size_t pos = 0; pos < fft_len; pos += 4)
{
fft_4point(input, pos, 1);
}
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
Complex unit_omega = std::polar<double>(1, -HINT_2PI / gap);
Complex unit_sqr = std::polar<double>(1, -HINT_2PI * 2 / gap);
Complex unit_cube = std::polar<double>(1, -HINT_2PI * 3 / gap);
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_4point(input, begin, rank);
Complex omega = unit_omega;
Complex omega_sqr = unit_sqr;
Complex omega_cube = unit_cube;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
omega *= unit_omega;
omega_sqr *= unit_sqr;
omega_cube *= unit_cube;
}
}
}
}
// 求共轭复数及归一化,逆变换用
inline void fft_conj(Complex* input, size_t fft_len, double div = 1)
{
for (size_t i = 0; i < fft_len; i++)
{
input[i] = std::conj(input[i]) / div;
}
}
// 归一化,逆变换用
inline void fft_normalize(Complex* input, size_t fft_len)
{
double len = static_cast<double>(fft_len);
for (size_t i = 0; i < fft_len; i++)
{
input[i] /= len;
}
}
// 基4查表快速傅里叶变换
void fft_radix4_lut(Complex* input, size_t fft_len)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("fft length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
quaternary_inverse_swap(input, fft_len);
for (size_t pos = 0; pos < fft_len; pos += 4)
{
fft_4point(input, pos, 1);
}
UINT_32 shift = 4;
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
fft_4point(input, begin, rank);
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
size_t count = pos - begin;
Complex omega = std::conj(TABLE.get_complex(shift, count));
Complex omega_sqr = std::conj(TABLE.get_complex(shift, count * 2));
Complex omega_cube = std::conj(TABLE.get_complex(shift, count * 3));
fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, pos, rank);
}
}
shift += 2;
}
}
// 批量基4fft,BATCH为1时为正常基4fft
template <size_t BATCH = 1>
void fft_radix4_batch(Complex* input, size_t fft_len)
{
size_t log4_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log4_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("fft length too long for lut\n");
}
if (fft_len <= 1)
{
return;
}
quaternary_inverse_swap(input, fft_len, BATCH);
for (size_t pos = 0; pos < fft_len; pos += 4)
{
for (size_t j = 0; j < BATCH; j++)
{
fft_4point(input, j + pos * BATCH, BATCH);
}
}
UINT_32 shift = 4;
for (size_t rank = 4; rank < fft_len; rank *= 4)
{
size_t gap = rank * 4;
for (size_t begin = 0; begin < fft_len; begin += gap)
{
for (size_t j = 0; j < BATCH; j++)
{
fft_4point(input, j + begin * BATCH, rank * BATCH);
}
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
size_t count = pos - begin;
Complex omega = std::conj(TABLE.get_complex(shift, count));
Complex omega_sqr = std::conj(TABLE.get_complex(shift, count * 2));
Complex omega_cube = std::conj(TABLE.get_complex(shift, count * 3));
for (size_t j = 0; j < BATCH; j++)
{
fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, j + pos * BATCH, rank * BATCH);
}
}
}
shift += 2;
}
}
// 使用4批量计算的基4fft
void fft_radix4_bat4(Complex* input, size_t fft_len)
{
size_t log_len = std::log2(fft_len) / 2;
fft_len = 1ull << (log_len * 2);
if (fft_len > (1ull << lut_max_rank))
{
throw("fft length too long for lut\n");
}
if (fft_len == 1)
{
return;
}
if (fft_len == 4)
{
fft_4point(input, 0, 1);
return;
}
log_len *= 2;
size_t quarter_len = fft_len / 4;
fft_radix4_batch<4>(input, quarter_len);
for (size_t i = 0; i < fft_len; i += 4)
{
Complex omega = std::conj(TABLE.get_complex(log_len, i / 4));
Complex omega_sqr = std::conj(TABLE.get_complex(log_len, i / 2));
Complex omega_cube = std::conj(TABLE.get_complex(log_len, i * 3 / 4));
fft_radix4_butterfly(omega, omega_sqr, omega_cube, input, i, 1);
}
ary_interlace<4>(input, fft_len);
}
// 使用2批量计算的非基4fft
void fft_radix2_bat2(Complex* input, size_t fft_len)
{
size_t log_len = std::log2(fft_len);
fft_len = 1ull << log_len;
if (fft_len > (1ull << lut_max_rank))
{
throw("fft length too long for lut\n");
}
if (fft_len == 1)
{
return;
}
if (fft_len == 2)
{
fft_2point(input[0], input[1], input[0], input[1]);
return;
}
size_t sub_len = fft_len / 2;
fft_radix4_batch<2>(input, sub_len);
for (size_t i = 0; i < fft_len; i += 2)
{
Complex omega = std::conj(TABLE.get_complex(log_len, i / 2));
fft_radix2_butterfly(omega, input, i, 1);
}
ary_interlace<2>(input, fft_len);
}
/// @brief 查表快速傅里叶变换
/// @param input 复数组
/// @param fft_len 变换长度
/// @param is_ifft 是否为逆变换
void fft_lut(Complex* input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
if (fft_len <= 1)
{
return;
}
size_t log_len = std::log2(fft_len);
TABLE.expend(log_len);
if (is_odd(log_len))
{
fft_radix2_bat2(input, fft_len);
}
else
{
fft_radix4_bat4(input, fft_len);
}
}
/// @brief 快速傅里叶逆变换
/// @param input 输入频域数组
/// @param fft_len 变换长度
void ifft_lut(Complex* input, size_t fft_len)
{
fft_len = max_2pow(fft_len);
fft_conj(input, fft_len);
fft_lut(input, fft_len);
fft_conj(input, fft_len, fft_len);
}
/// @brief 快速哈特莱变换
/// @param input 浮点数组指针
/// @param fht_len 变换的长度
/// @param is_ifht 是否为逆变换
void fht(double* input, size_t fht_len)
{
fht_len = max_2pow(fht_len);
if (fht_len <= 1)
{
return;
}
UINT_32 log_len = std::log2(fht_len);
TABLE.expend(log_len);
binary_inverse_swap(input, fht_len);
for (size_t i = 0; i < fht_len; i += 2)
{
double tmp1 = input[i];
double tmp2 = input[i + 1];
input[i] = tmp1 + tmp2;
input[i + 1] = tmp1 - tmp2;
}
UINT_32 shift = 2;
for (size_t rank = 2; rank < fht_len; rank *= 2)
{
size_t gap = rank * 2;
size_t half = rank / 2;
for (size_t begin = 0; begin < fht_len; begin += gap)
{
size_t index1 = begin, index2 = begin + half;
size_t index3 = begin + rank, index4 = begin + half * 3;
double tmp1 = input[index1];
double tmp2 = input[index3];
input[index1] = tmp1 + tmp2;
input[index3] = tmp1 - tmp2;
tmp1 = input[index2];
tmp2 = input[index4];
input[index2] = tmp1 + tmp2;
input[index4] = tmp1 - tmp2;
for (size_t pos = 1; pos < half; pos++)
{
index1 = begin + pos;
index2 = rank + begin - pos;
index3 = rank + begin + pos;
index4 = gap + begin - pos;
double tmp1 = input[index1];
double tmp2 = input[index2];
double tmp3 = input[index3];
double tmp4 = input[index4];
Complex omega = TABLE.get_complex(shift, pos);
double t1 = tmp3 * omega.real() + tmp4 * omega.imag();
double t2 = tmp3 * omega.imag() - tmp4 * omega.real();
input[index1] = tmp1 + t1;
input[index2] = tmp2 + t2;
input[index3] = tmp1 - t1;
input[index4] = tmp2 - t2;
}
}
shift++;
}
}
void ifht(double* input, size_t fht_len)
{
fht_len = max_2pow(fht_len);
fht(input, fht_len);
double len = fht_len;
for (size_t i = 0; i < fht_len; i++)
{
input[i] /= len;
}
}
constexpr UINT_64 add_mod(UINT_64 a, UINT_64 b, UINT_64 mod)
{
a += b;
return a < mod ? a : a - mod;
}
constexpr UINT_64 sub_mod(UINT_64 a, UINT_64 b, UINT_64 mod)
{
return a < b ? a + mod - b : a - b;
}
template <INT_64 MOD>
constexpr INT_64 mul_mod(INT_64 a, INT_64 b)
{
if (((a | b) >> 32) == 0)
{
return a * b % MOD;
}
INT_64 result = a * b - static_cast<UINT_64>(1.0 * a / MOD * b) * MOD;
if (result < 0)
{
result += MOD;
}
else if (result >= MOD)
{
result -= MOD;
}
return result;
}
// 基2快速数论变换
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_radix2(UINT_32* input, size_t ntt_len, bool is_intt)
{
ntt_len = max_2pow(ntt_len);
binary_inverse_swap(input, ntt_len);
const UINT_64 g_root = is_intt ? qpow(G_ROOT, MOD - 2, MOD) : G_ROOT;
for (size_t pos = 0; pos < ntt_len; pos += 2)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + 1];
input[pos] = add_mod(tmp1, tmp2, MOD);
input[pos + 1] = sub_mod(tmp1, tmp2, MOD);
}
for (size_t rank = 2; rank < ntt_len; rank *= 2)
{
size_t gap = rank * 2;
UINT_64 unit_omega = qpow(g_root, (MOD - 1) / gap, MOD);
for (size_t begin = 0; begin < ntt_len; begin += gap)
{
UINT_32 tmp1 = input[begin];
UINT_32 tmp2 = input[begin + rank];
input[begin] = add_mod(tmp1, tmp2, MOD);
input[begin + rank] = sub_mod(tmp1, tmp2, MOD);
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = omega * input[pos + rank] % MOD;
input[pos] = add_mod(tmp1, tmp2, MOD);
input[pos + rank] = sub_mod(tmp1, tmp2, MOD);
omega = omega * unit_omega % MOD;
}
}
}
}
// 基4蝶形变换
template <UINT_64 MOD = 2281701377>
constexpr void ntt_radix4_butterfly(UINT_64 omega, UINT_64 omega_sqr, UINT_64 omega_cube,
UINT_64 quarter, UINT_32* input, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank] * omega % MOD;
UINT_32 tmp3 = input[pos + rank * 2] * omega_sqr % MOD;
UINT_32 tmp4 = input[pos + rank * 3] * omega_cube % MOD;
UINT_32 t1 = add_mod(tmp1, tmp3, MOD);
UINT_32 t2 = add_mod(tmp2, tmp4, MOD);
UINT_32 t3 = sub_mod(tmp1, tmp3, MOD);
UINT_32 t4 = (MOD + tmp2 - tmp4) * quarter % MOD;
input[pos] = add_mod(t1, t2, MOD);
input[pos + rank] = add_mod(t3, t4, MOD);
input[pos + rank * 2] = sub_mod(t1, t2, MOD);
input[pos + rank * 3] = sub_mod(t3, t4, MOD);
}
// 4点NTT
template <UINT_64 MOD = 2281701377>
constexpr void ntt_4point(UINT_32* input, UINT_64 quarter, size_t pos, size_t rank)
{
UINT_32 tmp1 = input[pos];
UINT_32 tmp2 = input[pos + rank];
UINT_32 tmp3 = input[pos + rank * 2];
UINT_32 tmp4 = input[pos + rank * 3];
UINT_32 t1 = add_mod(tmp1, tmp3, MOD);
UINT_32 t2 = add_mod(tmp2, tmp4, MOD);
UINT_32 t3 = sub_mod(tmp1, tmp3, MOD);
UINT_32 t4 = (MOD + tmp2 - tmp4) * quarter % MOD;
input[pos] = add_mod(t1, t2, MOD);
input[pos + rank] = add_mod(t3, t4, MOD);
input[pos + rank * 2] = sub_mod(t1, t2, MOD);
input[pos + rank * 3] = sub_mod(t3, t4, MOD);
}
// 批量基4快速数论变换,默认BATCH为1时为普通基4变换
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3, size_t BATCH = 1>
void ntt_radix4_batch(UINT_32* input, size_t ntt_len, bool is_intt)
{
size_t log4_len = std::log2(ntt_len) / 2;
ntt_len = 1ull << (log4_len * 2);
if (ntt_len <= 1)
{
return;
}
quaternary_inverse_swap(input, ntt_len, BATCH);
const UINT_64 g_root = is_intt ? qpow(G_ROOT, MOD - 2, MOD) : G_ROOT;
const UINT_64 quarter = qpow(g_root, (MOD - 1) / 4, MOD); // 等价于复数的i
for (size_t pos = 0; pos < ntt_len; pos += 4)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + pos * BATCH, BATCH);
}
}
for (size_t rank = 4; rank < ntt_len; rank *= 4)
{
size_t gap = rank * 4;
UINT_64 unit_omega = qpow(g_root, (MOD - 1) / gap, MOD);
for (size_t begin = 0; begin < ntt_len; begin += gap)
{
for (size_t j = 0; j < BATCH; j++)
{
ntt_4point<MOD>(input, quarter, j + begin * BATCH, rank * BATCH);
}
UINT_64 omega = unit_omega;
for (size_t pos = begin + 1; pos < begin + rank; pos++)
{
UINT_64 omega_sqr = omega * omega % MOD;
UINT_64 omega_cube = omega_sqr * omega % MOD;
for (size_t j = 0; j < BATCH; j++)
{
ntt_radix4_butterfly<MOD>(omega, omega_sqr, omega_cube, quarter, input, j + pos * BATCH, rank * BATCH);
}
omega = omega * unit_omega % MOD;
}
}
}
}
// 单线程NTT
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_single(UINT_32* input, size_t ntt_len, bool is_intt)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
size_t log_len = std::log2(ntt_len);
if (is_odd(log_len))
{
size_t half_len = ntt_len / 2;
ntt_radix4_batch<MOD, G_ROOT, 2>(input, half_len, is_intt);
const UINT_64 g_root = is_intt ? qpow(G_ROOT, MOD - 2, MOD) : G_ROOT;
const UINT_64 unit_omega = qpow(g_root, (MOD - 1) / ntt_len, MOD);
UINT_64 omega = 1;
for (size_t i = 0; i < ntt_len; i += 2)
{
UINT_32 tmp1 = input[i];
UINT_32 tmp2 = input[i + 1] * omega % MOD;
input[i] = add_mod(tmp1, tmp2, MOD);
input[i + 1] = sub_mod(tmp1, tmp2, MOD);
omega = omega * unit_omega % MOD;
}
UINT_32* tmp_ary = new UINT_32[half_len];
for (size_t i = 0; i < ntt_len; i += 2)
{
input[i / 2] = input[i];
tmp_ary[i / 2] = input[i + 1];
}
ary_copy(input + half_len, tmp_ary, half_len);
delete[] tmp_ary;
}
else
{
size_t quarter_len = ntt_len / 4;
ntt_radix4_batch<MOD, G_ROOT, 4>(input, quarter_len, is_intt);
const UINT_64 g_root = is_intt ? qpow(G_ROOT, MOD - 2, MOD) : G_ROOT;
const UINT_64 quarter = qpow(g_root, (MOD - 1) / 4, MOD); // 等价于复数的i
const UINT_64 unit_omega = qpow(g_root, (MOD - 1) / ntt_len, MOD);
UINT_64 omega = 1;
for (size_t i = 0; i < ntt_len; i += 4)
{
UINT_64 omega_sqr = omega * omega % MOD;
UINT_64 omega_cube = omega_sqr * omega % MOD;
ntt_radix4_butterfly<MOD>(omega, omega_sqr, omega_cube, quarter, input, i, 1);
omega = omega * unit_omega % MOD;
}
UINT_32* tmp_ary1 = new UINT_32[quarter_len * 3];
UINT_32* tmp_ary2 = tmp_ary1 + quarter_len;
UINT_32* tmp_ary3 = tmp_ary2 + quarter_len;
for (size_t i = 0; i < ntt_len; i += 4)
{
input[i / 4] = input[i];
tmp_ary1[i / 4] = input[i + 1];
tmp_ary2[i / 4] = input[i + 2];
tmp_ary3[i / 4] = input[i + 3];
}
ary_copy(input + quarter_len, tmp_ary1, quarter_len * 3);
delete[] tmp_ary1;
}
}
// 双线程NTT
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt_dual(UINT_32* input, size_t ntt_len, bool is_intt)
{
ntt_len = max_2pow(ntt_len);
if (ntt_len <= 1)
{
return;
}
size_t half_len = ntt_len / 2;
UINT_32* tmp_ary = new UINT_32[half_len];
for (size_t i = 0; i < ntt_len; i += 2)
{
input[i / 2] = input[i];
tmp_ary[i / 2] = input[i + 1];
}
ary_copy(input + half_len, tmp_ary, half_len);
delete[] tmp_ary;
std::future<void> th = std::async(ntt_single<MOD, G_ROOT>, input, half_len, is_intt);
ntt_single<MOD, G_ROOT>(input + half_len, half_len, is_intt);
th.wait();
const UINT_64 g_root = is_intt ? qpow(G_ROOT, MOD - 2, MOD) : G_ROOT;
const UINT_64 unit_omega = qpow(g_root, (MOD - 1) / ntt_len, MOD);
UINT_64 omega2 = qpow(g_root, (MOD - 1) / 4, MOD);
auto merge_proc = [=](size_t start, size_t end, UINT_64 omega_start)
{
UINT_64 omega = omega_start;
for (size_t i = start; i < end; i++)
{
UINT_32 tmp1 = input[i];
UINT_32 tmp2 = input[i + half_len] * omega % MOD;
input[i] = add_mod(tmp1, tmp2, MOD);
input[i + half_len] = sub_mod(tmp1, tmp2, MOD);
omega = omega * unit_omega % MOD;
}
};
th = std::async(merge_proc, 0, half_len / 2, 1);
merge_proc(half_len / 2, half_len, omega2);
th.wait();
}
/// @brief 快速数论变换
/// @tparam T 输入整数组类型
/// @tparam MOD 模数
/// @tparam G_ROOT 原根
/// @param input 输入数组
/// @param ntt_len 数组长度
/// @param is_intt 是否为逆变换
/// @param multi_threads 是否为多线程
template <UINT_64 MOD = 2281701377, UINT_64 G_ROOT = 3>
void ntt(UINT_32* input, size_t ntt_len, bool is_intt, bool multi_threads = false)
{
if (multi_threads)
{
ntt_dual<MOD, G_ROOT>(input, ntt_len, is_intt);
}
else
{
ntt_single<MOD, G_ROOT>(input, ntt_len, is_intt);
}
if (is_intt) // 逆变换需乘以以n的逆元
{
UINT_64 inv = qpow(ntt_len, MOD - 2, MOD);
for (size_t i = 0; i < ntt_len; ++i)
{
input[i] = input[i] * inv % MOD;
}
}
}
// 数组按位相乘
template <typename T>
inline void ary_mul(const T in1[], const T in2[], T out[], size_t len)
{
for (size_t i = 0; i < len; i++)
{
out[i] = in1[i] * in2[i];
}
}
// 数组按位带模相乘,4路循环展开
template <UINT_64 MOD, typename T>
constexpr void ary_mul_mod(const T in1[], const T in2[], T out[], size_t len)
{
size_t mod4 = len % 4;
len -= mod4;
for (size_t i = 0; i < len; i += 4)
{
out[i] = static_cast<UINT_64>(in1[i]) * in2[i] % MOD;
out[i + 1] = static_cast<UINT_64>(in1[i + 1]) * in2[i + 1] % MOD;
out[i + 2] = static_cast<UINT_64>(in1[i + 2]) * in2[i + 2] % MOD;
out[i + 3] = static_cast<UINT_64>(in1[i + 3]) * in2[i + 3] % MOD;
}
for (size_t i = len; i < len + mod4; i++)
{
out[i] = static_cast<UINT_64>(in1[i]) * in2[i] % MOD;
}
}
template <typename T>
constexpr void normal_convolution(const T in1[], const T in2[], T out[],
size_t len1, size_t len2)
{
ary_clr(out, len1 + len2 - 1);
for (size_t i = 0; i < len1; i++)
{
T num1 = in1[i];
for (size_t j = 0; j < len2; j++)
{
out[i + j] = num1 * in2[j];
}
}
}
inline void fft_lut_convolution(Complex fft_ary1[], Complex fft_ary2[], Complex out[], size_t fft_len)
{
fft_lut(fft_ary1, fft_len);
if (fft_ary1 != fft_ary2)
{
fft_lut(fft_ary2, fft_len);
}
ary_mul(fft_ary1, fft_ary2, out, fft_len);
ifft_lut(out, fft_len);
}
void fht_convolution(double fht_ary1[], double fht_ary2[], double out[], size_t fht_len)
{
fht(fht_ary1, fht_len);
if (fht_ary1 != fht_ary2)
{
fht(fht_ary2, fht_len);
}
out[0] = fht_ary1[0] * fht_ary2[0];
for (size_t i = 1; i < fht_len; ++i)
{
double tmp1 = fht_ary1[i], tmp2 = fht_ary1[fht_len - i];
out[i] = (fht_ary2[i] * (tmp1 + tmp2) + fht_ary2[fht_len - i] * (tmp1 - tmp2)) / 2;
}
ifht(out, fht_len);
}
void ntt_convolution(UINT_32 ntt_ary1[], UINT_32 ntt_ary2[], UINT_64 out[], size_t ntt_len) // 数论变换卷积分
{
constexpr UINT_64 mod1 = NTT_MOD1, mod2 = NTT_MOD2;
constexpr UINT_64 root1 = NTT_ROOT1, root2 = NTT_ROOT2;
UINT_32* ntt_ary3 = nullptr, * ntt_ary4 = nullptr;
if (ntt_ary1 == ntt_ary2)
{
ntt_ary3 = ntt_ary4 = new UINT_32[ntt_len];
ary_copy(ntt_ary3, ntt_ary1, ntt_len);
}
else
{
ntt_ary3 = new UINT_32[ntt_len * 2];
ntt_ary4 = ntt_ary3 + ntt_len;
ary_copy(ntt_ary3, ntt_ary1, ntt_len);
ary_copy(ntt_ary4, ntt_ary2, ntt_len);
}
bool multi_threads = false;
#ifdef MULTITHREAD
multi_threads = hint_threads >= 2;
#endif
ntt<mod1, root1>(ntt_ary1, ntt_len, false, multi_threads); // 多线程快速数论变换
ntt<mod2, root2>(ntt_ary3, ntt_len, false, multi_threads);
if (ntt_ary1 != ntt_ary2)
{
ntt<mod1, root1>(ntt_ary2, ntt_len, false, multi_threads);
ntt<mod2, root2>(ntt_ary4, ntt_len, false, multi_threads);
}
ary_mul_mod<mod1>(ntt_ary2, ntt_ary1, ntt_ary1, ntt_len);
ary_mul_mod<mod2>(ntt_ary4, ntt_ary3, ntt_ary3, ntt_len); // 每一位相乘
ntt<mod1, root1>(ntt_ary1, ntt_len, true, multi_threads);
ntt<mod2, root2>(ntt_ary3, ntt_len, true, multi_threads);
constexpr UINT_64 inv1 = qpow(mod1, mod2 - 2, mod2);
constexpr UINT_64 inv2 = qpow(mod2, mod1 - 2, mod1);
for (size_t i = 0; i < ntt_len; i++)
{
out[i] = qcrt(ntt_ary1[i], ntt_ary3[i], mod1, mod2, inv1, inv2);
} // 使用中国剩余定理变换
delete[] ntt_ary3;
}
std::string ui64to_string(UINT_64 input, UINT_8 digits)
{
std::string result(digits, '0');
for (UINT_8 i = 0; i < digits; i++)
{
result[digits - i - 1] = static_cast<char>(input % 10 + '0');
input /= 10;
}
return result;
}
UINT_64 stoui64(const std::string::const_iterator& begin,
const std::string::const_iterator& end, const UINT_32 base = 10)
{
UINT_64 result = 0;
for (auto it = begin; it < end && it - begin < 19; ++it)
{
result *= base;
char c = tolower(*it);
UINT_64 n = 0;
if (isalnum(c))
{
if (isalpha(c))
{
n = c - 'a' + 10;
}
else
{
n = c - '0';
}
}
if (n < base)
{
result += n;
}
}
return result;
}
UINT_64 stoui64(const std::string& str, const UINT_32 base = 10)
{
return stoui64(str.begin(), str.end(), base);
}
template <typename T = UINT_32, typename SIZE_TYPE = UINT_32>
class HintVector
{
private:
T* ary_ptr = nullptr;
SIZE_TYPE sign_n_len = 0;
SIZE_TYPE size = 0;
public:
static constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8; // size和len成员的比特数
static constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1)); // 第一位为1,其余位为0
static constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1; // 定义最大长度
static constexpr SIZE_TYPE SIZE_FF = SIZE_80 + LEN_MAX; // 所有比特为1的数
~HintVector()
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
ary_ptr = nullptr;
}
}
constexpr HintVector()
{
ary_ptr = nullptr;
}
HintVector(SIZE_TYPE ele_size)
{
if (ele_size > 0)
{
resize(ele_size);
}
}
HintVector(SIZE_TYPE ele_size, T ele)
{
if (ele_size > 0)
{
resize(ele_size);
change_length(ele_size);
fill_element(ele, ele_size);
}
}
HintVector(const T* ary, SIZE_TYPE len)
{
if (ary == nullptr)
{
throw("Can't copy from nullptr\n");
}
if (len > 0)
{
resize(len);
change_length(len);
ary_copy(ary_ptr, ary, len);
}
}
HintVector(const HintVector& input)
{
if (input.ary_ptr != nullptr)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
set_true_len();
}
}
HintVector(HintVector&& input)
{
if (input.ary_ptr != nullptr)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
set_true_len();
}
}
HintVector& operator=(const HintVector& input)
{
if (this != &input)
{
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = nullptr;
size_t len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy(ary_ptr, input.ary_ptr, len);
set_true_len();
}
return *this;
}
HintVector& operator=(HintVector&& input)
{
if (this != &input)
{
size = input.size;
change_length(input.length());
change_sign(input.sign());
if (ary_ptr != nullptr)
{
delete[] ary_ptr;
}
ary_ptr = input.ary_ptr;
input.ary_ptr = nullptr;
set_true_len();
}
return *this;
}
template <typename Ty, typename SIZE_Ty>
void copy_from(const HintVector<Ty, SIZE_Ty>& input)
{
if (this != &input)
{
if (input.length() > LEN_MAX)
{
std::cerr << "HintVector err: too long\n";
exit(EXIT_FAILURE);
}
delete[] ary_ptr;
ary_ptr = nullptr;
SIZE_TYPE len = input.length();
resize(len);
change_length(len);
change_sign(input.sign());
ary_copy_2type(ary_ptr, input.raw_ptr(), len);
}
}
template <typename Ty>
void copy_from(const Ty* input, SIZE_TYPE len)
{
if (input == nullptr)
{
throw("Can't copy from nullptr\n");
}
len = min(len, length());
ary_copy(ary_ptr, input, len);
}
T front() const
{
if (length() == 0)
{
return 0;
}
return ary_ptr[0];
}
T back() const
{
if (length() == 0)
{
return 0;
}
return ary_ptr[length() - 1];
}
T& operator[](SIZE_TYPE index) const
{
return ary_ptr[index];
}
void* raw_ptr() const
{
if (ary_ptr == nullptr)
{
std::cerr << "HintVector err: Get a ptr from an empty vector\n";
exit(EXIT_FAILURE);
}
return ary_ptr;
}
T* type_ptr() const
{
if (ary_ptr == nullptr)
{
std::cerr << "HintVector err: Get a ptr from an empty vector\n";
exit(EXIT_FAILURE);
}
return ary_ptr;
}
static SIZE_TYPE size_generator(SIZE_TYPE new_size)
{
new_size = std::min(new_size, LEN_MAX);
if (new_size <= 2)
{
return 2;
}
else if (new_size <= 4)
{
return 4;
}
SIZE_TYPE size1 = min_2pow(new_size);
SIZE_TYPE size2 = size1 / 2;
size2 = size2 + size2 / 2;
if (new_size <= size2)
{
return size2;
}
else
{
return size1;
}
}
void print() const
{
SIZE_TYPE i = length();
if (i == 0 || ary_ptr == nullptr)
{
std::cout << "Empty vector" << std::endl;
}
while (i > 0)
{
i--;
std::cout << ary_ptr[i] << "\t";
}
std::cout << std::endl;
}
SIZE_TYPE capacity() const
{
return size;
}
SIZE_TYPE length() const
{
return sign_n_len & LEN_MAX;
}
bool sign() const
{
return (SIZE_80 & sign_n_len) != 0;
}
constexpr void resize(SIZE_TYPE new_size, SIZE_TYPE(size_func)(SIZE_TYPE) = size_generator)
{
new_size = size_func(new_size);
if (ary_ptr == nullptr)
{
size = new_size;
ary_ptr = new T[size];
change_length(0);
}
else if (size != new_size)
{
size = new_size;
ary_ptr = ary_realloc(ary_ptr, size);
change_length(std::min(length(), size));
}
}
void change_length(SIZE_TYPE new_length)
{
if (new_length > LEN_MAX)
{
throw("Length too long\n");
}
new_length = std::min(new_length, size);
bool sign_tmp = sign();
sign_n_len = new_length;
change_sign(sign_tmp);
}
void change_sign(bool is_sign)
{
if ((!is_sign) || length() == 0)
{
sign_n_len = sign_n_len & LEN_MAX;
}
else
{
sign_n_len = sign_n_len | SIZE_80;
}
}
SIZE_TYPE set_true_len()
{
if (ary_ptr == nullptr)
{
return 0;
}
SIZE_TYPE t_len = length();
while (t_len > 0 && ary_ptr[t_len - 1] == 0)
{
t_len--;
}
change_length(t_len);
return length();
}
void fill_element(T ele, SIZE_TYPE count, SIZE_TYPE begin = 0)
{
if (begin >= size || ary_ptr == nullptr)
{
return;
}
if (begin + count >= size)
{
count = size - begin;
}
std::fill(ary_ptr + begin, ary_ptr + begin + count, ele);
}
void clear()
{
ary_clr(ary_ptr, size);
}
};
}
#endif
#include "hint_math.hpp"
#ifndef HINT_MINI_HPP
#define HINT_MINI_HPP
namespace hint_arithm
{
using namespace hint;
using SIZE_TYPE = UINT_32;
template <typename T>
using hintvector = HintVector<T, SIZE_TYPE>;
template <typename T>
void ary_print(T ary[], size_t len)
{
size_t i = len;
while (i > 0)
{
i--;
std::cout << ary[i] << "\t";
}
std::cout << "\n";
}
// 按位与
template <typename T>
constexpr void ary_and(const T in1[], const T in2[], T out[], size_t len1, size_t len2)
{
size_t len = std::min(len1, len2);
size_t mod4 = len % 4;
len -= mod4;
for (size_t i = 0; i < len; i += 4)
{
out[i] = in1[i] & in2[i];
out[i + 1] = in1[i + 1] & in2[i + 1];
out[i + 2] = in1[i + 2] & in2[i + 2];
out[i + 3] = in1[i + 3] & in2[i + 3];
}
for (size_t i = len; i < len + mod4; i++)
{
out[i] = in1[i] & in2[i];
}
}
// 按位或
template <typename T>
constexpr void ary_or(const T in1[], const T in2[], T out[], size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t mod4 = len2 % 4;
len2 -= mod4;
for (size_t i = 0; i < len2; i++)
{
out[i] = in1[i] | in2[i];
out[i + 1] = in1[i + 1] | in2[i + 1];
out[i + 2] = in1[i + 2] | in2[i + 2];
out[i + 3] = in1[i + 3] | in2[i + 3];
}
for (size_t i = len2; i < len2 + mod4; i++)
{
out[i] = in1[i] | in2[i];
}
len2 += mod4;
ary_copy(out + len2, in1 + len2, len1 - len2);
}
// 按位异或
template <typename T>
constexpr void ary_xor(const T in1[], const T in2[], T out[], size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t mod4 = len2 % 4;
len2 -= mod4;
for (size_t i = 0; i < len2; i++)
{
out[i] = in1[i] ^ in2[i];
out[i + 1] = in1[i + 1] ^ in2[i + 1];
out[i + 2] = in1[i + 2] ^ in2[i + 2];
out[i + 3] = in1[i + 3] ^ in2[i + 3];
}
for (size_t i = len2; i < len2 + mod4; i++)
{
out[i] = in1[i] ^ in2[i];
}
len2 += mod4;
ary_copy(out + len2, in1 + len2, len1 - len2);
}
// 按位取反
template <typename T>
constexpr void ary_not(const T in[], T out[], size_t len)
{
size_t mod4 = len % 4;
len -= mod4;
for (size_t i = 0; i < len; i += 4)
{
out[i] = ~in[i];
out[i + 1] = ~in[i + 1];
out[i + 2] = ~in[i + 2];
out[i + 3] = ~in[i + 3];
}
for (size_t i = len; i < len + mod4; i++)
{
out[i] = ~in[i];
}
}
// 高精度绝对值比较,前者大于后者返回1,小于返回-1等于返回0
template <typename T>
constexpr INT_32 abs_compare(const T ary1[], const T ary2[], size_t len1, size_t len2)
{
len1 = ary_true_len(ary1, len1);
len2 = ary_true_len(ary2, len2);
if (len1 != len2)
{
return len1 > len2 ? 1 : -1;
}
if (ary1 == ary2)
{
return 0;
}
while (len1 > 0)
{
len1--;
T num1 = ary1[len1], num2 = ary2[len1];
if (num1 != num2)
{
return num1 > num2 ? 1 : -1;
}
}
return 0;
}
// 与左移后的ary2绝对值比较,前者大于后者返回1,小于返回-1等于返回0
template <typename T>
constexpr INT_32 abs_compare_shift(const T ary1[], const T ary2[], size_t len1, size_t len2, size_t shift = 0)
{
len1 = ary_true_len(ary1, len1);
len2 = ary_true_len(ary2, len2);
if (len1 != len2 + shift)
{
return len1 > (len2 + shift) ? 1 : -1;
}
INT_32 cmp = abs_compare(ary1 + shift, ary2, len1, len2);
if (cmp != 0)
{
return cmp;
}
for (size_t i = 0; i < shift; i++)
{
if (ary1[i] > 0)
{
return 1;
}
}
return 0;
}
// 高精度加法
template <bool is_carry = true, typename T>
constexpr void abs_add(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
carry += (in1[pos] + in2[pos]);
out[pos] = carry < base ? carry : carry - base;
carry = carry < base ? 0 : 1;
pos++;
}
while (pos < len1 && carry > 0)
{
carry += in1[pos];
out[pos] = carry < base ? carry : carry - base;
carry = carry < base ? 0 : 1;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
if (is_carry)
{
out[len1] = carry % base;
}
}
// 高精度减法
template <typename T>
constexpr void abs_sub(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
if (len1 < len2)
{
return;
}
size_t pos = 0;
INT_64 borrow = 0;
while (pos < len2)
{
borrow += (static_cast<INT_64>(in1[pos]) - in2[pos]);
out[pos] = borrow < 0 ? borrow + base : borrow;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
while (pos < len1 && borrow < 0)
{
borrow += in1[pos];
out[pos] = borrow < 0 ? borrow + base : borrow;
borrow = borrow < 0 ? -1 : 0;
pos++;
}
ary_copy(out + pos, in1 + pos, len1 - pos);
}
// 64位搞精度加法
constexpr void abs_add64(const UINT_64 in1[], const UINT_64 in2[], UINT_64 out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
bool is_carry1 = false, is_carry2 = false;
std::tie(carry, is_carry1) = safe_add(carry, in1[pos]);
std::tie(carry, is_carry2) = safe_add(carry, in2[pos]);
out[pos] = carry;
carry = is_carry1 || is_carry2 ? 1 : 0;
pos++;
}
while (pos < len1)
{
bool is_carry = false;
std::tie(carry, is_carry) = safe_add(carry, in1[pos]);
out[pos] = carry;
carry = carry ? 1 : 0;
pos++;
}
out[len1] = carry;
}
// 64位多精度减法
constexpr void abs_sub64(const UINT_64 in1[], const UINT_64 in2[], UINT_64 out[],
size_t len1, size_t len2)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
size_t pos = 0;
UINT_64 carry = 0;
while (pos < len2)
{
bool is_carry1 = false, is_carry2 = false;
std::tie(carry, is_carry1) = safe_add(carry, in1[pos]);
std::tie(carry, is_carry2) = safe_add(carry, in2[pos]);
out[pos] = carry;
carry = is_carry1 || is_carry2 ? 1 : 0;
pos++;
}
while (pos < len1)
{
bool is_carry = false;
std::tie(carry, is_carry) = safe_add(carry, in1[pos]);
out[pos] = carry;
carry = carry ? 1 : 0;
pos++;
}
out[len1] = carry;
}
// 小学乘法
template <typename T>
void normal_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
if (len1 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
T *res = out;
if (in1 == out || in2 == out)
{
res = new T[len1 + len2];
}
ary_clr(res, len1 + len2);
for (size_t i = 0; i < len1; i++)
{
UINT_64 num1 = in1[i];
for (size_t j = 0; j < len2; j++)
{
UINT_64 tmp = num1 * in2[j];
for (size_t k = i + j; tmp > 0 && k < len1 + len2; k++)
{
tmp += res[k];
std::tie(tmp, res[k]) = div_mod(tmp, base);
}
}
}
if (res != out)
{
ary_copy(out, res, len1 + len2);
delete[] res;
}
}
// fft加速乘法
template <typename T>
void fft_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 == 0 || len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
size_t conv_res_len = len1 + len2 - 1; // 卷积结果长度
size_t fft_len = min_2pow(conv_res_len); // fft长度
Complex *fft_ary = new Complex[fft_len];
com_ary_combine_copy(fft_ary, in1, len1, in2, len2);
hint::fft_lut(fft_ary, fft_len);
for (size_t i = 0; i < fft_len; i++)
{
Complex tmp = fft_ary[i];
fft_ary[i] = std::conj(tmp * tmp);
}
hint::fft_lut(fft_ary, fft_len);
hint::UINT_64 carry = 0;
double inv = 2 * fft_len;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += static_cast<hint::UINT_64>(-fft_ary[i].imag() / inv + 0.5);
std::tie(carry, out[i]) = div_mod<UINT_64>(carry, base);
}
out[conv_res_len] = carry % base;
delete[] fft_ary;
}
// fht加速乘法
template <typename T>
void fht_sqr(const T in[], T out[], size_t len, const UINT_64 base)
{
if (len == 0 || in == nullptr)
{
return;
}
size_t conv_res_len = len * 2 - 1; // 卷积结果长度
size_t fht_len = min_2pow(conv_res_len); // fht长度
double *fht_ary = new double[fht_len * 2];
ary_clr(fht_ary, fht_len);
ary_copy_2type(fht_ary, in, len);
fht_convolution(fht_ary, fht_ary, fht_ary + fht_len, fht_len);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += static_cast<hint::UINT_64>(fht_ary[i + fht_len] + 0.5);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] fht_ary;
}
// ntt加速乘法
template <typename T>
void ntt_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const UINT_64 base)
{
if (len1 == 0 || len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
size_t conv_res_len = len1 + len2 - 1; // 卷积结果长度
size_t ntt_len = min_2pow(conv_res_len); // ntt长度
UINT_32 *ntt_ary1 = new UINT_32[ntt_len * 4];
ary_clr(ntt_ary1, ntt_len * 4);
UINT_32 *ntt_ary2 = ntt_ary1 + ntt_len;
UINT_32 *ntt_ary3 = ntt_ary1 + ntt_len * 2;
UINT_32 *ntt_ary4 = ntt_ary1 + ntt_len * 3;
hint::ary_copy_2type(ntt_ary1, in1, len1);
hint::ary_copy_2type(ntt_ary2, in1, len1);
hint::ary_copy_2type(ntt_ary3, in2, len2);
hint::ary_copy_2type(ntt_ary4, in2, len2);
bool is_multithread = false;
#ifdef MULTITHREAD
is_multithread = hint_threads >= 2 && ntt_len >= 1 << 18;
#endif
constexpr UINT_64 mod1 = NTT_MOD1, mod2 = NTT_MOD2;
constexpr UINT_64 root1 = NTT_ROOT1, root2 = NTT_ROOT2;
hint::ntt<mod1, root1>(ntt_ary1, ntt_len, false, is_multithread);
hint::ntt<mod2, root2>(ntt_ary2, ntt_len, false, is_multithread);
hint::ntt<mod1, root1>(ntt_ary3, ntt_len, false, is_multithread);
hint::ntt<mod2, root2>(ntt_ary4, ntt_len, false, is_multithread);
hint::ary_mul_mod<mod1>(ntt_ary1, ntt_ary3, ntt_ary1, ntt_len);
hint::ary_mul_mod<mod2>(ntt_ary2, ntt_ary4, ntt_ary2, ntt_len);
hint::ntt<mod1, root1>(ntt_ary1, ntt_len, true, is_multithread);
hint::ntt<mod2, root2>(ntt_ary2, ntt_len, true, is_multithread);
constexpr UINT_64 inv1 = qpow(mod1, mod2 - 2, mod2);
constexpr UINT_64 inv2 = qpow(mod2, mod1 - 2, mod1);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += qcrt(ntt_ary1[i], ntt_ary2[i], mod1, mod2, inv1, inv2);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] ntt_ary1;
}
// ntt加速平方
template <typename T>
void ntt_sqr(const T in[], T out[], size_t len, const UINT_64 base)
{
if (len == 0 || in == nullptr)
{
return;
}
size_t conv_res_len = len * 2 - 1; // 卷积结果长度
size_t ntt_len = min_2pow(conv_res_len); // ntt长度
UINT_32 *ntt_ary1 = new UINT_32[ntt_len * 2];
ary_clr(ntt_ary1, ntt_len * 2);
UINT_32 *ntt_ary2 = ntt_ary1 + ntt_len;
hint::ary_copy_2type(ntt_ary1, in, len);
hint::ary_copy(ntt_ary2, ntt_ary1, len);
bool is_multithread = false;
#ifdef MULTITHREAD
is_multithread = hint_threads >= 2 && ntt_len >= 1 << 18;
#endif
constexpr UINT_64 mod1 = NTT_MOD1, mod2 = NTT_MOD2;
constexpr UINT_64 root1 = NTT_ROOT1, root2 = NTT_ROOT2;
hint::ntt<mod1, root1>(ntt_ary1, ntt_len, false, is_multithread);
hint::ntt<mod2, root2>(ntt_ary2, ntt_len, false, is_multithread);
hint::ary_mul_mod<mod1>(ntt_ary1, ntt_ary1, ntt_ary1, ntt_len);
hint::ary_mul_mod<mod2>(ntt_ary2, ntt_ary2, ntt_ary2, ntt_len);
hint::ntt<mod1, root1>(ntt_ary1, ntt_len, true, is_multithread);
hint::ntt<mod2, root2>(ntt_ary2, ntt_len, true, is_multithread);
constexpr UINT_64 inv1 = qpow(mod1, mod2 - 2, mod2);
constexpr UINT_64 inv2 = qpow(mod2, mod1 - 2, mod1);
hint::UINT_64 carry = 0;
for (size_t i = 0; i < conv_res_len; i++)
{
carry += qcrt(ntt_ary1[i], ntt_ary2[i], mod1, mod2, inv1, inv2);
std::tie(carry, out[i]) = div_mod(carry, base);
}
out[conv_res_len] = carry % base;
delete[] ntt_ary1;
}
// karatsuba乘法
template <typename T, typename NTT_Ty = UINT_16>
constexpr void karatsuba_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
// (a*base^n+b)*(c*base^n+d) = a*c*base^2n+(a*d+b*c)*base^n+b*d
// compute: a*c,b*d,(a+b)*(c+d),a*b+b*c = (a+b)*(c+d)-a*c-b*d
len1 = ary_true_len(in1, len1);
len2 = ary_true_len(in2, len2);
if (len1 < len2)
{
std::swap(in1, in2);
std::swap(len1, len2);
}
// std::cin.get();
if (len2 == 0 || in1 == nullptr || in2 == nullptr)
{
return;
}
if (len1 + len2 - 1 <= NTT_MAX_LEN)
{
const size_t ntt_len1 = len1 * std::max<size_t>(1, sizeof(T) / sizeof(NTT_Ty));
const size_t ntt_len2 = len2 * std::max<size_t>(1, sizeof(T) / sizeof(NTT_Ty));
if (in1 == in2 && len1 == len2)
{
ntt_sqr<NTT_Ty>(in1, out, ntt_len1, base);
}
else
{
ntt_mul<NTT_Ty>(in1, in2, out, ntt_len1, ntt_len2, base);
}
return;
}
size_t len_a = len1 / 2;
size_t len_b = len1 - len_a; // 公共长度
size_t len_c = len2 > len_b ? len2 - len_b : 1;
size_t len_d = len2 > len_b ? len_b : len2;
size_t len_ac = len_c > 0 ? len_a + len_c : 0; // a*c的长度
size_t len_bd = len_b + len_d; // b*d的长度
size_t len_add_mul = len_b + len_d + 2; //(a+b)*(c*d)的长度
const T *a_ptr = in1 + len_b; // in1代表b
const T *c_ptr = in2 + len_d; // in2代表d
hintvector<T> mul_ac(len_ac, 0); // 存储a*c
hintvector<T> mul_bd(len_bd, 0); // 存储b*d
hintvector<T> add_mul(len_add_mul, 0); // 存储a+b与c+d,a+b的长度为len_b+1
T *add_ab = add_mul.type_ptr();
T *add_cd = add_ab + len_b + 1;
abs_add(in1, a_ptr, add_ab, len_b, len_a, base); // b+a
abs_add(in2, c_ptr, add_cd, len_d, len_c, base); // d+c
size_t len_add_ab = ary_true_len(add_ab, len_b + 1);
size_t len_add_cd = ary_true_len(add_cd, len_d + 1);
karatsuba_mul(a_ptr, c_ptr, mul_ac.type_ptr(), len_a, len_c, base); // a*c
karatsuba_mul(in1, in2, mul_bd.type_ptr(), len_b, len_d, base); // b*d
karatsuba_mul(add_ab, add_cd, add_mul.type_ptr(), len_b + 1, len_d + 1, base); //(a+b)*(c+d)
add_mul.change_length(len_add_ab + len_add_cd);
len_ac = mul_ac.set_true_len();
len_bd = mul_bd.set_true_len();
len_add_mul = add_mul.set_true_len();
ary_clr(out, len1 + len2);
ary_copy(out, mul_bd.type_ptr(), len_bd); // 结果加上b*d
ary_copy(out + len_b * 2, mul_ac.type_ptr(), len_ac);
INT_64 carry = 0;
for (size_t pos = len_b; pos < len1 + len2; pos++)
{
size_t t_pos = pos - len_b;
carry += out[pos];
if (t_pos < len_add_mul)
{
carry += add_mul[t_pos];
}
if (t_pos < len_ac)
{
carry -= mul_ac[t_pos];
}
if (t_pos < len_bd)
{
carry -= mul_bd[t_pos];
}
INT_64 rem = carry % base;
carry = carry < 0 ? (carry + 1) / base - 1 : carry / base;
out[pos] = rem < 0 ? rem + base : rem;
}
}
// 高精度乘法
template <typename T>
constexpr void abs_mul(const T in1[], const T in2[], T out[],
size_t len1, size_t len2, const INT_64 base)
{
len1 = ary_true_len(in1, len1);
len2 = ary_true_len(in2, len2);
if (len1 + len2 <= 48 || len1 * len2 < (len1 + len2) * std::log2(len1 + len2))
{
normal_mul(in1, in2, out, len1, len2, base);
}
else if (len1 + len2 - 1 <= (1 << hint::lut_max_rank))
{
fft_mul(in1, in2, out, len1, len2, base);
}
else if (len1 + len2 - 1 <= NTT_MAX_LEN)
{
ntt_mul(in1, in2, out, len1, len2, base);
}
else
{
karatsuba_mul(in1, in2, out, len1, len2, base);
}
}
// 高精度平方
template <typename T>
constexpr void abs_sqr(const T in[], T out[], size_t len, const INT_64 base)
{
len = ary_true_len(in, len);
if (len <= 24)
{
normal_mul(in, in, out, len, len, base);
}
else if (len * 2 - 1 <= (1 << hint::lut_max_rank))
{
fht_sqr(in, out, len, base);
}
else if (len * 2 - 1 <= NTT_MAX_LEN)
{
ntt_sqr(in, out, len, base);
}
else
{
karatsuba_mul(in, in, out, len, len, base);
}
}
// 高精度乘低精度
template <bool is_carry = true, typename T>
constexpr void abs_mul_num(const T in[], T num, T out[], size_t len, const UINT_64 base)
{
len = ary_true_len(in, len);
num %= base;
UINT_64 prod = 0;
if (num == 1)
{
ary_copy(out, in, len);
}
else
{
for (size_t i = 0; i < len; i++)
{
prod += static_cast<UINT_64>(in[i]) * num;
std::tie(prod, out[i]) = div_mod<UINT_64>(prod, base);
}
}
if (is_carry)
{
out[len] = prod;
}
}
// 除以num的同时返回余数
template <typename T>
constexpr INT_64 abs_div_num(const T in[], T num, T out[], size_t len, const UINT_64 base)
{
size_t pos = ary_true_len(in, len);
num %= base;
if (num == 1)
{
ary_copy(out, in, len);
return 0;
}
UINT_64 last_rem = 0;
while (pos > 0)
{
pos--;
last_rem = last_rem * base + in[pos];
std::tie(out[pos], last_rem) = div_mod<UINT_64>(last_rem, num);
}
return last_rem;
}
// 除数的规则化
template <typename T>
constexpr T divisor_normalize(const T in[], T out[], size_t len, const UINT_64 base)
{
if (in == out)
{
throw("In can't be same as out\n");
}
T multiplier = 1;
if (len == 1)
{
multiplier = (base - 1) / in[0];
out[0] = in[0] * multiplier;
}
else if (in[len - 1] >= (base / 2))
{
ary_copy(out, in, len);
}
else
{
multiplier = (base - 1) * base / (base * in[len - 1] + in[len - 2]);
abs_mul_num<false>(in, multiplier, out, len, base);
if (out[len - 1] < (base / 2))
{
multiplier++;
abs_add<false>(out, in, out, len, len, base);
}
}
return multiplier;
}
// 长除法,从被除数返回余数,需要确保除数的规则化
template <typename T>
void abs_long_div(T dividend[], const T divisor[], T quot[],
size_t len1, size_t len2, const UINT_64 base)
{
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (len2 == 1)
{
T rem = abs_div_num(dividend, divisor[0], quot, len1, base);
dividend[0] = rem;
return;
}
if (divisor[len2 - 1] < (base / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
quot[len1 - len2] = 0;
const UINT_64 divisor_2digits = base * divisor[len2 - 1] + divisor[len2 - 2];
hintvector<T> sub(len2 + 1);
sub.change_length(len2 + 1);
// 被除数(余数大于等于除数则继续减)
while (abs_compare(dividend, divisor, len1, len2) >= 0)
{
sub.change_length(len2 + 1);
UINT_64 dividend_2digits = dividend[len1 - 1] * base + dividend[len1 - 2];
T quo_digit = 0;
size_t shift = len1 - len2;
// 被除数前两位大于等于除数前两位试商的结果偏差不大于1
if (dividend_2digits >= divisor_2digits)
{
quo_digit = dividend_2digits / divisor_2digits;
abs_mul_num(divisor, quo_digit, sub.type_ptr(), len2, base);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
}
}
else
{
// 被除数前两位和除数前一位试商的结果偏差不大于2
quo_digit = dividend_2digits / (divisor_2digits / base);
shift--;
abs_mul_num(divisor, quo_digit, sub.type_ptr(), len2, base);
sub.set_true_len();
size_t sub_len = sub.length();
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
if (abs_compare(dividend + shift, sub.type_ptr(), len1 - shift, sub_len) < 0)
{
quo_digit--;
abs_sub(sub.type_ptr(), divisor, sub.type_ptr(), sub_len, len2, base);
}
}
}
abs_sub(dividend + shift, sub.type_ptr(), dividend + shift, len1, sub.length(), base);
len1 = ary_true_len(dividend, len1);
quot[shift] = quo_digit;
}
}
int c = 0;
// 递归除法,从被除数返回余数,需要确保除数的规则化
template <typename T>
void abs_rec_div(T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2, const UINT_64 base)
{
c++;
len1 = ary_true_len(dividend, len1);
len2 = ary_true_len(divisor, len2);
printf("%d\n", c);
ary_print(dividend, len1);
ary_print(divisor, len2);
if (divisor == nullptr || len2 == 0)
{
throw("Can't divide by zero\n");
}
if (dividend == nullptr || len1 == 0)
{
return;
}
if (abs_compare(dividend, divisor, len1, len2) < 0)
{
return;
}
if (divisor[len2 - 1] < (base / 2))
{
throw("Can't call this proc before normalize the divisor\n");
}
size_t quot_len = len1 - len2 + 1;
constexpr size_t LONG_DIV_THRESHOLD = 2;
// std::cin.get();
if (len2 <= LONG_DIV_THRESHOLD) // 小于等于阈值调用长除法
{
quot.resize(quot_len);
quot.change_length(quot_len);
abs_long_div(dividend, divisor, quot.type_ptr(), len1, len2, base);
printf("%d\n", c);
quot.print();
}
else if (len1 >= len2 * 2 || len1 > ((len2 + 1) / 2) * 3) // 2n/n的除法,进行两次递归
{
size_t base_len = (len1 + 3) / 4;
size_t quot_tmp_len = base_len * 3 - len2 + 1;
hintvector<T> quot_tmp(quot_tmp_len, 0);
abs_rec_div(dividend + base_len, divisor, quot_tmp, len1 - base_len, len2, base);
quot_tmp_len = quot_tmp.set_true_len();
printf("hh%d\n", c);
quot_tmp.print();
printf("hd%d\n", c);
quot.print();
quot.resize(quot_len);
quot.clear();
abs_rec_div(dividend, divisor, quot, base_len * 3, len2, base);
quot.change_length(quot_len);
abs_add(quot.type_ptr() + base_len, quot_tmp.type_ptr(), quot.type_ptr() + base_len, quot_len - base_len, quot_tmp_len, base);
quot.set_true_len();
}
else
{
quot.resize(quot_len);
quot.change_length(quot_len);
size_t base_len = len2 / 2;
// 开始试商,用dividend/(base^base_len)除以divisor/(base^base_len)
abs_rec_div(dividend + base_len, divisor + base_len, quot, len1 - base_len, len2 - base_len, base);
printf("%d\n", c);
constexpr T ONE[1] = {1};
quot_len = quot.set_true_len();
std::cout << quot_len << "\t" << len1 << "\t" << len2 << "\n";
hintvector<T> prod(base_len + quot_len);
prod.change_length(base_len + quot_len);
// 用除数的低base_len位乘以刚刚试出来的商,而后与余数比较,必须满足quot*(divisor%(base^base_len))<=dividend
abs_mul(divisor, quot.type_ptr(), prod.type_ptr(), base_len, quot_len, base);
size_t prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, len1);
prod.print();
ary_print(dividend, len1);
quot.print();
while (abs_compare(prod.type_ptr(), dividend, prod_len, len1) > 0)
{
printf("kkk\n");
abs_sub(quot.type_ptr(), ONE, quot.type_ptr(), quot_len, 1, base);
abs_sub(prod.type_ptr(), divisor, prod.type_ptr(), prod_len, base_len, base);
abs_add(dividend + base_len, divisor + base_len, dividend + base_len, len1 - base_len, len2 - base_len, base);
quot_len = quot.set_true_len();
prod_len = prod.set_true_len();
len1 = ary_true_len(dividend, std::max(len1, len2) + 1);
ary_print(dividend, len1);
prod.print();
quot.print();
}
abs_sub(dividend, prod.type_ptr(), dividend, len1, prod_len, base);
quot.print();
// ary_print(dividend, len1);
}
quot.set_true_len();
printf("%d\n", c);
quot.print();
}
// 绝对值除法
template <typename T>
hintvector<T> abs_div(const T dividend[], T divisor[], hintvector<T> ",
size_t len1, size_t len2, const UINT_64 base)
{
hintvector<T> normalized_divisor(len2); // 定义规则化的除数
normalized_divisor.change_length(len2);
hintvector<T> normalized_dividend(len1 + 1); // 定义规则化的被除数
normalized_dividend.change_length(len1 + 1);
T *divisor_ptr = normalized_divisor.type_ptr();
T *dividend_ptr = normalized_dividend.type_ptr();
T multiplier = divisor_normalize(divisor, divisor_ptr, len2, base); // 除数规则化,获得乘数
abs_mul_num(dividend, multiplier, dividend_ptr, len1, base); // 被除数规则化
normalized_dividend.set_true_len();
// std::cout << multiplier << "\n";
// normalized_dividend.print();
// normalized_divisor.print();
abs_rec_div(dividend_ptr, divisor_ptr, quot, normalized_dividend.length(), len2, base);
// quot.resize(len1 - len2 + 1);
// quot.change_length(len1 - len2 + 1);
// abs_long_div(dividend_ptr, divisor_ptr, quot.type_ptr(), normalized_dividend.length(), len2, base);
abs_div_num(dividend_ptr, multiplier, dividend_ptr, normalized_dividend.length(), base); // 余数除以乘数得到正确的结果
normalized_dividend.set_true_len();
return normalized_dividend;
}
/// @brief 高精度进制转换
/// @tparam T
/// @tparam UNIT_T
/// @param data_ary 输入表示大整数的数组,需留有充足空间
/// @param BASE1 输入进制
/// @param BASE2 输出进制
template <typename T>
void base_conversion(T data_ary[], size_t &in_len,
const UINT_64 BASE1 = 1 << 16, const UINT_64 BASE2 = 1e4)
{
if (in_len == 0 || BASE1 == BASE2)
{
return;
}
if (in_len < 2)
{
UINT_64 tmp = data_ary[0];
size_t pos = 0;
while (tmp > 0)
{
std::tie(tmp, data_ary[pos]) = div_mod(tmp, BASE2);
pos++;
}
return;
}
const size_t max_rank = min_2pow(in_len) / 2; // unit_ary存储的base1的最高次幂
const UINT_64 base1to2_len = static_cast<UINT_64>(std::ceil(std::log2(BASE1) / std::log2(BASE2))); // base1到base2的数长度的比值
size_t result_len = static_cast<size_t>(max_rank * base1to2_len * 2); // 结果的长度
ary_clr(data_ary + in_len, result_len - in_len); // 清零
// 输入进制比输出进制大进行预处理
if (BASE1 > BASE2)
{
size_t pos = in_len;
while (pos > 0)
{
pos--;
UINT_64 tmp = data_ary[pos];
size_t i = 0, trans_pos = pos * base1to2_len;
while (tmp > 0)
{
std::tie(tmp, data_ary[trans_pos + i]) = div_mod(tmp, BASE2);
i++;
}
}
UINT_64 tmp = BASE2;
while (tmp < BASE1)
{
tmp *= BASE2;
if (tmp == BASE1)
{
return;
}
}
}
size_t unit_ary_len = max_rank * base1to2_len; // unit_ary的长度max_rank
T *unit_ary = new T[unit_ary_len]; // 用一个数组存储base2进制下的(base1)^1,(base1)^2,(base1)^4...
ary_clr(unit_ary, unit_ary_len);
UINT_64 tmp = BASE1;
size_t i = 0;
while (tmp > 0)
{
std::tie(tmp, unit_ary[i]) = div_mod(tmp, BASE2); // 将base2进制下的base1存入数组
}
T *tmp_product = new T[max_rank * base1to2_len * 2];
for (size_t rank = 1; rank <= max_rank; rank *= 2)
{
size_t gap = rank * 2;
for (size_t i = 0; i < result_len; i += gap)
{
T *work_ary = data_ary + i;
abs_mul(work_ary + rank, unit_ary, tmp_product, rank, rank, BASE2);
abs_add<false>(work_ary, tmp_product, work_ary, rank, gap, BASE2);
}
if (rank < max_rank)
{
abs_sqr(unit_ary, unit_ary, rank, BASE2);
}
}
result_len = ary_true_len(data_ary, result_len);
in_len = result_len;
delete[] unit_ary;
delete[] tmp_product;
}
}
// 简单高精度简单实现
class Integer
{
private:
using DataType = hint::UINT_16;
using SizeType = hint::UINT_32;
using DataVec = hint::HintVector<DataType, SizeType>;
DataVec data;
public:
static constexpr hint::UINT_32 DIGIT = 4;
static constexpr hint::UINT_64 BASE = hint::qpow(10, DIGIT);
Integer()
{
data = DataVec();
}
// Integer 拷贝构造
Integer(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
}
// Integer 移动构造
Integer(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
}
// string 参数构造
Integer(const std::string &input)
{
string_in(input);
}
// 字符串构造
Integer(char input[])
{
string_in(input);
}
// 字符串构造
Integer(const char input[])
{
string_in(input);
}
// 通用构造
template <typename T>
Integer(T input)
{
bool is_neg = hint::is_neg(input);
hint::UINT_64 tmp = std::abs<hint::INT_64>(input);
size_t digits = std::ceil(std::log10(tmp + 1));
size_t len = (digits + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
for (size_t i = 0; i < len; i++)
{
data[i] = tmp % BASE;
tmp /= BASE;
}
data.change_sign(is_neg);
data.set_true_len();
}
// Integer 拷贝赋值
Integer &operator=(const Integer &input)
{
if (this != &input)
{
data = input.data;
}
return *this;
}
// Integer 移动赋值
Integer &operator=(Integer &&input) noexcept
{
if (this != &input)
{
data = std::move(input.data);
}
return *this;
}
// string 赋值
Integer &operator=(const std::string &input)
{
string_in(input);
return *this;
}
// 字符串赋值
Integer &operator=(const char input[])
{
string_in(input);
return *this;
}
// 字符串赋值
Integer &operator=(char input[])
{
string_in(input);
return *this;
}
// 通用赋值
template <typename T>
Integer &operator=(T input)
{
bool is_neg = hint::is_neg(input);
hint::UINT_64 tmp = std::abs<hint::INT_64>(input);
size_t digits = std::ceil(std::log10(tmp + 1));
size_t len = (digits + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
for (size_t i = 0; i < len; i++)
{
std::tie(tmp, data[i]) = hint::div_mod(tmp, BASE);
}
data.change_sign(is_neg);
data.set_true_len();
return *this;
}
// 首位的数字
DataType first_num() const
{
if (length() == 0)
{
return 0;
}
return data[length() - 1];
}
// 前2位的数字
hint::UINT_64 first2_num() const
{
if (length() < 2)
{
return first_num();
}
return data[length() - 1] * BASE + data[length() - 2];
}
// 更改符号
void change_sign(bool is_neg)
{
data.change_sign(is_neg);
}
// 是否为负号
bool is_neg() const
{
return data.sign();
}
SizeType length() const
{
return data.length();
}
SizeType length_base10() const
{
size_t len = data.length();
if (len == 0)
{
return 1;
}
return (len - 1) * DIGIT + std::ceil(std::log10(first_num() + 1));
}
void string_in(const std::string &str)
{
size_t str_len = str.size();
if (str_len == 0)
{
data = DataVec();
return;
}
hint::INT_64 len = str_len, pos = len, i = 0;
bool is_neg = false;
if (str[0] == '-')
{
is_neg = true;
len--;
}
len = (len + DIGIT - 1) / DIGIT;
data = DataVec(len);
data.change_length(len);
while (pos - DIGIT > 0)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + pos - DIGIT, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
i++;
pos -= DIGIT;
}
hint::INT_64 begin = is_neg ? 1 : 0;
if (pos > begin)
{
hint::UINT_64 tmp = hint::stoui64(str.begin() + begin, str.begin() + pos);
data[i] = static_cast<DataType>(tmp);
}
change_sign(is_neg);
data.set_true_len();
}
std::string to_string() const
{
std::string result;
size_t pos = length();
if (pos == 0)
{
return "0";
}
if (is_neg())
{
result += '-';
}
result += std::to_string(first_num());
pos--;
while (pos > 0)
{
pos--;
result += hint::ui64to_string(data[pos], DIGIT);
}
return result;
}
void print() const
{
size_t pos = length();
if (pos < 1)
{
putchar('0');
return;
}
if (is_neg())
{
putchar('-');
}
printf("%d", first_num());
pos--;
while (pos > 0)
{
pos--;
printf(" %04d", data[pos]);
}
putchar('\n');
}
hint::INT_32 abs_compare(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
return hint_arithm::abs_compare(data.type_ptr(), input.data.type_ptr(), len1, len2);
}
Integer abs_add(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
Integer result;
result.data = DataVec(std::max(len1, len2) + 1);
result.data.change_length(std::max(len1, len2) + 1);
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
auto res_ptr = result.data.type_ptr();
hint_arithm::abs_add<true>(ptr1, ptr2, res_ptr, len1, len2, BASE);
result.data.set_true_len();
return result;
}
Integer abs_sub(const Integer &input) const
{
size_t len1 = length(), len2 = input.length();
Integer result;
result.data = DataVec(std::max(len1, len2));
result.data.change_length(std::max(len1, len2));
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
auto res_ptr = result.data.type_ptr();
hint_arithm::abs_sub(ptr1, ptr2, res_ptr, len1, len2, BASE);
result.data.set_true_len();
return result;
}
bool operator>(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return !is_neg();
}
return is_neg() != (abs_compare(input) > 0);
}
bool operator<(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return is_neg();
}
return is_neg() != (abs_compare(input) < 0);
}
bool operator>=(const Integer &input) const
{
return !(*this < input);
}
bool operator<=(const Integer &input) const
{
return !(*this > input);
}
bool operator==(const Integer &input) const
{
if (is_neg() != input.is_neg())
{
return false;
}
return abs_compare(input) == 0;
}
bool operator!=(const Integer &input) const
{
return !(*this == input);
}
Integer operator+(const Integer &input) const
{
Integer result;
if (is_neg() == input.is_neg()) // 是否同号
{
result = abs_add(input);
result.change_sign(is_neg());
}
else
{
const hint::INT_32 cmp = abs_compare(input);
if (cmp > 0)
{
result = abs_sub(input);
result.change_sign(is_neg());
}
else if (cmp < 0)
{
result = input.abs_sub(*this);
result.change_sign(!is_neg());
}
}
return result;
}
Integer operator-(const Integer &input) const
{
Integer result;
if (this == &input)
{
return result;
}
if (is_neg() != input.is_neg()) // 是否异号
{
result = abs_add(input);
result.change_sign(is_neg());
}
else
{
const hint::INT_32 cmp = abs_compare(input);
if (cmp > 0)
{
result = abs_sub(input);
result.change_sign(is_neg());
}
else if (cmp < 0)
{
result = input.abs_sub(*this);
result.change_sign(!is_neg());
}
}
return result;
}
Integer operator*(const Integer &input) const
{
Integer result;
size_t len1 = length(), len2 = input.length();
if (len1 == 0 || len2 == 0)
{
return result;
}
result.data = DataVec(len1 + len2);
result.data.change_length(len1 + len2);
result.data.clear();
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
auto res_ptr = result.data.type_ptr();
if (abs_compare(input) == 0)
{
hint_arithm::abs_sqr(ptr1, res_ptr, len1, BASE);
}
else
{
hint_arithm::abs_mul(ptr1, ptr2, res_ptr, len1, len2, BASE);
}
result.data.set_true_len();
result.change_sign(is_neg() != input.is_neg());
return result;
}
Integer operator/(const Integer &input) const
{
Integer result;
size_t len1 = length(), len2 = input.length();
if (len1 < len2)
{
return result;
}
if (len2 == 0)
{
throw("Can't divide by zero\n");
}
auto ptr1 = data.type_ptr();
auto ptr2 = input.data.type_ptr();
hint_arithm::abs_div(ptr1, ptr2, result.data, len1, len2, BASE);
result.data.set_true_len();
result.change_sign(is_neg() != input.is_neg());
return result;
}
};
Integer pi_generator(hint::UINT_32 n)
{
n += 5;
Integer result = hint::qpow(Integer(10), n) * Integer(2);
Integer a = result / 3;
result = result + a;
hint::UINT_32 i = 2;
while (a.length() > 0)
{
a = a * i;
a = a / (i * 2 + 1);
result = result + a;
i++;
}
return result;
}
template <typename T>
bool div_test(const T ÷nd, const T &divisor)
{
T quo = dividend / divisor;
T prod = quo * divisor;
return (dividend >= prod) && (dividend < (prod + divisor));
}
#endif
#include <iostream>
using namespace std;
int main() {
string s;
cin>>s;
Integer a = s;
cin>>s;
Integer b= s;
cout<<(a*b).to_string();
return 0;
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |