提交记录 28525


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 8.305 ms 12248 KB C++14 52.89 KB
提交时间 评测时间
2025-09-21 22:58:08 2025-09-21 22:58:11
// TSKY 2025/9/16
#include <array>
#include <complex>
#include <iostream>
#include <type_traits>
#include <cstdint>
#include <climits>
#include <cstring>
#include <cassert>

#include <immintrin.h>

#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("avx2")
#pragma GCC target("fma")

namespace hint_simd
{
    template <typename YMM>
    inline void transpose64_2X4(YMM &row0, YMM &row1){
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1));
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1));

        row0 = YMM(_mm256_permute2f128_pd(t0, t1, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t0, t1, 0x31));
    }
    template <typename YMM>
    inline void transpose64_4X2(YMM &row0, YMM &row1){
        auto t0 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x20);
        auto t1 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x31);
        row0 = YMM(_mm256_unpacklo_pd(t0, t1));
        row1 = YMM(_mm256_unpackhi_pd(t0, t1));
    }

    template <typename YMM>
    inline void transpose64_4X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3){
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1));
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1));
        auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3));
        auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3));

        row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
        row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
        row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
    }

    class Float64X4
    {
    public:
        using F64 = double;
        using F64X4 = Float64X4;
        Float64X4() : data(_mm256_setzero_pd()) {}
        Float64X4(__m256d in_data) : data(in_data) {}
        Float64X4(F64 in_data) : data(_mm256_set1_pd(in_data)) {}
        Float64X4(const F64 *in_data) : data(_mm256_load_pd(in_data)) {}

        F64X4 operator+(const F64X4 &other) const{
            return _mm256_add_pd(data, other.data);}
        F64X4 operator-(const F64X4 &other) const{
            return _mm256_sub_pd(data, other.data);}
        F64X4 operator*(const F64X4 &other) const{
            return _mm256_mul_pd(data, other.data);}
        F64X4 operator/(const F64X4 &other) const{
            return _mm256_div_pd(data, other.data);}

        F64X4 &operator+=(const F64X4 &other){
            return *this = *this + other;}
        F64X4 &operator-=(const F64X4 &other){
            return *this = *this - other;}
        F64X4 &operator*=(const F64X4 &other){
            return *this = *this * other;}
        F64X4 &operator/=(const F64X4 &other){
            return *this = *this / other;}
        F64X4 floor() const{
            return _mm256_floor_pd(data);}
        static F64X4 fmadd(const F64X4 &a, const F64X4 &b, const F64X4 &c){
            return _mm256_fmadd_pd(a.data, b.data, c.data);}
        static F64X4 fmsub(const F64X4 &a, const F64X4 &b, const F64X4 &c){
            return _mm256_fmsub_pd(a.data, b.data, c.data);}
        template <int N>
        F64X4 permute4x64() const{
            return _mm256_permute4x64_pd(data, N);}
        static F64X4 extractEven64X4(const F64X4 &in0, const F64X4 &in1){
            F64X4 result = _mm256_unpacklo_pd(in0.data, in1.data);
            return result.permute4x64<0b11011000>();
        }

        template <int N>
        F64X4 permute() const{
            return _mm256_permute_pd(data, N);}
        F64X4 reverse() const{
            return permute4x64<0b00011011>();}
        void load(const F64 *p){
            data = _mm256_load_pd(p);}
        void loadu(const F64 *p){
            data = _mm256_loadu_pd(p);}
        void load1(const F64 *p){
            data = _mm256_broadcast_sd(p);}
        static F64X4 fromMem(const F64 *p){
            return _mm256_load_pd(p);}
        static F64X4 fromUMem(const F64 *p){
            return _mm256_loadu_pd(p);}
        void store(F64 *p) const{
            _mm256_store_pd(p, data);}
        void storeu(F64 *p) const{
            _mm256_storeu_pd(p, data);}
        operator __m256d() const{
            return data;}
        __m256i toI64X4() const{
            constexpr uint64_t mask = (uint64_t(1) << 52) - 1;
            constexpr uint64_t offset = (uint64_t(1) << 10) - 1;
            const __m256i f64bits = _mm256_castpd_si256(data);
            __m256i tail = _mm256_and_si256(f64bits, _mm256_set1_epi64x(mask));
            tail = _mm256_or_si256(tail, _mm256_set1_epi64x(mask + 1));
            __m256i exp = _mm256_srli_epi64(f64bits, 52);
            exp = _mm256_sub_epi64(_mm256_set1_epi64x(offset + 52), exp);
            return _mm256_srlv_epi64(tail, exp);}
    private:
        __m256d data;
    };

    struct Complex64X4
    {
        using C64X4 = Complex64X4;
        using F64X4 = Float64X4;
        using F64 = double;
        Complex64X4() {}
        Complex64X4(F64X4 real, F64X4 imag) : real(real), imag(imag) {}
        Complex64X4(const F64 *p) : real(p), imag(p + 4) {}
        Complex64X4(const F64 *p_real, const F64 *p_imag) : real(p_real), imag(p_imag) {}
        C64X4 operator+(const C64X4 &other) const{
            return C64X4(real + other.real, imag + other.imag);}
        C64X4 operator-(const C64X4 &other) const{
            return C64X4(real - other.real, imag - other.imag);}
        C64X4 operator*(const F64X4 &other) const{
            return C64X4(real * other, imag * other);}
        C64X4 mul(const C64X4 &other) const{
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmsub(real, other.real, ii);
            const F64X4 i = F64X4::fmadd(imag, other.real, ri);
            return C64X4(r, i);}
        C64X4 mulConj(const C64X4 &other) const{
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmadd(real, other.real, ii);
            const F64X4 i = F64X4::fmsub(imag, other.real, ri);
            return C64X4(r, i);}
        C64X4 reverse() const{
            return C64X4(real.reverse(), imag.reverse());}

        template <typename T>
        void load(const T *p, std::false_type){
            this->load(p);}
        template <typename T>
        void load(const T *p, std::true_type){
            this->load(p);
            *this = this->toRRIIPermu();}
        template <typename T>
        void load(const T *p){
            real.load(reinterpret_cast<const F64 *>(p));
            imag.load(reinterpret_cast<const F64 *>(p) + 4);}
        template <typename T>
        void loadu(const T *p){
            real.loadu(reinterpret_cast<const F64 *>(p));
            imag.loadu(reinterpret_cast<const F64 *>(p) + 4);}
        void load1(const F64 *real_p, const F64 *imag_p){
            real.load1(real_p);
            imag.load1(imag_p);}

        template <typename T>
        void store(T *p, std::false_type) const{
            this->store(p);}
        template <typename T>
        void store(T *p, std::true_type) const{
            this->toRIRIPermu().store(p);}
        template <typename T>
        void store(T *p) const{
            real.store(reinterpret_cast<F64 *>(p));
            imag.store(reinterpret_cast<F64 *>(p) + 4);}
        template <typename T>
        void storeu(T *p) const{
            real.storeu(reinterpret_cast<F64 *>(p));
            imag.storeu(reinterpret_cast<F64 *>(p) + 4);}
        C64X4 square() const{
            const F64X4 ii = imag * imag;
            const F64X4 ri = real * imag;
            const F64X4 r = F64X4::fmsub(real, real, ii);
            const F64X4 i = ri + ri;
            return C64X4(r, i);}
        C64X4 cube() const{
            const F64X4 rr = real * real;
            const F64X4 ii = imag * imag;
            const F64X4 rr3 = rr + rr + rr;
            const F64X4 ii3 = ii + ii + ii;
            const F64X4 r = real * (rr - ii3);
            const F64X4 i = imag * (rr3 - ii);
            return C64X4(r, i);}
        C64X4 toRIRIPermu() const{
            C64X4 res = *this;
            transpose64_2X4(res.real, res.imag);
            return res;}
        C64X4 toRRIIPermu() const{
            C64X4 res = *this;
            transpose64_4X2(res.real, res.imag);
            return res;}
        C64X4 transToI64(std::false_type) const{
            return *this;}
        C64X4 transToI64(std::true_type) const{
            constexpr int64_t F1_2 = 4602678819172646912;
            auto F1_2X4 = F64X4(__m256d(_mm256_set1_epi64x(F1_2)));
            auto real_i64 = (real + F1_2X4).toI64X4();
            auto imag_i64 = (imag + F1_2X4).toI64X4();
            return C64X4(__m256d(real_i64), __m256d(imag_i64));}

        F64X4 real, imag;
    };
}
#endif

namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    constexpr Float64 COS_PI_8 = 0.707106781186547524400844;

    constexpr size_t FFT_MAX_LEN = size_t(1) << 23;

    template <typename T>
    constexpr T int_floor2(T n){
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2){
            n |= (n >> i);}
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n){
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2){
            n |= (n >> i);}
        return n + 1;
    }

    template <typename IntTy>
    constexpr bool is_2pow(IntTy n){
        return n != 0 && (n & (n - 1)) == 0;
    }

    template <typename T>
    constexpr int hint_log2(T n){
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r){
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n){
                r = mid;}
            else{
                l = mid;}}
        return l;
    }
    constexpr int hint_ctz(uint32_t x){
        int r0 = 31;
        x &= (-x);
        if (x & 0x55555555){
            r0 &= ~1;}
        if (x & 0x33333333){
            r0 &= ~2;}
        if (x & 0x0F0F0F0F){
            r0 &= ~4;}
        if (x & 0x00FF00FF){
            r0 &= ~8;}
        if (x & 0x0000FFFF){
            r0 &= ~16;}
        r0 += (x == 0);
        return r0;
    }

    constexpr int hint_popcnt(uint32_t n){
        constexpr uint32_t mask55 = 0x55555555;
        constexpr uint32_t mask33 = 0x33333333;
        constexpr uint32_t mask0f = 0x0f0f0f0f;
        constexpr uint32_t maskff = 0x00ff00ff;
        n = (n & mask55) + ((n >> 1) & mask55);
        n = (n & mask33) + ((n >> 2) & mask33);
        n = (n & mask0f) + ((n >> 4) & mask0f);
        n = (n & maskff) + ((n >> 8) & maskff);
        return uint16_t(n) + (n >> 16);
    }

    namespace transform
    {
        using namespace hint_simd;

        template <typename T>
        inline void transform2(T &sum, T &diff){
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;}

        template <typename T>
        inline void transform2(const T a, const T b, T &sum, T &diff){
            sum = a + b;
            diff = a - b;}
        namespace fft{
            using F64 = Float64;
            using C64 = std::complex<F64>;
            using F64X4 = Float64X4;
            using C64X4 = Complex64X4;
            template <typename Float, size_t OMEGA_LEN>
            class TableFix{
                alignas(64) std::array<Float, OMEGA_LEN * 2> table;

            public:
                TableFix(size_t theta_divider, size_t factor, size_t stride){
                    const Float theta = -HINT_2PI * factor / theta_divider;
                    assert(OMEGA_LEN % stride == 0);
                    for (size_t begin = 0, index = 0; begin < OMEGA_LEN * 2; begin += stride * 2){
                        for (size_t j = 0; j < stride; j++, index++){
                            table[begin + j] = std::cos(theta * index);
                            table[begin + j + stride] = std::sin(theta * index);
                        }}}
                constexpr const Float &operator[](size_t index) const
                {
                    return table[index];}
                constexpr const Float *getOmegaIt(size_t index) const
                {
                    return &table[index];}};
            template <typename Float, int LOG_BEGIN, int LOG_END, int DIV>
            class TableFixMulti{
                static_assert(LOG_END >= LOG_BEGIN);
                static_assert(is_2pow(DIV));
                static constexpr size_t TABLE_CPX_LEN = (size_t(1) << (LOG_END + 1)) / DIV;
                alignas(64) std::array<Float, TABLE_CPX_LEN * 2> table;

            public:
                TableFixMulti(size_t factor, size_t stride = 4){
                    assert(((size_t(1) << LOG_BEGIN) / DIV) % stride == 0);
                    initBottomUp(factor, stride);}
                void initTopDown(size_t factor, size_t stride){
                    static_assert(std::is_same<Float, Float64>::value);
                    assert(stride == 4);}
                void initBottomUp(size_t factor, size_t stride){
                    static_assert(std::is_same<Float, Float64>::value);
                    assert(stride == 4);
                    size_t len = size_t(1) << LOG_BEGIN, cpx_len = len / DIV;
                    Float theta = -HINT_2PI * factor / len;
                    auto it = getBeginLog(LOG_BEGIN);
                    for (size_t i = 0; i < cpx_len; i++){
                        it[0] = std::cos(theta * i), it[stride] = std::sin(theta * i);
                        it += (i % stride == stride - 1 ? stride + 1 : 1);}
                    it = getBeginLog(LOG_BEGIN);
                    for (int log_len = LOG_BEGIN + 1; log_len <= LOG_END; log_len++){
                        len = size_t(1) << log_len, cpx_len = len / DIV;
                        theta = -HINT_2PI * factor / len;
                        auto it = getBeginLog(log_len), it_last = getBeginLog(log_len - 1);
                        C64X4 unit(std::cos(theta), std::sin(theta));
                        for (auto end = it + cpx_len * 2; it < end; it += 16, it_last += 8){
                            Complex64X4 omega0, omega1;
                            omega0.load(it_last);
                            omega1 = omega0.mul(unit);
                            transpose64_2X4(omega0.real, omega1.real);
                            transpose64_2X4(omega0.imag, omega1.imag);
                            omega0.store(it), omega1.store(it + 8);
                        }}}
                constexpr const Float *getBeginLog(int log_rank) const
                {
                    return getBegin(size_t(1) << log_rank);}
                constexpr Float *getBeginLog(int log_rank){
                    return getBegin(size_t(1) << log_rank);}
                constexpr const Float *getBegin(size_t rank) const
                {
                    return &table[rank * 2 / DIV];}
                constexpr Float *getBegin(size_t rank){
                    return &table[rank * 2 / DIV];}};

            struct FFT{
                template <typename Float>
                static void dif4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3){
                    difSplit(r0, i0, r1, i1, r2, i2, r3, i3);
                    transform2(r0, r1);
                    transform2(i0, i1);}
                template <typename Float>
                static void idit4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3){
                    transform2(r0, r1);
                    transform2(i0, i1);
                    iditSplit(r0, i0, r1, i1, r2, i2, r3, i3);}
                template <typename Float>
                static void difSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3){
                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, r3);
                    transform2(i1, i3);

                    transform2(r2, i3);
                    transform2(i2, r3, r3, i2);
                    std::swap(i3, r3);}
                template <typename Float>
                static void iditSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3){
                    transform2(r2, r3);
                    transform2(i2, i3);

                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, i3, i3, r1);
                    transform2(i1, r3);
                    std::swap(i3, r3);}};

            struct FFTAVX : public FFT{
                static constexpr size_t LOG_SHORT = 10;
                static constexpr size_t LOG_MID = 14;
                static constexpr size_t LOG_MAX = 18;
                static constexpr size_t SHORT_LEN = size_t(1) << LOG_SHORT;
                static constexpr size_t MID_LEN = size_t(1) << LOG_MID;
                static constexpr size_t MAX_LEN = size_t(1) << LOG_MAX;
                static const TableFix<Float64, 4> table_8;
                static const TableFix<Float64, 4> table_16_1;
                static const TableFix<Float64, 4> table_16_3;
                static const TableFix<Float64, 8> table_32_1;
                static const TableFix<Float64, 8> table_32_3;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_3;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_2;
                static const TableFixMulti<Float64, 6, LOG_MAX, 4> multi_table_1;

                static constexpr const Float64 *it8 = &table_8[0];
                static constexpr const Float64 *it16_1 = &table_16_1[0];
                static constexpr const Float64 *it16_3 = &table_16_3[0];
                static constexpr const Float64 *it32_1 = &table_32_1[0];
                static constexpr const Float64 *it32_3 = &table_32_3[0];

                static void dif4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3){
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    dif4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);}
                static void idit4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3){
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    idit4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);}
                static void dif8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega){
                    transform2(c0, c1);
                    transform2(c2, c3);
                    c1 = c1.mul(omega);
                    c3 = c3.mul(omega);
                    FFTAVX::dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);}
                static void idit8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega){
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c1 = c1.mulConj(omega);
                    c3 = c3.mulConj(omega);
                    transform2(c0, c1);
                    transform2(c2, c3);}
                static void dif8x2(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    dif8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);}
                static void idit8x2(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    idit8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);}
                static void dif16(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mul(omega);
                    omega.load(it16_1), c2 = c2.mul(omega);
                    omega.load(it16_3), c3 = c3.mul(omega);
                    dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);}
                static void idit16(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mulConj(omega);
                    omega.load(it16_1), c2 = c2.mulConj(omega);
                    omega.load(it16_3), c3 = c3.mulConj(omega);
                    idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);}
                static void dif32(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 16), c2.load(in_out + 32), c3.load(in_out + 48);
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1), c2 = c2.mul(omega);
                    omega.load(it32_3), c3 = c3.mul(omega);
                    c0.store(in_out), c1.store(in_out + 16), c2.store(in_out + 32), c3.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1 + 8), c2 = c2.mul(omega);
                    omega.load(it32_3 + 8), c3 = c3.mul(omega);
                    c0.store(in_out + 8), c1.store(in_out + 24);
                    c0.load(in_out + 32), c1.load(in_out + 48), omega.load(it8);
                    dif8x2(c0, c2, c1, c3, omega);
                    c0.store(in_out + 32), c2.store(in_out + 40), c1.store(in_out + 48), c3.store(in_out + 56);
                    dif16(in_out);}
                static void idit32(Float64 in_out[]){
                    Complex64X4 c0, c1, c2, c3, omega;
                    idit16(in_out);
                    c0.load(in_out + 32), c1.load(in_out + 40), c2.load(in_out + 48), c3.load(in_out + 56), omega.load(it8);
                    idit8x2(c0, c1, c2, c3, omega);
                    c1.store(in_out + 40), c3.store(in_out + 56);

                    c1.load(in_out), c3.load(in_out + 16);
                    omega.load(it32_1), c0 = c0.mulConj(omega);
                    omega.load(it32_3), c2 = c2.mulConj(omega);
                    iditSplit(c1.real, c1.imag, c3.real, c3.imag, c0.real, c0.imag, c2.real, c2.imag);
                    c1.store(in_out), c3.store(in_out + 16), c0.store(in_out + 32), c2.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
                    omega.load(it32_1 + 8), c2 = c2.mulConj(omega);
                    omega.load(it32_3 + 8), c3 = c3.mulConj(omega);
                    iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out + 8), c1.store(in_out + 24), c2.store(in_out + 40), c3.store(in_out + 56);}
                static void dif16(Float64 in_out[], size_t float_len){
                    assert(float_len >= 32);
                    for (auto end = in_out + float_len; in_out < end; in_out += 32){
                        dif16(in_out);}}
                static void idit16(Float64 in_out[], size_t float_len){
                    assert(float_len >= 32);
                    for (auto end = in_out + float_len; in_out < end; in_out += 32){
                        idit16(in_out);}}
                static void dif32(Float64 in_out[], size_t float_len){
                    assert(float_len >= 64);
                    for (auto end = in_out + float_len; in_out < end; in_out += 64){
                        dif32(in_out);}}
                static void idit32(Float64 in_out[], size_t float_len){
                    assert(float_len >= 64);
                    for (auto end = in_out + float_len; in_out < end; in_out += 64){
                        idit32(in_out);}}
                static void difIter(Float64 in_out[], size_t float_len){
                    size_t fft_len = float_len / 2;
                    assert(fft_len <= SHORT_LEN);
                    for (size_t rank = fft_len; rank >= 64; rank /= 4){
                        const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2){
                            auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
                            auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                            for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8){
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                                dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                omega.load(table2), c1 = c1.mul(omega);
                                omega.load(table1), c2 = c2.mul(omega);
                                omega.load(table3), c3 = c3.mul(omega);
                                c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);}}}
                    if (hint_log2(fft_len) % 2 == 0){
                        dif16(in_out, float_len);}
                    else{
                        dif32(in_out, float_len);}}
                static void iditIter(Float64 in_out[], size_t float_len){
                    size_t fft_len = float_len / 2;
                    assert(fft_len <= SHORT_LEN);
                    size_t rank = 0;
                    if (hint_log2(fft_len) % 2 == 0){
                        idit16(in_out, float_len);
                        rank = 64;}
                    else{
                        idit32(in_out, float_len);
                        rank = 128;}
                    for (; rank <= fft_len; rank *= 4){
                        const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2){
                            auto table1 = multi_table_1.getBegin(rank), table2 = multi_table_2.getBegin(rank), table3 = multi_table_3.getBegin(rank);
                            auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                            for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8){
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                                omega.load(table2), c1 = c1.mulConj(omega);
                                omega.load(table1), c2 = c2.mulConj(omega);
                                omega.load(table3), c3 = c3.mulConj(omega);
                                idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);}}}}
                template <bool FROM_RIRI_PERM = false>
                static void dif2LayerMid(Float64 in_out[], size_t float_len, size_t rank){
                    using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
                    const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2){
                        auto table1 = multi_table_1.getBegin(rank);
                        auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                        for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8){
                            Complex64X4 c0, c1, c2, c3, omega, omega2;
                            c0.load(it0, FromRIRI{}), c1.load(it1, FromRIRI{}), c2.load(it2, FromRIRI{}), c3.load(it3, FromRIRI{});
                            dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                            omega.load(table1), c2 = c2.mul(omega);
                            omega2 = omega.square(), c1 = c1.mul(omega2);
                            c3 = c3.mul(omega2.mul(omega));
                            c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
                        }}}
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
                static void idit2LayerMid(Float64 in_out[], size_t float_len, size_t rank){
                    using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
                    using ToI64 = std::integral_constant<bool, TO_INT64>;
                    const size_t stride1 = rank / 2, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    for (auto begin = in_out, end = in_out + float_len; begin < end; begin += rank * 2){
                        auto table1 = multi_table_1.getBegin(rank);
                        auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                        for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8){
                            Complex64X4 c0, c1, c2, c3, omega, omega2;
                            c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                            omega.load(table1), c2 = c2.mulConj(omega);
                            omega2 = omega.square(), c1 = c1.mulConj(omega2);
                            c3 = c3.mulConj(omega2.mul(omega));
                            idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                            c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
                            c0.store(it0, ToRIRI{}), c1.store(it1, ToRIRI{}), c2.store(it2, ToRIRI{}), c3.store(it3, ToRIRI{});
                        }}}
                template <bool FROM_RIRI_PERM = false>
                static void difMid(Float64 in_out[], size_t float_len){
                    size_t rank = float_len / 2;
                    dif2LayerMid<FROM_RIRI_PERM>(in_out, float_len, rank);
                    rank /= 4;
                    for (; rank > SHORT_LEN; rank /= 4){
                        dif2LayerMid(in_out, float_len, rank);}
                    for (auto end = in_out + float_len; in_out < end; in_out += rank * 2){
                        difIter(in_out, rank * 2);}}
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
                static void iditMid(Float64 in_out[], size_t float_len){
                    constexpr size_t SHORT_LEN_RADIX4 = size_t(1) << ((LOG_SHORT) / 2 * 2);
                    constexpr size_t SHORT_LEN_RADIX2 = SHORT_LEN_RADIX4 == SHORT_LEN ? SHORT_LEN_RADIX4 / 2 : SHORT_LEN;
                    size_t fft_len = float_len / 2;
                    size_t rank = hint_log2(fft_len) % 2 == 0 ? SHORT_LEN_RADIX4 : SHORT_LEN_RADIX2;
                    rank = std::min(rank, fft_len);
                    for (auto it = in_out, end = in_out + float_len; it < end; it += rank * 2){
                        iditIter(it, rank * 2);}
                    rank *= 4;
                    for (; rank < fft_len; rank *= 4){
                        idit2LayerMid(in_out, float_len, rank);}
                    idit2LayerMid<TO_RIRI_PERM, TO_INT64>(in_out, float_len, rank);}

                template <bool FROM_RIRI_PERM = false>
                static void difLarge(Float64 in_out[], size_t float_len){
                    size_t rank = float_len / 2;
                    dif2LayerMid<FROM_RIRI_PERM>(in_out, float_len, rank);
                    rank /= 4;
                    for (; rank > MID_LEN; rank /= 4){
                        dif2LayerMid(in_out, float_len, rank);}
                    for (auto end = in_out + float_len; in_out < end; in_out += rank * 2){
                        difMid(in_out, rank * 2);}}
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
                static void iditLarge(Float64 in_out[], size_t float_len){
                    constexpr size_t MID_LEN_RADIX4 = size_t(1) << ((LOG_MID) / 2 * 2);
                    constexpr size_t MID_LEN_RADIX2 = MID_LEN_RADIX4 == MID_LEN ? MID_LEN_RADIX4 / 2 : MID_LEN;
                    size_t fft_len = float_len / 2;
                    size_t rank = hint_log2(fft_len) % 2 == 0 ? MID_LEN_RADIX4 : MID_LEN_RADIX2;
                    rank = std::min(rank, fft_len);
                    for (auto it = in_out, end = in_out + float_len; it < end; it += rank * 2){
                        iditMid(it, rank * 2);}
                    rank *= 4;
                    for (; rank < fft_len; rank *= 4){
                        idit2LayerMid(in_out, float_len, rank);}
                    idit2LayerMid<TO_RIRI_PERM, TO_INT64>(in_out, float_len, rank);}
                template <bool FROM_RIRI_PERM = false>
                static void difRecS(Float64 in_out[], size_t float_len)
                {
                    using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
                    const size_t fft_len = float_len / 2;
                    assert(fft_len <= MAX_LEN);
                    if (fft_len <= MID_LEN)
                    {
                        difMid<FROM_RIRI_PERM>(in_out, float_len);
                        return;
                    }
                    const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    auto table1 = multi_table_1.getBegin(fft_len);
                    for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega, omega2;
                        c0.load(it, FromRIRI{}), c1.load(it + stride1, FromRIRI{}), c2.load(it + stride2, FromRIRI{}), c3.load(it + stride3, FromRIRI{});
                        c2.load(it + stride2, FromRIRI{}), c3.load(it + stride3, FromRIRI{});
                        dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        omega.load(table1), c2 = c2.mul(omega);
                        omega2 = omega.square(), c1 = c1.mul(omega2);
                        c3 = c3.mul(omega2.mul(omega));
                        c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
                    }
                    difRecS(in_out, stride1);
                    difRecS(in_out + stride1, stride1);
                    difRecS(in_out + stride2, stride1);
                    difRecS(in_out + stride3, stride1);
                }
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false>
                static void iditRecS(Float64 in_out[], size_t float_len)
                {
                    const size_t fft_len = float_len / 2;
                    assert(fft_len <= MAX_LEN);
                    if (fft_len <= MID_LEN)
                    {
                        iditMid<TO_RIRI_PERM, TO_INT64>(in_out, float_len);
                        return;
                    }
                    using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
                    using ToI64 = std::integral_constant<bool, TO_INT64>;
                    const size_t stride1 = float_len / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                    iditRecS(in_out, stride1);
                    iditRecS(in_out + stride1, stride1);
                    iditRecS(in_out + stride2, stride1);
                    iditRecS(in_out + stride3, stride1);
                    auto table1 = multi_table_1.getBegin(fft_len);
                    for (auto end = in_out + stride1, it = in_out; it < end; it += 8, table1 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega, omega2;
                        c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
                        omega.load(table1), c2 = c2.mulConj(omega);
                        omega2 = omega.square(), c1 = c1.mulConj(omega2);
                        c3 = c3.mulConj(omega2.mul(omega));
                        idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
                        c0.store(it, ToRIRI{}), c1.store(it + stride1, ToRIRI{}), c2.store(it + stride2, ToRIRI{}), c3.store(it + stride3, ToRIRI{});
                    }
                }
                };
                
            constexpr size_t FFTAVX::LOG_SHORT;
            constexpr size_t FFTAVX::LOG_MID;
            constexpr size_t FFTAVX::LOG_MAX;
            constexpr size_t FFTAVX::SHORT_LEN;
            constexpr size_t FFTAVX::MID_LEN;
            constexpr size_t FFTAVX::MAX_LEN;

            const TableFix<Float64, 4> FFTAVX::table_8(8, 1, 4);
            const TableFix<Float64, 4> FFTAVX::table_16_1(16, 1, 4);
            const TableFix<Float64, 4> FFTAVX::table_16_3(16, 3, 4);
            const TableFix<Float64, 8> FFTAVX::table_32_1(32, 1, 4);
            const TableFix<Float64, 8> FFTAVX::table_32_3(32, 3, 4);
            const TableFixMulti<Float64, 6, FFTAVX::LOG_SHORT, 4> FFTAVX::multi_table_3(3);
            const TableFixMulti<Float64, 6, FFTAVX::LOG_SHORT, 4> FFTAVX::multi_table_2(2);
            const TableFixMulti<Float64, 6, FFTAVX::LOG_MAX, 4> FFTAVX::multi_table_1(1);

            constexpr uint32_t bitrev32(uint32_t n){
                constexpr uint32_t mask55 = 0x55555555;
                constexpr uint32_t mask33 = 0x33333333;
                constexpr uint32_t mask0f = 0x0f0f0f0f;
                constexpr uint32_t maskff = 0x00ff00ff;
                n = ((n & mask55) << 1) | ((n >> 1) & mask55);
                n = ((n & mask33) << 2) | ((n >> 2) & mask33);
                n = ((n & mask0f) << 4) | ((n >> 4) & mask0f);
                n = ((n & maskff) << 8) | ((n >> 8) & maskff);
                return (n << 16) | (n >> 16);}
            constexpr uint32_t bitrev(uint32_t n, int len){
                assert(len <= 32);
                return bitrev32(n) >> (32 - len);}

            class BinRevTableC64X4HP{
            public:
                using F64 = double;
                using C64 = std::complex<F64>;
                using C64X4 = hint_simd::Complex64X4;
                static constexpr int MAX_LOG_LEN = 32, LOG_BLOCK = 2, BLOCK = 1 << LOG_BLOCK;
                static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;

                BinRevTableC64X4HP(int log_max_iter_in, int log_fft_len_in)
                    : index(0), pop(0), log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in){
                    assert(log_max_iter <= log_fft_len);
                    assert(log_fft_len <= MAX_LOG_LEN);
                    const F64 factor = F64(1) / (size_t(1) << (log_fft_len - log_max_iter));
                    for (int i = 0; i < MAX_LOG_LEN; i++){
                        units[i] = getOmega(size_t(1) << (i + 1), 1, factor);}
                    auto fp = reinterpret_cast<F64 *>(table);
                    fp[0] = 1, fp[BLOCK] = 0;
                    for (int i = 1; i < BLOCK; i++){
                        C64 omega = getOmega(BLOCK, bitrev(i, LOG_BLOCK), factor);
                        fp[i] = omega.real(), fp[i + BLOCK] = omega.imag();}}

                void reset(size_t i = 0){
                    if (i == 0){
                        pop = 0, index = i;
                        return;}
                    assert((i & (i - 1)) == 0);
                    assert(i % BLOCK == 0);
                    pop = 1, index = i / BLOCK;
                    int zero = hint_ctz(index);
                    auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
                    table[1].load1(fp, fp + 1);
                    table[1] = table[1].mul(table[0]);}
                C64X4 iterate(){
                    C64X4 res = table[pop], unit4;
                    index++;
                    int zero = hint_ctz(index);
                    auto fp = reinterpret_cast<F64 *>(&units[zero + 2]);
                    unit4.load1(fp, fp + 1);
                    pop -= zero;
                    table[pop + 1] = table[pop].mul(unit4);
                    pop++;
                    return res;}

                static C64 getOmega(size_t n, size_t index, F64 factor = 1){
                    F64 theta = -HINT_2PI * index / n;
                    return std::polar<F64>(1, theta * factor);}

            private:
                C64 units[MAX_LOG_LEN]{};
                C64X4 table[MAX_LOG_LEN]{};
                size_t index;
                int pop;
                int log_max_iter, log_fft_len;};

            template <size_t RI_DIFF = 1, typename FloatTy>
            inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1,
                                 const std::complex<FloatTy> &omega0, const FloatTy factor = 1){
                using Complex = std::complex<FloatTy>;
                auto mul1 = [](Complex c0, Complex c1){
                    return Complex(c0.imag() * c1.real() + c0.real() * c1.imag(),
                                   c0.imag() * c1.imag() - c0.real() * c1.real());};
                auto mul2 = [](Complex c0, Complex c1){
                    return Complex(c0.real() * c1.imag() - c0.imag() * c1.real(),
                                   c0.real() * c1.real() + c0.imag() * c1.imag());};
                auto compute2 = [&omega0](Complex in0, Complex in1, Complex &out0, Complex &out1, auto Func){
                    in1 = std::conj(in1);
                    transform2(in0, in1);
                    in1 = Func(in1, omega0);
                    out0 = in0 + in1;
                    out1 = std::conj(in0 - in1);};
                Complex c0, c1;{
                    Complex x0, x1, x2, x3;
                    c0.real(inout0[0]), c0.imag(inout0[RI_DIFF]), c1.real(inout1[0]), c1.imag(inout1[RI_DIFF]);
                    compute2(c0, c1, x0, x1, mul1);
                    c0.real(in0[0]), c0.imag(in0[RI_DIFF]), c1.real(in1[0]), c1.imag(in1[RI_DIFF]);
                    compute2(c0, c1, x2, x3, mul1);
                    x0 *= x2 * factor;
                    x1 *= x3 * factor;
                    compute2(x0, x1, c0, c1, mul2);}
                inout0[0] = c0.real(), inout0[RI_DIFF] = c0.imag();
                inout1[0] = c1.real(), inout1[RI_DIFF] = c1.imag();}
            inline void dot_rfftX4(F64 *inout0, F64 *inout1, const F64 *in0, const F64 *in1, const C64X4 &omega0, const F64X4 &inv){
                auto mul1 = [](C64X4 c0, C64X4 c1){
                    return C64X4(F64X4::fmadd(c0.imag, c1.real, c0.real * c1.imag),
                                 F64X4::fmsub(c0.imag, c1.imag, c0.real * c1.real));};
                auto mul2 = [](C64X4 c0, C64X4 c1){
                    return C64X4(F64X4::fmsub(c0.real, c1.imag, c0.imag * c1.real),
                                 F64X4::fmadd(c0.real, c1.real, c0.imag * c1.imag));};
                auto compute2 = [&omega0](C64X4 c0, C64X4 c1, C64X4 &out0, C64X4 &out1, auto Func){
                    C64X4 t0(c0.real + c1.real, c0.imag - c1.imag), t1(c0.real - c1.real, c0.imag + c1.imag);
                    t1 = Func(t1, omega0);
                    out0 = t0 + t1;
                    out1.real = t0.real - t1.real;
                    out1.imag = t1.imag - t0.imag;};
                C64X4 c0, c1;{
                    C64X4 x0, x1, x2, x3;
                    c0.load(inout0), c1.load(inout1);
                    compute2(c0, c1.reverse(), x0, x1, mul1);

                    c0.load(in0), c1.load(in1);
                    compute2(c0, c1.reverse(), x2, x3, mul1);
                    c0 = x0.mul(x2) * inv;
                    c1 = x1.mul(x3) * inv;
                    compute2(c0, c1, c0, c1, mul2);}
                c0.store(inout0), c1.reverse().store(inout1);}

            inline void real_dot_binrev4(Float64 in_out[], Float64 in[], size_t float_len){
                using Complex = std::complex<Float64>;
                Float64 inv = 2.0 / float_len;{
                    auto r0 = in_out[0], i0 = in_out[4], r1 = in[0], i1 = in[4];
                    transform2(r0, i0);
                    transform2(r1, i1);
                    r0 *= r1, i0 *= i1;
                    transform2(r0, i0);
                    in_out[0] = r0 * 0.5 * inv, in_out[4] = i0 * 0.5 * inv;}
                auto temp = Complex(in_out[1], in_out[5]) * Complex(in[1], in[5]) * inv;
                in_out[1] = temp.real(), in_out[5] = temp.imag();
                inv /= 8;
                static BinRevTableC64X4HP table(31, 32);
                dot_rfft<4>(&in_out[2], &in_out[3], &in[2], &in[3], Complex(COS_PI_8, -COS_PI_8), inv);
                constexpr Float64 COS_16_1 = 0.92387953251128675612818318939;
                constexpr Float64 SIN_16_1 = 0.38268343236508977172845998403;
                dot_rfft<4>(&in_out[8], &in_out[11], &in[8], &in[11], C64(COS_16_1, -SIN_16_1), inv);
                dot_rfft<4>(&in_out[9], &in_out[10], &in[9], &in[10], C64(-SIN_16_1, -COS_16_1), inv);
                const Float64X4 inv4 = F64X4(inv);
                for (size_t begin = 16; begin < float_len; begin *= 2){
                    table.reset(begin / 2);
                    auto it0 = in_out + begin, it1 = it0 + begin - 8, it2 = in + begin, it3 = it2 + begin - 8;
                    for (; it0 < it1; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8){
                        auto omega = table.iterate();
                        dot_rfftX4(it0, it1, it2, it3, omega, inv4);}}}

            template <bool TO_INT = false>
            inline void real_conv_avxS(F64 *in_out1, F64 *in2, size_t float_len){
                assert(is_2pow(float_len));
                FFTAVX::difRecS<true>(in_out1, float_len);
                FFTAVX::difRecS<true>(in2, float_len);
                real_dot_binrev4(in_out1, in2, float_len);
                FFTAVX::iditRecS<true, TO_INT>(in_out1, float_len);}}
    }
}

namespace string_util{
    using namespace hint;
    using namespace transform;
    using namespace fft;
    class ItoStrBase10000{
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num){
            uint32_t res = (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                           ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res + '0' * 0x1010101;}
        constexpr ItoStrBase10000(){
            for (size_t i = 0; i < 10000; i++){
                table[i] = itosbase10000(i);}}
        void tostr(char *str, uint32_t num) const{
            std::memcpy(str, &table[num], sizeof(num));}
        uint32_t tostr(uint32_t num) const{
            return table[num];}
    };

    class StrtoIBase100{
    private:
        static constexpr size_t TABLE_SIZE = size_t(1) << 15;
        uint16_t table[TABLE_SIZE]{};

    public:
        static constexpr uint16_t itosbase100(uint16_t num){
            uint16_t res = (num / 10 % 10) | ((num % 10) << 8);
            return res + '0' * 0x0101;}
        constexpr StrtoIBase100(){
            for (size_t i = 0; i < TABLE_SIZE; i++){
                table[i] = UINT16_MAX;}
            for (size_t i = 0; i < 100; i++){
                table[itosbase100(i)] = i;}}
        uint16_t toInt(const char *str) const{
            uint16_t num;
            std::memcpy(&num, str, sizeof(num));
            return table[num];}
    };

    constexpr ItoStrBase10000 itosbase10000{};
    constexpr StrtoIBase100 strtoibase100{};

    inline uint32_t stobase10000(const char *s){
        return strtoibase100.toInt(s) * 100 + strtoibase100.toInt(s + 2);
    }

    template <typename T, size_t ALIGN = 64>
    class AlignMem{
    public:
        using Ptr = T *;
        using ConstPtr = const T *;
        AlignMem() : ptr(nullptr) {}
        AlignMem(size_t n) : len(n), ptr(reinterpret_cast<Ptr>(_mm_malloc(n * sizeof(T), ALIGN))) {}
        ~AlignMem(){
            if (ptr){
                _mm_free(ptr);}};
        T &operator[](size_t i){
            return ptr[i];}
        const T &operator[](size_t i) const{
            return ptr[i];}
        Ptr begin(){
            return ptr;}
        Ptr end(){
            return ptr + len;}
        ConstPtr begin() const{
            return ptr;}
        ConstPtr end() const{
            return ptr + len;}
        size_t size() const{
            return len;
        }
    private:
        T *ptr;
        size_t len;
    };

    template <typename T>
    void fill_zero(T *begin, T *end){
        std::memset(begin, 0, (end - begin) * sizeof(T));
    }

    template <typename T>
    size_t str_num_to_array_base10000(const char *str, size_t len, T *ary){
        constexpr size_t BLOCK = 4;
        auto end = str + len, p = str;
        size_t i = 0;
        for (auto ed = end - len % BLOCK; p < ed; p += BLOCK, i++){
            ary[i] = stobase10000(p);}
        size_t shift = 0;
        if (p < end){
            size_t rem = end - p;
            int n = 0;
            for (; p < end; p++){
                n = n * 10 + *p - '0';}
            shift = BLOCK - rem;
            for (; rem < BLOCK; rem++){
                n *= 10;}
            ary[i] = n;
            i++;}
        return shift;
    }

    template <typename T>
    size_t conv_to_str_base10000(const T *ary, size_t conv_len, size_t shift, char *res, size_t &res_len){
        constexpr size_t BLOCK = 4, BASE = 10000;
        res_len = (conv_len + 1) * BLOCK;
        auto end = res + res_len;
        size_t i = conv_len;
        uint64_t carry = 0;
        while (i > 0){
            i--;
            end -= BLOCK;
            carry += ary[i];
            itosbase10000.tostr(end, carry % BASE);
            carry /= BASE;}
        assert(carry < BASE);
        end -= 4;
        itosbase10000.tostr(end, carry);
        while (*end == '0'){
            end++;}
        size_t offset = end - res;
        res_len -= (offset + shift);
        return offset;
    }
    static size_t fft_len = size_t(1)<<19;
    static AlignMem<Float64> ary1(fft_len), ary2(fft_len);
     
    char *big_mul(const char *str1, size_t len1, const char *str2, size_t len2, char *res, size_t &res_len){
        constexpr size_t BLOCK = 4, BASE = 10000;
        size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
        size_t conv_len = block_len1 + block_len2 - 1, fft_len = hint::int_ceil2(conv_len);
        fft_len = std::max<size_t>(fft_len, 256);
        size_t shift = str_num_to_array_base10000(str1, len1, &ary1[0]);
        shift += str_num_to_array_base10000(str2, len2, &ary2[0]);
        fill_zero(ary1.begin() + block_len1, ary1.end());
        fill_zero(ary2.begin() + block_len2, ary2.end());
        real_conv_avxS<true>(ary1.begin(), ary2.begin(), fft_len);
        return res + conv_to_str_base10000((uint64_t *)ary1.begin(), conv_len, shift, res, res_len);
    }

    size_t preserve_strlen(size_t len1, size_t len2){
        constexpr size_t BLOCK = 4;
        size_t block_len1 = (len1 + BLOCK - 1) / BLOCK, block_len2 = (len2 + BLOCK - 1) / BLOCK;
        return (block_len1 + block_len2) * BLOCK;
    }

    size_t digit_strlen(const char *str){
        auto begin = str;
        while (*str >= '0'){
            str++;}
        return str - begin;
    }

    void mul(){
        constexpr size_t STR_LEN = 2000008;
        auto str = reinterpret_cast<char*>(&ary2[ary2.size()/2]);
        // static char str[STR_LEN] = "0 10";
        fread(str, 1, STR_LEN, stdin);
        char *s1 = str, *s2;
        size_t len1 = digit_strlen(str);
        s2 = s1 + len1;
        while (*s2 < '0'){
            s2++;}
        size_t len2 = digit_strlen(s2);
        size_t res_len = preserve_strlen(len1, len2);
        auto begin = big_mul(s1, len1, s2, len2, str, res_len);
        auto end = begin + res_len;
        if(res_len == 0){
            puts("0");}
        auto buf = reinterpret_cast<char*>(ary1.begin());
        // setvbuf(stdout, buf, _IOFBF, 1 << 14);
        fwrite(begin, 1, res_len, stdout);
        // fflush(stdout);
    }
}

int main(){
    string_util::mul();
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #18.305 ms11 MB + 984 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-10-26 18:17:26 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠