提交记录 28450


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 36.925 ms 81572 KB C++14 69.13 KB
提交时间 评测时间
2025-09-10 22:09:52 2025-09-10 22:09:56
// TSKY 2025/8/23
#include <vector>
#include <array>
#include <complex>
#include <iostream>
#include <chrono>
#include <string>
#include <bitset>
#include <type_traits>
#include <cstdint>
#include <climits>
#include <cfloat>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>

#include <immintrin.h>

#define __FMA__
#define __AVX2__

#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("avx2")
#pragma GCC target("fma")

namespace hint_simd
{
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(4096) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };
    template <typename YMM>
    inline void transpose64_2X4(YMM &row0, YMM &row1)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7

        row0 = YMM(_mm256_permute2f128_pd(t0, t1, 0x20)); // 0,4,2,6 1,5,3,7 -> 0,4,1,5
        row1 = YMM(_mm256_permute2f128_pd(t0, t1, 0x31)); // 0,4,2,6 1,5,3,7 -> 2,6,3,7
    }
    template <typename YMM>
    inline void transpose64_4X2(YMM &row0, YMM &row1)
    {
        auto t0 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x20); // 0,1,2,3 4,5,6,7 -> 0,1,4,5
        auto t1 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x31); // 0,1,2,3 4,5,6,7 -> 2,3,6,7
        row0 = YMM(_mm256_unpacklo_pd(t0, t1));                               // 0,1,4,5 2,3,6,7 -> 0,4,2,6
        row1 = YMM(_mm256_unpackhi_pd(t0, t1));                               // 0,1,4,5 2,3,6,7 -> 1,5,3,7
    }

    template <typename YMM>
    inline void transpose64_8X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
                                YMM &row4, YMM &row5, YMM &row6, YMM &row7)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
        auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
        auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15
        auto t4 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 16,20,18,22
        auto t5 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 17,21,19,23
        auto t6 = _mm256_unpacklo_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 24,28,26,30
        auto t7 = _mm256_unpackhi_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 25,29,27,31

        row0 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x20));
        row2 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x20));
        row3 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x20));
        row4 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x31));
        row5 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x31));
        row6 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x31));
        row7 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x31));
    }

    template <typename YMM>
    inline void transpose64_4X8(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
                                YMM &row4, YMM &row5, YMM &row6, YMM &row7)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 0,8,2,10
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 1,9,3,11
        auto t2 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 16,24,18,26
        auto t3 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 17,25,19,27
        auto t4 = _mm256_unpacklo_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 4,12,6,14
        auto t5 = _mm256_unpackhi_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 5,13,7,15
        auto t6 = _mm256_unpacklo_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 20,28,22,30
        auto t7 = _mm256_unpackhi_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 21,29,23,31

        row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
        row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
        row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
        row4 = YMM(_mm256_permute2f128_pd(t4, t6, 0x20));
        row5 = YMM(_mm256_permute2f128_pd(t5, t7, 0x20));
        row6 = YMM(_mm256_permute2f128_pd(t4, t6, 0x31));
        row7 = YMM(_mm256_permute2f128_pd(t5, t7, 0x31));
    }

    template <typename YMM>
    inline void transpose64_4X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
        auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
        auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15

        row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
        row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
        row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
    }

    class Float64X4
    {
    public:
        using F64 = double;
        using F64X4 = Float64X4;
        Float64X4() : data(_mm256_setzero_pd()) {}
        Float64X4(__m256d in_data) : data(in_data) {}
        Float64X4(F64 in_data) : data(_mm256_set1_pd(in_data)) {}
        Float64X4(const F64 *in_data) : data(_mm256_load_pd(in_data)) {}

        F64X4 operator+(const F64X4 &other) const
        {
            return _mm256_add_pd(data, other.data);
        }
        F64X4 operator-(const F64X4 &other) const
        {
            return _mm256_sub_pd(data, other.data);
        }
        F64X4 operator*(const F64X4 &other) const
        {
            return _mm256_mul_pd(data, other.data);
        }
        F64X4 operator/(const F64X4 &other) const
        {
            return _mm256_div_pd(data, other.data);
        }

        F64X4 &operator+=(const F64X4 &other)
        {
            return *this = *this + other;
        }
        F64X4 &operator-=(const F64X4 &other)
        {
            return *this = *this - other;
        }
        F64X4 &operator*=(const F64X4 &other)
        {
            return *this = *this * other;
        }
        F64X4 &operator/=(const F64X4 &other)
        {
            return *this = *this / other;
        }
        F64X4 floor() const
        {
            return _mm256_floor_pd(data);
        }
        // a * b + c
        static F64X4 fmadd(const F64X4 &a, const F64X4 &b, const F64X4 &c)
        {
#ifdef __FMA__
            return _mm256_fmadd_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
            return a * b + c;
#endif
        }
        // a * b - c
        static F64X4 fmsub(const F64X4 &a, const F64X4 &b, const F64X4 &c)
        {
#ifdef __FMA__
            return _mm256_fmsub_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
            return a * b - c;
#endif
        }
#ifdef __AVX2__
        template <int N>
        F64X4 permute4x64() const
        {
            return _mm256_permute4x64_pd(data, N);
        }
#else
        template <int N>
        F64X4 permute4x64() const
        {
            alignas(32) uint64_t arr[4];
            alignas(32) uint64_t dst[4];
            this->store(reinterpret_cast<F64 *>(arr));
            dst[0] = arr[(N >> 0) & 3];
            dst[1] = arr[(N >> 2) & 3];
            dst[2] = arr[(N >> 4) & 3];
            dst[3] = arr[(N >> 6) & 3];
            return fromMem(reinterpret_cast<const F64 *>(dst));
        }
#endif
        static F64X4 extractEven64X4(const F64X4 &in0, const F64X4 &in1)
        {
            F64X4 result = _mm256_unpacklo_pd(in0.data, in1.data); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
            return result.permute4x64<0b11011000>();               // 0,4,2,6 -> 0,2,4,6
        }

        template <int N>
        F64X4 permute() const
        {
            return _mm256_permute_pd(data, N);
        }
        F64X4 reverse() const
        {
            return permute4x64<0b00011011>();
        }
        void load(const F64 *p)
        {
            data = _mm256_load_pd(p);
        }
        void loadu(const F64 *p)
        {
            data = _mm256_loadu_pd(p);
        }
        void load1(const F64 *p)
        {
            data = _mm256_broadcast_sd(p);
        }
        static F64X4 fromMem(const F64 *p)
        {
            return _mm256_load_pd(p);
        }
        static F64X4 fromUMem(const F64 *p)
        {
            return _mm256_loadu_pd(p);
        }
        void store(F64 *p) const
        {
            _mm256_store_pd(p, data);
        }
        void storeu(F64 *p) const
        {
            _mm256_storeu_pd(p, data);
        }
        operator __m256d() const
        {
            return data;
        }
#ifdef __AVX2__
        // Convert positive double to int64
        __m256i toI64X4() const
        {
            constexpr uint64_t mask = (uint64_t(1) << 52) - 1;
            constexpr uint64_t offset = (uint64_t(1) << 10) - 1;
            const __m256i f64bits = _mm256_castpd_si256(data);
            __m256i tail = _mm256_and_si256(f64bits, _mm256_set1_epi64x(mask));
            tail = _mm256_or_si256(tail, _mm256_set1_epi64x(mask + 1));
            __m256i exp = _mm256_srli_epi64(f64bits, 52);
            exp = _mm256_sub_epi64(_mm256_set1_epi64x(offset + 52), exp);
            return _mm256_srlv_epi64(tail, exp);
        }
#else
#pragma message("No AVX2 support")
        __m256i toI64X4() const
        {
            alignas(32) F64 arr[4];
            alignas(32) int64_t i64_arr[4];
            this->store(arr);
            i64_arr[0] = arr[0];
            i64_arr[1] = arr[1];
            i64_arr[2] = arr[2];
            i64_arr[3] = arr[3];
            return _mm256_load_si256(reinterpret_cast<const __m256i *>(i64_arr));
        }
#endif
        template <int N>
        F64 nthEle() const
        {
            union F64I64
            {
                int64_t i64;
                F64 f64;
            } temp;
            temp.i64 = _mm256_extract_epi64(__m256i(data), N);
            return temp.f64;
        }

        void print() const
        {
            std::cout << "[" << nthEle<0>() << "," << nthEle<1>()
                      << "," << nthEle<2>() << "," << nthEle<3>() << "]" << std::endl;
        }

    private:
        __m256d data;
    };

    struct Complex64X4
    {
        using C64X4 = Complex64X4;
        using F64X4 = Float64X4;
        using F64 = double;
        Complex64X4() {}
        Complex64X4(F64X4 real, F64X4 imag) : real(real), imag(imag) {}
        Complex64X4(const F64 *p) : real(p), imag(p + 4) {}
        Complex64X4(const F64 *p_real, const F64 *p_imag) : real(p_real), imag(p_imag) {}
        C64X4 operator+(const C64X4 &other) const
        {
            return C64X4(real + other.real, imag + other.imag);
        }
        C64X4 operator-(const C64X4 &other) const
        {
            return C64X4(real - other.real, imag - other.imag);
        }
        C64X4 mul(const C64X4 &other) const
        {
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmsub(real, other.real, ii);
            const F64X4 i = F64X4::fmadd(imag, other.real, ri);
            return C64X4(r, i);
        }
        C64X4 mulConj(const C64X4 &other) const
        {
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmadd(real, other.real, ii);
            const F64X4 i = F64X4::fmsub(imag, other.real, ri);
            return C64X4(r, i);
        }
        // exp{i*theta*k},k in {0,1,2,3}
        static C64X4 omegaSeq0To3(F64 theta, F64 begin = 0)
        {
            F64 real_arr[4] = {cos(begin), cos(theta + begin), cos(2 * theta + begin), cos(3 * theta + begin)};
            F64 imag_arr[4] = {sin(begin), sin(theta + begin), sin(2 * theta + begin), sin(3 * theta + begin)};
            return C64X4(F64X4(real_arr), F64X4(imag_arr));
        }

        template <typename T>
        void load(const T *p, std::false_type)
        {
            this->load(p);
        }
        // From RIRI permutation
        template <typename T>
        void load(const T *p, std::true_type)
        {
            this->load(p);
            *this = this->toRRIIPermu();
        }
        template <typename T>
        void load(const T *p)
        {
            real.load(reinterpret_cast<const F64 *>(p));
            imag.load(reinterpret_cast<const F64 *>(p) + 4);
        }
        template <typename T>
        void loadu(const T *p)
        {
            real.loadu(reinterpret_cast<const F64 *>(p));
            imag.loadu(reinterpret_cast<const F64 *>(p) + 4);
        }
        void load1(const F64 *real_p, const F64 *imag_p)
        {
            real.load1(real_p);
            imag.load1(imag_p);
        }

        template <typename T>
        void store(T *p, std::false_type) const
        {
            this->store(p);
        }
        // To RIRI permutation
        template <typename T>
        void store(T *p, std::true_type) const
        {
            this->toRIRIPermu().store(p);
        }
        template <typename T>
        void store(T *p) const
        {
            real.store(reinterpret_cast<F64 *>(p));
            imag.store(reinterpret_cast<F64 *>(p) + 4);
        }
        template <typename T>
        void storeu(T *p) const
        {
            real.storeu(reinterpret_cast<F64 *>(p));
            imag.storeu(reinterpret_cast<F64 *>(p) + 4);
        }
        C64X4 square() const
        {
            const F64X4 ii = imag * imag;
            const F64X4 ri = real * imag;
            const F64X4 r = F64X4::fmsub(real, real, ii);
            const F64X4 i = ri + ri;
            return C64X4(r, i);
        }
        C64X4 cube() const
        {
            const F64X4 rr = real * real;
            const F64X4 ii = imag * imag;
            const F64X4 rr3 = rr + rr + rr;
            const F64X4 ii3 = ii + ii + ii;
            const F64X4 r = real * (rr - ii3);
            const F64X4 i = imag * (rr3 - ii);
            return C64X4(r, i);
        }
        C64X4 toRIRIPermu() const
        {
            C64X4 res = *this;
            transpose64_2X4(res.real, res.imag);
            return res;
        }
        C64X4 toRRIIPermu() const
        {
            C64X4 res = *this;
            transpose64_4X2(res.real, res.imag);
            return res;
        }
        void print() const
        {
            alignas(32) F64 real_arr[4]{}, imag_arr[4]{};
            real.storeu(real_arr);
            imag.storeu(imag_arr);
            std::cout << "[(" << real_arr[0] << ", " << imag_arr[0] << "), ("
                      << real_arr[1] << ", " << imag_arr[1] << "), ("
                      << real_arr[2] << ", " << imag_arr[2] << "), ("
                      << real_arr[3] << ", " << imag_arr[3] << ")]" << std::endl;
        }

        C64X4 transToI64(std::false_type) const
        {
            return *this;
        }
        C64X4 transToI64(std::true_type) const
        {
            constexpr int64_t F1_2 = 4602678819172646912; // magic::bit_cast<int64_t>(0.5);
            auto F1_2X4 = F64X4(__m256d(_mm256_set1_epi64x(F1_2)));
            auto real_i64 = (real + F1_2X4).toI64X4();
            auto imag_i64 = (imag + F1_2X4).toI64X4();
            return C64X4(__m256d(real_i64), __m256d(imag_i64));
        }

        F64X4 real, imag;
    };
}
#endif

namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    constexpr Float64 COS_PI_8 = 0.707106781186547524400844;

    constexpr size_t FFT_MAX_LEN = size_t(1) << 23;

    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    template <typename IntTy>
    constexpr bool is_2pow(IntTy n)
    {
        return n != 0 && (n & (n - 1)) == 0;
    }

    // 求整数的对数
    template <typename T>
    constexpr int hint_log2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r)
        {
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n)
            {
                r = mid;
            }
            else
            {
                l = mid;
            }
        }
        return l;
    }
    constexpr int hint_ctz(uint32_t x)
    {
        int r0 = 31;
        x &= (-x);
        if (x & 0x55555555)
        {
            r0 &= ~1;
        }
        if (x & 0x33333333)
        {
            r0 &= ~2;
        }
        if (x & 0x0F0F0F0F)
        {
            r0 &= ~4;
        }
        if (x & 0x00FF00FF)
        {
            r0 &= ~8;
        }
        if (x & 0x0000FFFF)
        {
            r0 &= ~16;
        }
        r0 += (x == 0);
        return r0;
    }

    constexpr int hint_ctz(uint64_t x)
    {
        int r0 = 63;
        x &= (-x);
        if (x & 0x5555555555555555)
        {
            r0 &= ~1; // -1
        }
        if (x & 0x3333333333333333)
        {
            r0 &= ~2; // -2
        }
        if (x & 0x0F0F0F0F0F0F0F0F)
        {
            r0 &= ~4; // -4
        }
        if (x & 0x00FF00FF00FF00FF)
        {
            r0 &= ~8; // -8
        }
        if (x & 0x0000FFFF0000FFFF)
        {
            r0 &= ~16; // -16
        }
        if (x & 0x00000000FFFFFFFF)
        {
            r0 &= ~32; // -32
        }
        r0 += (x == 0);
        return r0;
    }

    constexpr int hint_popcnt(uint32_t n)
    {
        constexpr uint32_t mask55 = 0x55555555;
        constexpr uint32_t mask33 = 0x33333333;
        constexpr uint32_t mask0f = 0x0f0f0f0f;
        constexpr uint32_t maskff = 0x00ff00ff;
        n = (n & mask55) + ((n >> 1) & mask55);
        n = (n & mask33) + ((n >> 2) & mask33);
        n = (n & mask0f) + ((n >> 4) & mask0f);
        n = (n & maskff) + ((n >> 8) & maskff);
        return uint16_t(n) + (n >> 16);
    }
    constexpr int hint_popcnt(uint64_t n)
    {
        constexpr uint64_t mask5555 = 0x5555555555555555;
        constexpr uint64_t mask3333 = 0x3333333333333333;
        constexpr uint64_t mask0f0f = 0x0f0f0f0f0f0f0f0f;
        constexpr uint64_t mask00ff = 0x00ff00ff00ff00ff;
        constexpr uint64_t maskffff = 0x0000ffff0000ffff;
        n = (n & mask5555) + ((n >> 1) & mask5555);
        n = (n & mask3333) + ((n >> 2) & mask3333);
        n = (n & mask0f0f) + ((n >> 4) & mask0f0f);
        n = (n & mask00ff) + ((n >> 8) & mask00ff);
        n = (n & maskffff) + ((n >> 16) & maskffff);
        return uint32_t(n) + (n >> 32);
    }

    template <typename T, T N>
    struct StaticObject
    {
        using Type = T;
        static constexpr Type value = N;
    };

    template <size_t N>
    using StaticSize = StaticObject<size_t, N>;

    template <int N>
    using StaticInt = StaticObject<int, N>;

    // FFT与类FFT变换的命名空间
    namespace transform
    {
        using namespace hint_simd;

        template <typename T>
        inline void transform2(T &sum, T &diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }

        template <typename T>
        inline void transform2(const T a, const T b, T &sum, T &diff)
        {
            sum = a + b;
            diff = a - b;
        }

        // 返回单位圆上辐角为theta的点
        template <typename FloatTy>
        inline auto unit_root(FloatTy theta)
        {
            return std::polar<FloatTy>(1.0, theta);
        }

        // 二进制逆序
        template <typename It>
        void binary_reverse_swap(It begin, It end)
        {
            const size_t len = end - begin;
            // 左下标小于右下标时交换,防止重复交换
            auto smaller_swap = [=](It it_left, It it_right)
            {
                if (it_left < it_right)
                {
                    std::swap(it_left[0], it_right[0]);
                }
            };
            // 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
            auto get_next_bitrev = [=](It last)
            {
                size_t k = len / 2, indx = last - begin;
                indx ^= k;
                while (k > indx)
                {
                    k >>= 1;
                    indx ^= k;
                };
                return begin + indx;
            };
            // 长度较短的普通逆序
            if (len <= 16)
            {
                for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
                {
                    smaller_swap(i, j);
                    j = get_next_bitrev(j);
                }
                return;
            }
            const size_t len_8 = len / 8;
            const auto last = begin + len_8;
            auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
            for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
            {
                smaller_swap(i0, j);
                smaller_swap(i1, j + 1);
                smaller_swap(i2, j + 2);
                smaller_swap(i3, j + 3);
                smaller_swap(i0 + len_8, j + 4);
                smaller_swap(i1 + len_8, j + 5);
                smaller_swap(i2 + len_8, j + 6);
                smaller_swap(i3 + len_8, j + 7);
                j = get_next_bitrev(j);
            }
        }

        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T ary, const size_t len)
        {
            binary_reverse_swap(ary, ary + len);
        }
        namespace fft
        {
            using F64 = Float64;
            using C64 = std::complex<F64>;
            using F64X4 = Float64X4;
            using C64X4 = Complex64X4;
            template <typename Float, size_t OMEGA_LEN>
            class TableFix
            {
                alignas(64) std::array<Float, OMEGA_LEN * 2> table;

            public:
                TableFix(size_t theta_divider, size_t factor, size_t stride)
                {
                    const Float theta = -HINT_2PI * factor / theta_divider;
                    assert(OMEGA_LEN % stride == 0);
                    for (size_t begin = 0, index = 0; begin < OMEGA_LEN * 2; begin += stride * 2)
                    {
                        for (size_t j = 0; j < stride; j++, index++)
                        {
                            table[begin + j] = std::cos(theta * index);
                            table[begin + j + stride] = std::sin(theta * index);
                        }
                    }
                }
                constexpr const Float &operator[](size_t index) const
                {
                    return table[index];
                }
                constexpr const Float *getOmegaIt(size_t index) const
                {
                    return &table[index];
                }
            };
            template <typename Float, int LOG_BEGIN, int LOG_END, int DIV>
            class TableFixMulti
            {
                static_assert(LOG_END >= LOG_BEGIN);
                static_assert(is_2pow(DIV));
                static constexpr size_t TABLE_CPX_LEN = (size_t(1) << (LOG_END + 1)) / DIV;
                alignas(64) std::array<Float, TABLE_CPX_LEN * 2> table;

            public:
                TableFixMulti(size_t factor, size_t stride = 4)
                {
                    assert(((size_t(1) << LOG_BEGIN) / DIV) % stride == 0);
                    initBottomUp(factor, stride);
                }
                void initTopDown(size_t factor, size_t stride)
                {
                    static_assert(std::is_same<Float, Float64>::value);
                    assert(stride == 4);
                }
                void initBottomUp(size_t factor, size_t stride)
                {
                    static_assert(std::is_same<Float, Float64>::value);
                    assert(stride == 4);
                    size_t len = size_t(1) << LOG_BEGIN, cpx_len = len / DIV;
                    Float theta = -HINT_2PI * factor / len;
                    auto it = getBeginLog(LOG_BEGIN);
                    for (size_t i = 0; i < cpx_len; i++)
                    {
                        it[0] = std::cos(theta * i), it[stride] = std::sin(theta * i);
                        it += (i % stride == stride - 1 ? stride + 1 : 1);
                    }
                    it = getBeginLog(LOG_BEGIN);
                    for (int log_len = LOG_BEGIN + 1; log_len <= LOG_END; log_len++)
                    {
                        len = size_t(1) << log_len, cpx_len = len / DIV;
                        theta = -HINT_2PI * factor / len;
                        auto it = getBeginLog(log_len), it_last = getBeginLog(log_len - 1);
                        C64X4 unit(std::cos(theta), std::sin(theta));
                        for (auto end = it + cpx_len * 2; it < end; it += 16, it_last += 8)
                        {
                            Complex64X4 omega0, omega1;
                            omega0.load(it_last);
                            omega1 = omega0.mul(unit);
                            transpose64_2X4(omega0.real, omega1.real);
                            transpose64_2X4(omega0.imag, omega1.imag);
                            omega0.store(it), omega1.store(it + 8);
                        }
                    }
                }
                constexpr const Float *getBeginLog(int log_rank) const
                {
                    return getBegin(size_t(1) << log_rank);
                }
                constexpr Float *getBeginLog(int log_rank)
                {
                    return getBegin(size_t(1) << log_rank);
                }
                constexpr const Float *getBegin(size_t rank) const
                {
                    return &table[rank * 2 / DIV];
                }
                constexpr Float *getBegin(size_t rank)
                {
                    return &table[rank * 2 / DIV];
                }
            };
            struct FFTFixed
            {
                static constexpr size_t LOG_MAX = 23;
                static constexpr size_t LOG_SHORT = 10;
                static constexpr size_t SHORT_LEN = size_t(1) << LOG_SHORT;
                static constexpr size_t MAX_LEN = size_t(1) << LOG_MAX;
                static const TableFix<Float64, 4> table_8;
                static const TableFix<Float64, 4> table_16_1;
                static const TableFix<Float64, 4> table_16_3;
                static const TableFix<Float64, 8> table_32_1;
                static const TableFix<Float64, 8> table_32_3;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_2;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_3;
                static const TableFixMulti<Float64, 6, LOG_MAX, 4> multi_table_1;

                static constexpr const Float64 *it8 = &table_8[0];
                static constexpr const Float64 *it16_1 = &table_16_1[0];
                static constexpr const Float64 *it16_3 = &table_16_3[0];
                static constexpr const Float64 *it32_1 = &table_32_1[0];
                static constexpr const Float64 *it32_3 = &table_32_3[0];

                template <typename Float>
                static void dif4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, r3);
                    transform2(i1, i3);

                    transform2(r0, r1);
                    transform2(i0, i1);
                    transform2(r2, i3);
                    transform2(i2, r3, r3, i2);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void idit4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r1);
                    transform2(i0, i1);
                    transform2(r2, r3);
                    transform2(i2, i3);

                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, i3, i3, r1);
                    transform2(i1, r3);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void difSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, r3);
                    transform2(i1, i3);

                    transform2(r2, i3);
                    transform2(i2, r3, r3, i2);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void iditSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r2, r3);
                    transform2(i2, i3);

                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, i3, i3, r1);
                    transform2(i1, r3);
                    std::swap(i3, r3);
                }
                static void dif4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
                {
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    dif4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);
                }
                static void idit4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
                {
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    idit4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);
                }
                static void dif8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
                {
                    transform2(c0, c1);
                    transform2(c2, c3);
                    c1 = c1.mul(omega);
                    c3 = c3.mul(omega);
                    dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                }
                static void idit8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
                {
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c1 = c1.mulConj(omega);
                    c3 = c3.mulConj(omega);
                    transform2(c0, c1);
                    transform2(c2, c3);
                }
                static void dif8x2(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    dif8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void idit8x2(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    idit8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void dif16(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mul(omega);
                    omega.load(it16_1), c2 = c2.mul(omega);
                    omega.load(it16_3), c3 = c3.mul(omega);
                    dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void idit16(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mulConj(omega);
                    omega.load(it16_1), c2 = c2.mulConj(omega);
                    omega.load(it16_3), c3 = c3.mulConj(omega);
                    idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void dif32(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 16), c2.load(in_out + 32), c3.load(in_out + 48);
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1), c2 = c2.mul(omega);
                    omega.load(it32_3), c3 = c3.mul(omega);
                    c0.store(in_out), c1.store(in_out + 16), c2.store(in_out + 32), c3.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56); // 1,3,5,7
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1 + 8), c2 = c2.mul(omega);
                    omega.load(it32_3 + 8), c3 = c3.mul(omega);
                    c0.store(in_out + 8), c1.store(in_out + 24);
                    c0.load(in_out + 32), c1.load(in_out + 48), omega.load(it8); // 4,6
                    dif8x2(c0, c2, c1, c3, omega);
                    c0.store(in_out + 32), c2.store(in_out + 40), c1.store(in_out + 48), c3.store(in_out + 56);
                    dif16(in_out);
                }
                static void idit32(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    idit16(in_out);
                    c0.load(in_out + 32), c1.load(in_out + 40), c2.load(in_out + 48), c3.load(in_out + 56), omega.load(it8); // 4,5,6,7
                    idit8x2(c0, c1, c2, c3, omega);
                    c1.store(in_out + 40), c3.store(in_out + 56);

                    c1.load(in_out), c3.load(in_out + 16);
                    omega.load(it32_1), c0 = c0.mulConj(omega);
                    omega.load(it32_3), c2 = c2.mulConj(omega);
                    iditSplit(c1.real, c1.imag, c3.real, c3.imag, c0.real, c0.imag, c2.real, c2.imag);
                    c1.store(in_out), c3.store(in_out + 16), c0.store(in_out + 32), c2.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
                    omega.load(it32_1 + 8), c2 = c2.mulConj(omega);
                    omega.load(it32_3 + 8), c3 = c3.mulConj(omega);
                    iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out + 8), c1.store(in_out + 24), c2.store(in_out + 40), c3.store(in_out + 56);
                }
                static void dif16(Float64 in_out[], size_t float_len)
                {
                    assert(float_len >= 128);
                    for (auto end = in_out + float_len; in_out < end; in_out += 128)
                    {
                        dif16(in_out);
                        dif16(in_out + 32);
                        dif16(in_out + 64);
                        dif16(in_out + 96);
                    }
                }
                static void idit16(Float64 in_out[], size_t float_len)
                {
                    assert(float_len >= 128);
                    for (auto end = in_out + float_len; in_out < end; in_out += 128)
                    {
                        idit16(in_out);
                        idit16(in_out + 32);
                        idit16(in_out + 64);
                        idit16(in_out + 96);
                    }
                }
                static void dif32(Float64 in_out[], size_t float_len)
                {
                    assert(float_len >= 256);
                    for (auto end = in_out + float_len; in_out < end; in_out += 256)
                    {
                        dif32(in_out);
                        dif32(in_out + 64);
                        dif32(in_out + 128);
                        dif32(in_out + 192);
                    }
                }
                static void idit32(Float64 in_out[], size_t float_len)
                {
                    assert(float_len >= 256);
                    for (auto end = in_out + float_len; in_out < end; in_out += 256)
                    {
                        idit32(in_out);
                        idit32(in_out + 64);
                        idit32(in_out + 128);
                        idit32(in_out + 192);
                    }
                }
                static void difIter(Float64 in_out[], size_t float_len)
                {
                    size_t fft_len = float_len / 2;
                    assert(128 <= fft_len && fft_len <= SHORT_LEN);
                    for (size_t rank = fft_len; rank >= 64; rank /= 4)
                    {
                        const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
                        {
                            auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
                            auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                            for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
                            {
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                                dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                omega.load(table2), c1 = c1.mul(omega);
                                omega.load(table1), c2 = c2.mul(omega);
                                omega.load(table3), c3 = c3.mul(omega);
                                c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
                            }
                        }
                    }
                    if (hint_log2(fft_len) % 2 == 0)
                    {
                        dif16(in_out, float_len);
                    }
                    else
                    {
                        dif32(in_out, float_len);
                    }
                }
                static void iditIter(Float64 in_out[], size_t float_len)
                {
                    size_t fft_len = float_len / 2;
                    assert(128 <= fft_len && fft_len <= SHORT_LEN);
                    size_t rank = 0;
                    if (hint_log2(fft_len) % 2 == 0)
                    {
                        idit16(in_out, float_len);
                        rank = 64;
                    }
                    else
                    {
                        idit32(in_out, float_len);
                        rank = 128;
                    }
                    for (; rank <= fft_len; rank *= 4)
                    {
                        const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2)
                        {
                            auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
                            auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                            for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
                            {
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                                omega.load(table2), c1 = c1.mulConj(omega);
                                omega.load(table1), c2 = c2.mulConj(omega);
                                omega.load(table3), c3 = c3.mulConj(omega);
                                idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
                            }
                        }
                    }
                }
                template <bool FROM_RIRI_PERM = false>
                static void difRec(Float64 in_out[], size_t float_len)
                {
                    using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
                    const size_t fft_len = float_len / 2;
                    assert(fft_len <= MAX_LEN);
                    if (fft_len <= SHORT_LEN)
                    {
                        difIter(in_out, float_len);
                        assert(!FROM_RIRI_PERM);
                        return;
                    }
                    const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    auto table1 = multi_table_1.getBegin(fft_len);
                    for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega;
                        c0.load(it, FromRIRI{}), c1.load(it + stride1, FromRIRI{}), c2.load(it + stride2, FromRIRI{}), c3.load(it + stride3, FromRIRI{});
                        difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        omega.load(table1), c2 = c2.mul(omega);
                        c3 = c3.mul(omega.cube());
                        c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
                    }
                    difRec(in_out, stride2);
                    difRec(in_out + stride2, stride1);
                    difRec(in_out + stride3, stride1);
                }
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
                static void iditRec(Float64 in_out[], size_t float_len)
                {
                    const size_t fft_len = float_len / 2;
                    assert(fft_len <= MAX_LEN);
                    if (fft_len <= SHORT_LEN)
                    {
                        iditIter(in_out, float_len);
                        assert(!TO_RIRI_PERM);
                        return;
                    }
                    using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
                    using ToI64 = std::integral_constant<bool, TO_INT64>;
                    const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    iditRec(in_out, stride2);
                    iditRec(in_out + stride2, stride1);
                    iditRec(in_out + stride3, stride1);
                    auto table1 = multi_table_1.getBegin(fft_len);
                    for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega;
                        c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
                        omega.load(table1), c2 = c2.mulConj(omega);
                        c3 = c3.mulConj(omega.cube());
                        iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
                        c0.store(it, ToRIRI{}), c1.store(it + stride1, ToRIRI{}), c2.store(it + stride2, ToRIRI{}), c3.store(it + stride3, ToRIRI{});
                    }
                }
            };
            constexpr size_t FFTFixed::LOG_MAX;
            constexpr size_t FFTFixed::LOG_SHORT;
            constexpr size_t FFTFixed::SHORT_LEN;
            constexpr size_t FFTFixed::MAX_LEN;

            const TableFix<Float64, 4> FFTFixed::table_8(8, 1, 4);
            const TableFix<Float64, 4> FFTFixed::table_16_1(16, 1, 4);
            const TableFix<Float64, 4> FFTFixed::table_16_3(16, 3, 4);
            const TableFix<Float64, 8> FFTFixed::table_32_1(32, 1, 4);
            const TableFix<Float64, 8> FFTFixed::table_32_3(32, 3, 4);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_2(2);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_3(3);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_MAX, 4> FFTFixed::multi_table_1(1);

            constexpr uint32_t bitrev32(uint32_t n)
            {
                constexpr uint32_t mask55 = 0x55555555;
                constexpr uint32_t mask33 = 0x33333333;
                constexpr uint32_t mask0f = 0x0f0f0f0f;
                constexpr uint32_t maskff = 0x00ff00ff;
                n = ((n & mask55) << 1) | ((n >> 1) & mask55);
                n = ((n & mask33) << 2) | ((n >> 2) & mask33);
                n = ((n & mask0f) << 4) | ((n >> 4) & mask0f);
                n = ((n & maskff) << 8) | ((n >> 8) & maskff);
                return (n << 16) | (n >> 16);
            }
            constexpr uint32_t bitrev(uint32_t n, int len)
            {
                assert(len <= 32);
                return bitrev32(n) >> (32 - len);
            }

            class BinRevTableC64X4HP
            {
            public:
                using F64 = double;
                using C64 = std::complex<F64>;
                using C64X4 = hint_simd::Complex64X4;
                static constexpr int MAX_LOG_LEN = 32, LOG_BLOCK = 2, BLOCK = 1 << LOG_BLOCK;
                static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;

                BinRevTableC64X4HP(int log_max_iter_in, int log_fft_len_in)
                    : index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
                {
                    assert(log_max_iter <= log_fft_len);
                    assert(log_fft_len <= MAX_LOG_LEN);
                    const F64 factor = F64(1) / (size_t(1) << (log_fft_len - log_max_iter));
                    for (int i = 0; i < MAX_LOG_LEN; i++)
                    {
                        units[i] = getOmega(size_t(1) << (i + 1), 1, factor);
                    }
                    auto fp = reinterpret_cast<F64 *>(table);
                    fp[0] = 1, fp[BLOCK] = 0;
                    for (int i = 1; i < BLOCK; i++)
                    {
                        C64 omega = getOmega(BLOCK, bitrev(i, LOG_BLOCK), factor);
                        fp[i] = omega.real(), fp[i + BLOCK] = omega.imag();
                    }
                }

                // Only for power of 2
                void reset(size_t i = 0)
                {
                    if (i == 0)
                    {
                        pop = 0, index = i;
                        return;
                    }
                    assert((i & (i - 1)) == 0);
                    assert(i % BLOCK == 0);
                    pop = 1, index = i / BLOCK;
                    int zero = hint_ctz(index);
                    auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
                    table[1].load1(fp, fp + 1);
                    table[1] = table[1].mul(table[0]);
                }
                C64X4 iterate()
                {
                    C64X4 res = table[pop], unit4;
                    index++;
                    int zero = hint_ctz(index);
                    auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
                    unit4.load1(fp, fp + 1);
                    pop -= zero;
                    table[pop + 1] = table[pop].mul(unit4);
                    pop++;
                    return res;
                }

                static C64 getOmega(size_t n, size_t index, F64 factor = 1)
                {
                    F64 theta = -HINT_2PI * index / n;
                    return std::polar<F64>(1, theta * factor);
                }

            private:
                C64 units[MAX_LOG_LEN]{};
                C64X4 table[MAX_LOG_LEN]{};
                size_t index;
                int pop;
                int log_max_iter, log_fft_len;
            };

            template <size_t RI_DIFF = 1, typename FloatTy>
            inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1,
                                 const std::complex<FloatTy> &omega0, const FloatTy factor = 1)
            {
                using Complex = std::complex<FloatTy>;
                auto combine2 = [&omega0](auto r0, auto i0, auto r1, auto i1, Complex &out0, Complex &out1)
                {
                    auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
                    auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

                    r0 = ti1 * omega0.real() + tr1 * omega0.imag();
                    i0 = ti1 * omega0.imag() - tr1 * omega0.real();

                    out0.real(tr0 + r0);
                    out0.imag(ti0 + i0);

                    out1.real(tr0 - r0);
                    out1.imag(i0 - ti0);
                };
                Complex x0, x1, x2, x3;
                auto r0 = inout0[0], i0 = inout0[RI_DIFF], r1 = inout1[0], i1 = inout1[RI_DIFF];
                combine2(r0, i0, r1, i1, x0, x1);
                r0 = in0[0], i0 = in0[RI_DIFF], r1 = in1[0], i1 = in1[RI_DIFF];
                combine2(r0, i0, r1, i1, x2, x3);
                x0 *= x2;
                x1 *= x3;
                { // separate2
                    r0 = x0.real(), i0 = x0.imag(), r1 = x1.real(), i1 = x1.imag();
                    auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
                    auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

                    auto r = tr1 * omega0.imag() - ti1 * omega0.real();
                    auto i = tr1 * omega0.real() + ti1 * omega0.imag();

                    r0 = tr0 + r;
                    i0 = ti0 + i;

                    r1 = tr0 - r;
                    i1 = i - ti0;
                }
                inout0[0] = r0 * factor, inout0[RI_DIFF] = i0 * factor, inout1[0] = r1 * factor, inout1[RI_DIFF] = i1 * factor;
            }
            template <typename FloatTy>
            class BinRevTableComplexIterHP
            {
            public:
                using Complex = std::complex<FloatTy>;
                static constexpr int MAX_LOG_LEN = 32;
                static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;

                BinRevTableComplexIterHP(int log_max_iter_in, int log_fft_len_in)
                    : index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
                {
                    assert(log_max_iter <= log_fft_len);
                    assert(log_fft_len <= MAX_LOG_LEN);
                    const FloatTy factor = FloatTy(1) / (size_t(1) << (log_fft_len - log_max_iter));
                    for (int i = 0; i < MAX_LOG_LEN; i++)
                    {
                        units[i] = getOmega(size_t(1) << (i + 1), 1, factor);
                    }
                    table[0] = Complex(1, 0);
                }
                void reset(size_t i = 0)
                {
                    index = i;
                    if (i == 0)
                    {
                        pop = 0;
                        return;
                    }
                    pop = hint_popcnt(i);
                    const size_t len = size_t(1) << log_fft_len;
                    for (int p = pop; p > 0; p--)
                    {
                        table[p] = getOmega(len, bitrev(i, log_max_iter));
                        i &= (i - 1);
                    }
                }
                Complex iterate()
                {
                    Complex res = table[pop];
                    index++;
                    int zero = hint_ctz(index);
                    pop -= zero;
                    table[pop + 1] = table[pop] * units[zero];
                    pop++;
                    return res;
                }

                static Complex getOmega(size_t n, size_t index, FloatTy factor = 1)
                {
                    FloatTy theta = -HINT_2PI * index / n;
                    return std::polar<FloatTy>(1, theta * factor);
                }

            private:
                Complex units[MAX_LOG_LEN]{};
                Complex table[MAX_LOG_LEN]{};
                size_t index;
                int pop;
                int log_max_iter, log_fft_len;
            };

            template <typename Float>
            inline void real_dot_binrev(Float in_out[], const Float in[], size_t float_len)
            {
                assert(is_2pow(float_len));
                using Complex = std::complex<Float>;
                Float inv = 2.0 / float_len;
                {
                    auto r0 = in_out[0], i0 = in_out[1], r1 = in[0], i1 = in[1];
                    transform2(r0, i0);
                    transform2(r1, i1);
                    r0 *= r1, i0 *= i1;
                    transform2(r0, i0);
                    in_out[0] = r0 * 0.5 * inv, in_out[1] = i0 * 0.5 * inv;
                }
                auto temp = Complex(in_out[2], in_out[3]) * Complex(in[2], in[3]) * inv;
                in_out[2] = temp.real(), in_out[3] = temp.imag();
                inv /= 8;
                dot_rfft(&in_out[4], &in_out[6], &in[4], &in[6], Complex(COS_PI_8, -COS_PI_8), inv);
                BinRevTableComplexIterHP<Float> table(31, 32);
                for (size_t begin = 8; begin < float_len; begin *= 2)
                {
                    table.reset(begin / 2);
                    auto it0 = in_out + begin, it1 = it0 + begin - 2, it2 = in + begin, it3 = it2 + begin - 2;
                    for (; it0 < it1; it0 += 2, it1 -= 2, it2 += 2, it3 -= 2)
                    {
                        auto omega = table.iterate();
                        dot_rfft(it0, it1, it2, it3, omega, inv);
                    }
                }
            }
            inline void dot_rfftX4(F64 *inout0, F64 *inout1, const F64 *in0, const F64 *in1, const C64X4 &omega0, const F64X4 &inv)
            {
                auto combine2 = [&omega0](C64X4 c0, C64X4 c1, C64X4 &out0, C64X4 &out1)
                {
                    auto tr0 = c0.real + c1.real, ti0 = c0.imag - c1.imag; // sum
                    auto tr1 = c0.real - c1.real, ti1 = c0.imag + c1.imag; // diff

                    c0.real = F64X4::fmadd(ti1, omega0.real, tr1 * omega0.imag);
                    c0.imag = F64X4::fmsub(ti1, omega0.imag, tr1 * omega0.real);

                    out0.real = tr0 + c0.real;
                    out0.imag = ti0 + c0.imag;

                    out1.real = tr0 - c0.real;
                    out1.imag = c0.imag - ti0;
                };
                C64X4 x0, x1;
                {
                    C64X4 x2, x3, x4, x5;
                    x0.load(inout0), x1.load(inout1);
                    x1.real = x1.real.reverse();
                    x1.imag = x1.imag.reverse();

                    combine2(x0, x1, x2, x3);

                    x0.load(in0), x1.load(in1);
                    x1.real = x1.real.reverse();
                    x1.imag = x1.imag.reverse();

                    combine2(x0, x1, x4, x5);

                    x0 = x2.mul(x4);
                    x1 = x3.mul(x5);
                }
                {                                                          // separate2
                    auto tr0 = x0.real + x1.real, ti0 = x0.imag - x1.imag; // sum
                    auto tr1 = x0.real - x1.real, ti1 = x0.imag + x1.imag; // diff
                    auto r = F64X4::fmsub(tr1, omega0.imag, ti1 * omega0.real);
                    auto i = F64X4::fmadd(tr1, omega0.real, ti1 * omega0.imag);

                    x0.real = (tr0 + r) * inv;
                    x0.imag = (ti0 + i) * inv;
                    x1.real = (tr0 - r) * inv;
                    x1.imag = (i - ti0) * inv;
                }
                x1.real = x1.real.reverse();
                x1.imag = x1.imag.reverse();
                x0.store(inout0), x1.store(inout1);
            }
            template <typename Float>
            inline void idit(std::complex<Float> in_out[], size_t len, bool norm = false)
            {
                using Complex = std::complex<Float>;
                BinRevTableComplexIterHP<Float> table(31, 32);
                for (size_t rank = 2; rank <= len; rank *= 2)
                {
                    size_t stride = rank / 2;
                    auto it0 = in_out, it1 = it0 + stride;
                    table.reset(0);
                    for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
                    {
                        Complex omega = table.iterate();
                        for (size_t i = 0; i < stride; i++)
                        {
                            Complex x = it0[i];
                            Complex y = it1[i];
                            it0[i] = x + y;
                            it1[i] = (x - y) * std::conj(omega);
                        }
                    }
                }
                if (norm)
                {
                    const double len_inv = 1.0 / len;
                    for (size_t i = 0; i < len; i++)
                    {
                        in_out[i] *= len_inv;
                    }
                }
            }

            template <typename Float>
            inline void dif(std::complex<Float> in_out[], size_t len)
            {
                using Complex = std::complex<Float>;
                BinRevTableComplexIterHP<Float> table(31, 32);
                for (size_t rank = len; rank >= 2; rank /= 2)
                {
                    size_t stride = rank / 2;
                    auto it0 = in_out, it1 = it0 + stride;
                    table.reset(0);
                    for (size_t begin = 0; begin < len; begin += rank, it0 += rank, it1 += rank)
                    {
                        Complex omega = table.iterate();
                        for (size_t i = 0; i < stride; i++)
                        {
                            Complex x = it0[i];
                            Complex y = it1[i] * omega;
                            it0[i] = x + y;
                            it1[i] = x - y;
                        }
                    }
                }
            }

            template <typename Float>
            inline void idit(Float in_out[], size_t len, bool norm = false)
            {
                using Complex = std::complex<Float>;
                auto in_out_c = reinterpret_cast<Complex *>(in_out);
                idit(in_out_c, len / 2, norm);
            }

            template <typename Float>
            inline void dif(Float in_out[], size_t len)
            {
                using Complex = std::complex<Float>;
                auto in_out_c = reinterpret_cast<Complex *>(in_out);
                dif(in_out_c, len / 2);
            }

            inline void real_conv_rfft(Float64 in_out1[], Float64 in2[], size_t len)
            {
                assert(is_2pow(len));
                dif(in_out1, len);
                dif(in2, len);
                hint::transform::fft::real_dot_binrev(in_out1, in2, len);
                idit(in_out1, len);
            }

            inline void real_dot_binrev4(Float64 in_out[], Float64 in[], size_t float_len)
            {
                using Complex = std::complex<Float64>;
                Float64 inv = 2.0 / float_len;
                {
                    auto r0 = in_out[0], i0 = in_out[4], r1 = in[0], i1 = in[4];
                    transform2(r0, i0);
                    transform2(r1, i1);
                    r0 *= r1, i0 *= i1;
                    transform2(r0, i0);
                    in_out[0] = r0 * 0.5 * inv, in_out[4] = i0 * 0.5 * inv;
                }
                auto temp = Complex(in_out[1], in_out[5]) * Complex(in[1], in[5]) * inv;
                in_out[1] = temp.real(), in_out[5] = temp.imag();
                inv /= 8;
                static BinRevTableC64X4HP table(31, 32);
                dot_rfft<4>(&in_out[2], &in_out[3], &in[2], &in[3], Complex(COS_PI_8, -COS_PI_8), inv);
                constexpr Float64 COS_16_1 = 0.92387953251128675612818318939;
                constexpr Float64 SIN_16_1 = 0.38268343236508977172845998403;
                dot_rfft<4>(&in_out[8], &in_out[11], &in[8], &in[11], C64(COS_16_1, -SIN_16_1), inv);
                dot_rfft<4>(&in_out[9], &in_out[10], &in[9], &in[10], C64(-SIN_16_1, -COS_16_1), inv);
                const Float64X4 inv4 = F64X4(inv);
                for (size_t begin = 16; begin < float_len; begin *= 2)
                {
                    table.reset(begin / 2);
                    auto it0 = in_out + begin, it1 = it0 + begin - 8, it2 = in + begin, it3 = it2 + begin - 8;
                    for (; it0 < it1; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8)
                    {
                        auto omega = table.iterate();
                        dot_rfftX4(it0, it1, it2, it3, omega, inv4);
                    }
                }
            }

            template <bool TO_INT = false>
            inline void real_conv_avx(F64 *in_out1, F64 *in2, size_t float_len)
            {
                assert(is_2pow(float_len));
                FFTFixed::difRec<true>(in_out1, float_len);
                FFTFixed::difRec<true>(in2, float_len);
                real_dot_binrev4(in_out1, in2, float_len);
                FFTFixed::iditRec<true, TO_INT>(in_out1, float_len);
            }
        }
    }
}

namespace string_util
{

    using namespace hint;
    using namespace transform;
    using namespace fft;
    class ItoStrBase10000
    {
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num)
        {
            uint32_t res = (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                           ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res + '0' * 0x1010101;
        }
        constexpr ItoStrBase10000()
        {
            for (size_t i = 0; i < 10000; i++)
            {
                table[i] = itosbase10000(i);
            }
        }
        void tostr(char *str, uint32_t num) const
        {
            // *reinterpret_cast<uint32_t *>(str) = table[num];
            std::memcpy(str, &table[num], sizeof(num));
        }
        uint32_t tostr(uint32_t num) const
        {
            return table[num];
        }
    };

    class StrtoIBase100
    {
    private:
        static constexpr size_t TABLE_SIZE = size_t(1) << 15;
        uint16_t table[TABLE_SIZE]{};

    public:
        static constexpr uint16_t itosbase100(uint16_t num)
        {
            uint16_t res = (num / 10 % 10) | ((num % 10) << 8);
            return res + '0' * 0x0101;
        }
        constexpr StrtoIBase100()
        {
            for (size_t i = 0; i < TABLE_SIZE; i++)
            {
                table[i] = UINT16_MAX;
            }
            for (size_t i = 0; i < 100; i++)
            {
                table[itosbase100(i)] = i;
            }
        }
        uint16_t toInt(const char *str) const
        {
            uint16_t num;
            std::memcpy(&num, str, sizeof(num));
            return table[num];
        }
    };

    constexpr ItoStrBase10000 itosbase10000{};
    constexpr StrtoIBase100 strtoibase100{};

    constexpr uint32_t stobase10000(const char *s)
    {
        return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
    }

    template <typename T, size_t ALIGN = 64>
    class AlignMem
    {
    public:
        using Ptr = T *;
        using ConstPtr = const T *;
        AlignMem() : ptr(nullptr) {}
        AlignMem(size_t n) : len(n), ptr(reinterpret_cast<Ptr>(_mm_malloc(n * sizeof(T), ALIGN))) {}
        ~AlignMem()
        {
            if (ptr)
            {
                _mm_free(ptr);
            }
        };
        T &operator[](size_t i)
        {
            return ptr[i];
        }
        const T &operator[](size_t i) const
        {
            return ptr[i];
        }
        Ptr begin()
        {
            return ptr;
        }
        Ptr end()
        {
            return ptr + len;
        }
        ConstPtr begin() const
        {
            return ptr;
        }
        ConstPtr end() const
        {
            return ptr + len;
        }

    private:
        T *ptr;
        size_t len;
    };

    template <typename T>
    void fill_zero(T *begin, T *end)
    {
        std::memset(begin, 0, (end - begin) * sizeof(T));
    }

    template <typename T>
    size_t str_num_to_array_base10000(const char *str, size_t len, T *ary)
    {
        constexpr size_t BLOCK = 4;
        auto end = str + len, p = str;
        size_t i = 0;
        for (auto ed = end - len % BLOCK; p < ed; p += BLOCK, i++)
        {
            ary[i] = stobase10000(p);
        }
        size_t shift = 0;
        if (p < end)
        {
            size_t rem = end - p;
            int n = 0;
            for (; p < end; p++)
            {
                n = n * 10 + *p - '0';
            }
            shift = BLOCK - rem;
            for (; rem < BLOCK; rem++)
            {
                n *= 10;
            }
            ary[i] = n;
            i++;
        }
        return shift;
    }

    template <typename T>
    size_t conv_to_str_base10000(const T *ary, size_t conv_len, size_t shift, char *res, size_t &res_len)
    {
        constexpr size_t BLOCK = 4, BASE = 10000;
        res_len = (conv_len + 1) * BLOCK;
        auto end = res + res_len;
        size_t i = conv_len;
        uint64_t carry = 0;
        while (i > 0)
        {
            i--;
            end -= BLOCK;
            carry += uint64_t(ary[i] + 0.5);
            itosbase10000.tostr(end, carry % BASE);
            carry /= BASE;
        }
        assert(carry < BASE);
        end -= 4;
        itosbase10000.tostr(end, carry);
        while (*end == '0')
        {
            end++;
        }
        size_t offset = end - res;
        res_len -= (offset + shift);
        return offset;
    }

    // return result begin
    char *big_mul(const char *str1, size_t len1, const char *str2, size_t len2, char *res, size_t &res_len)
    {
        constexpr size_t BLOCK = 4, BASE = 10000;
        size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
        size_t conv_len = block_len1 + block_len2 - 1, fft_len = hint::int_ceil2(conv_len);
        AlignMem<Float64> ary1(fft_len), ary2(fft_len);
        size_t shift = str_num_to_array_base10000(str1, len1, &ary1[0]);
        shift += str_num_to_array_base10000(str2, len2, &ary2[0]);
        fill_zero(ary1.begin() + block_len1, ary1.end());
        fill_zero(ary2.begin() + block_len2, ary2.end());
        real_conv_avx(ary1.begin(), ary2.begin(), fft_len);
        // real_conv_rfft(ary1.begin(), ary2.begin(), fft_len);
        return res + conv_to_str_base10000(ary1.begin(), conv_len, shift, res, res_len);
    }

    size_t preserve_strlen(size_t len1, size_t len2)
    {
        constexpr size_t BLOCK = 4;
        size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
        return (block_len1 + block_len2) * BLOCK;
    }
    void mul_test()
    {
        std::string s1, s2;
        std::cin >> s1 >> s2;
        size_t len1 = s1.size(), len2 = s2.size();
        size_t res_len = preserve_strlen(len1, len2);
        std::vector<char> res(res_len, '0');
        auto begin = big_mul(s1.data(), len1, s2.data(), len2, res.data(), res_len);
        auto end = begin + res_len;
        fwrite(begin, 1, res_len, stdout);
    }
}

int main()
{
    string_util::mul_test();
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #136.925 ms79 MB + 676 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-18 21:50:05 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠