提交记录 28312


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Compile Error 0 0 ns 0 KB C++14 62.55 KB
提交时间 评测时间
2025-06-23 10:18:37 2025-06-23 10:18:39
// TSKY 2025/6/20
#include <vector>
#include <array>
#include <complex>
#include <iostream>
#include <chrono>
#include <string>
#include <bitset>
#include <type_traits>
#include <cstdint>
#include <cfloat>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>

#include <iostream>
#include <complex>
#include <type_traits>
#include <cstdint>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("avx")
#pragma GCC target("fma")
#pragma GCC target("avx2")

namespace hint_simd
{
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(4096) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };
    template <typename YMM>
    inline void transpose64_2X4(YMM &row0, YMM &row1)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7

        row0 = YMM(_mm256_permute2f128_pd(t0, t1, 0x20)); // 0,4,2,6 1,5,3,7 -> 0,4,1,5
        row1 = YMM(_mm256_permute2f128_pd(t0, t1, 0x31)); // 0,4,2,6 1,5,3,7 -> 2,6,3,7
    }
    template <typename YMM>
    inline void transpose64_4X2(YMM &row0, YMM &row1)
    {
        auto t0 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x20); // 0,1,2,3 4,5,6,7 -> 0,1,4,5
        auto t1 = _mm256_permute2f128_pd(__m256d(row0), __m256d(row1), 0x31); // 0,1,2,3 4,5,6,7 -> 2,3,6,7
        row0 = YMM(_mm256_unpacklo_pd(t0, t1));                               // 0,1,4,5 2,3,6,7 -> 0,4,2,6
        row1 = YMM(_mm256_unpackhi_pd(t0, t1));                               // 0,1,4,5 2,3,6,7 -> 1,5,3,7
    }

    template <typename YMM>
    inline void transpose64_8X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
                                YMM &row4, YMM &row5, YMM &row6, YMM &row7)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
        auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
        auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15
        auto t4 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 16,20,18,22
        auto t5 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row5)); // 16,17,18,19 20,21,22,23 -> 17,21,19,23
        auto t6 = _mm256_unpacklo_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 24,28,26,30
        auto t7 = _mm256_unpackhi_pd(__m256d(row6), __m256d(row7)); // 24,25,26,27 28,29,30,31 -> 25,29,27,31

        row0 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x20));
        row2 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x20));
        row3 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x20));
        row4 = __m256d(_mm256_permute2f128_pd(t0, t2, 0x31));
        row5 = __m256d(_mm256_permute2f128_pd(t4, t6, 0x31));
        row6 = __m256d(_mm256_permute2f128_pd(t1, t3, 0x31));
        row7 = __m256d(_mm256_permute2f128_pd(t5, t7, 0x31));
    }

    template <typename YMM>
    inline void transpose64_4X8(YMM &row0, YMM &row1, YMM &row2, YMM &row3,
                                YMM &row4, YMM &row5, YMM &row6, YMM &row7)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 0,8,2,10
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row2)); // 0,1,2,3 8,9,10,11 -> 1,9,3,11
        auto t2 = _mm256_unpacklo_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 16,24,18,26
        auto t3 = _mm256_unpackhi_pd(__m256d(row4), __m256d(row6)); // 16,17,18,19 24,25,26,27 -> 17,25,19,27
        auto t4 = _mm256_unpacklo_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 4,12,6,14
        auto t5 = _mm256_unpackhi_pd(__m256d(row1), __m256d(row3)); // 4,5,6,7 12,13,14,15 -> 5,13,7,15
        auto t6 = _mm256_unpacklo_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 20,28,22,30
        auto t7 = _mm256_unpackhi_pd(__m256d(row5), __m256d(row7)); // 20,21,22,23 28,29,30,31 -> 21,29,23,31

        row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
        row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
        row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
        row4 = YMM(_mm256_permute2f128_pd(t4, t6, 0x20));
        row5 = YMM(_mm256_permute2f128_pd(t5, t7, 0x20));
        row6 = YMM(_mm256_permute2f128_pd(t4, t6, 0x31));
        row7 = YMM(_mm256_permute2f128_pd(t5, t7, 0x31));
    }

    template <typename YMM>
    inline void transpose64_4X4(YMM &row0, YMM &row1, YMM &row2, YMM &row3)
    {
        auto t0 = _mm256_unpacklo_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
        auto t1 = _mm256_unpackhi_pd(__m256d(row0), __m256d(row1)); // 0,1,2,3 4,5,6,7 -> 1,5,3,7
        auto t2 = _mm256_unpacklo_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 8,12,10,14
        auto t3 = _mm256_unpackhi_pd(__m256d(row2), __m256d(row3)); // 8,9,10,11 12,13,14,15 -> 9,13,11,15

        row0 = YMM(_mm256_permute2f128_pd(t0, t2, 0x20));
        row1 = YMM(_mm256_permute2f128_pd(t1, t3, 0x20));
        row2 = YMM(_mm256_permute2f128_pd(t0, t2, 0x31));
        row3 = YMM(_mm256_permute2f128_pd(t1, t3, 0x31));
    }

    class Float64X4
    {
    public:
        using F64 = double;
        using F64X4 = Float64X4;
        Float64X4() : data(_mm256_setzero_pd()) {}
        Float64X4(__m256d in_data) : data(in_data) {}
        Float64X4(F64 in_data) : data(_mm256_set1_pd(in_data)) {}
        Float64X4(const F64 *in_data) : data(_mm256_load_pd(in_data)) {}

        F64X4 operator+(const F64X4 &other) const
        {
            return _mm256_add_pd(data, other.data);
        }
        F64X4 operator-(const F64X4 &other) const
        {
            return _mm256_sub_pd(data, other.data);
        }
        F64X4 operator*(const F64X4 &other) const
        {
            return _mm256_mul_pd(data, other.data);
        }
        F64X4 operator/(const F64X4 &other) const
        {
            return _mm256_div_pd(data, other.data);
        }

        F64X4 &operator+=(const F64X4 &other)
        {
            return *this = *this + other;
        }
        F64X4 &operator-=(const F64X4 &other)
        {
            return *this = *this - other;
        }
        F64X4 &operator*=(const F64X4 &other)
        {
            return *this = *this * other;
        }
        F64X4 &operator/=(const F64X4 &other)
        {
            return *this = *this / other;
        }
        F64X4 floor() const
        {
            return _mm256_floor_pd(data);
        }
        // a * b + c
        static F64X4 fmadd(const F64X4 &a, const F64X4 &b, const F64X4 &c)
        {
#ifdef __FMA__
            return _mm256_fmadd_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
            return a * b + c;
#endif
        }
        // a * b - c
        static F64X4 fmsub(const F64X4 &a, const F64X4 &b, const F64X4 &c)
        {
#ifdef __FMA__
            return _mm256_fmsub_pd(a.data, b.data, c.data);
#else
#pragma message("No FMA support")
            return a * b - c;
#endif
        }
#ifdef __AVX2__
        template <int N>
        F64X4 permute4x64() const
        {
            return _mm256_permute4x64_pd(data, N);
        }
#else
        template <int N>
        F64X4 permute4x64() const
        {
            alignas(32) uint64_t arr[4];
            alignas(32) uint64_t dst[4];
            this->store(reinterpret_cast<F64*>(arr));
            dst[0] = arr[(N >> 0) & 3];
            dst[1] = arr[(N >> 2) & 3];
            dst[2] = arr[(N >> 4) & 3];
            dst[3] = arr[(N >> 6) & 3];
            return fromMem(reinterpret_cast<const F64*>(dst));
        }
#endif
        static F64X4 extractEven64X4(const F64X4 &in0, const F64X4 &in1)
        {
            F64X4 result = _mm256_unpacklo_pd(in0.data, in1.data); // 0,1,2,3 4,5,6,7 -> 0,4,2,6
            return result.permute4x64<0b11011000>();               // 0,4,2,6 -> 0,2,4,6
        }

        template <int N>
        F64X4 permute() const
        {
            return _mm256_permute_pd(data, N);
        }
        F64X4 reverse() const
        {
            return permute4x64<0b00011011>();
        }
        void load(const F64 *p)
        {
            data = _mm256_load_pd(p);
        }
        void loadu(const F64 *p)
        {
            data = _mm256_loadu_pd(p);
        }
        void load1(const F64 *p)
        {
            data = _mm256_broadcast_sd(p);
        }
        static F64X4 fromMem(const F64 *p)
        {
            return _mm256_load_pd(p);
        }
        static F64X4 fromUMem(const F64 *p)
        {
            return _mm256_loadu_pd(p);
        }
        void store(F64 *p) const
        {
            _mm256_store_pd(p, data);
        }
        void storeu(F64 *p) const
        {
            _mm256_storeu_pd(p, data);
        }
        operator __m256d() const
        {
            return data;
        }
#ifdef __AVX2__
        // Convert positive double to int64
        __m256i toI64X4() const
        {
            constexpr uint64_t mask = (uint64_t(1) << 52) - 1;
            constexpr uint64_t offset = (uint64_t(1) << 10) - 1;
            const __m256i f64bits = _mm256_castpd_si256(data);
            __m256i tail = _mm256_and_si256(f64bits, _mm256_set1_epi64x(mask));
            tail = _mm256_or_si256(tail, _mm256_set1_epi64x(mask + 1));
            __m256i exp = _mm256_srli_epi64(f64bits, 52);
            exp = _mm256_sub_epi64(_mm256_set1_epi64x(offset + 52), exp);
            return _mm256_srlv_epi64(tail, exp);
        }
#else
#pragma message("No AVX2 support")
        __m256i toI64X4() const
        {
            alignas(32) F64 arr[4];
            alignas(32) int64_t i64_arr[4];
            this->store(arr);
            i64_arr[0] = arr[0];
            i64_arr[1] = arr[1];
            i64_arr[2] = arr[2];
            i64_arr[3] = arr[3];
            return _mm256_load_si256(reinterpret_cast<const __m256i *>(i64_arr));
        }
#endif
        template <int N>
        F64 nthEle() const
        {
            union F64I64
            {
                int64_t i64;
                F64 f64;
            } temp;
            temp.i64 = _mm256_extract_epi64(__m256i(data), N);
            return temp.f64;
        }

        void print() const
        {
            std::cout << "[" << nthEle<0>() << "," << nthEle<1>()
                      << "," << nthEle<2>() << "," << nthEle<3>() << "]" << std::endl;
        }

    private:
        __m256d data;
    };

    struct Complex64X4
    {
        using C64X4 = Complex64X4;
        using F64X4 = Float64X4;
        using F64 = double;
        Complex64X4() {}
        Complex64X4(F64X4 real, F64X4 imag) : real(real), imag(imag) {}
        Complex64X4(const F64 *p) : real(p), imag(p + 4) {}
        Complex64X4(const F64 *p_real, const F64 *p_imag) : real(p_real), imag(p_imag) {}
        C64X4 operator+(const C64X4 &other) const
        {
            return C64X4(real + other.real, imag + other.imag);
        }
        C64X4 operator-(const C64X4 &other) const
        {
            return C64X4(real - other.real, imag - other.imag);
        }
        C64X4 mul(const C64X4 &other) const
        {
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmsub(real, other.real, ii);
            const F64X4 i = F64X4::fmadd(imag, other.real, ri);
            return C64X4(r, i);
        }
        C64X4 mulConj(const C64X4 &other) const
        {
            const F64X4 ii = imag * other.imag;
            const F64X4 ri = real * other.imag;
            const F64X4 r = F64X4::fmadd(real, other.real, ii);
            const F64X4 i = F64X4::fmsub(imag, other.real, ri);
            return C64X4(r, i);
        }
        // exp{i*theta*k},k in {0,1,2,3}
        static C64X4 omegaSeq0To3(F64 theta, F64 begin = 0)
        {
            F64 real_arr[4] = {cos(begin), cos(theta + begin), cos(2 * theta + begin), cos(3 * theta + begin)};
            F64 imag_arr[4] = {sin(begin), sin(theta + begin), sin(2 * theta + begin), sin(3 * theta + begin)};
            return C64X4(F64X4(real_arr), F64X4(imag_arr));
        }

        template <typename T>
        void load(const T *p, std::false_type)
        {
            this->load(p);
        }
        // From RIRI permutation
        template <typename T>
        void load(const T *p, std::true_type)
        {
            this->load(p);
            *this = this->toRRIIPermu();
        }
        template <typename T>
        void load(const T *p)
        {
            real.load(reinterpret_cast<const F64 *>(p));
            imag.load(reinterpret_cast<const F64 *>(p) + 4);
        }
        template <typename T>
        void loadu(const T *p)
        {
            real.loadu(reinterpret_cast<const F64 *>(p));
            imag.loadu(reinterpret_cast<const F64 *>(p) + 4);
        }
        void load1(const F64 *real_p, const F64 *imag_p)
        {
            real.load1(real_p);
            imag.load1(imag_p);
        }

        template <typename T>
        void store(T *p, std::false_type) const
        {
            this->store(p);
        }
        // To RIRI permutation
        template <typename T>
        void store(T *p, std::true_type) const
        {
            this->toRIRIPermu().store(p);
        }
        template <typename T>
        void store(T *p) const
        {
            real.store(reinterpret_cast<F64 *>(p));
            imag.store(reinterpret_cast<F64 *>(p) + 4);
        }
        template <typename T>
        void storeu(T *p) const
        {
            real.storeu(reinterpret_cast<F64 *>(p));
            imag.storeu(reinterpret_cast<F64 *>(p) + 4);
        }
        C64X4 square() const
        {
            const F64X4 ii = imag * imag;
            const F64X4 ri = real * imag;
            const F64X4 r = F64X4::fmsub(real, real, ii);
            const F64X4 i = ri + ri;
            return C64X4(r, i);
        }
        C64X4 cube() const
        {
            const F64X4 rr = real * real;
            const F64X4 ii = imag * imag;
            const F64X4 rr3 = rr + rr + rr;
            const F64X4 ii3 = ii + ii + ii;
            const F64X4 r = real * (rr - ii3);
            const F64X4 i = imag * (rr3 - ii);
            return C64X4(r, i);
        }
        C64X4 toRIRIPermu() const
        {
            C64X4 res = *this;
            transpose64_2X4(res.real, res.imag);
            return res;
        }
        C64X4 toRRIIPermu() const
        {
            C64X4 res = *this;
            transpose64_4X2(res.real, res.imag);
            return res;
        }
        void print() const
        {
            alignas(32) F64 real_arr[4]{}, imag_arr[4]{};
            real.storeu(real_arr);
            imag.storeu(imag_arr);
            std::cout << "[(" << real_arr[0] << ", " << imag_arr[0] << "), ("
                      << real_arr[1] << ", " << imag_arr[1] << "), ("
                      << real_arr[2] << ", " << imag_arr[2] << "), ("
                      << real_arr[3] << ", " << imag_arr[3] << ")]" << std::endl;
        }

        C64X4 transToI64(std::false_type) const
        {
            return *this;
        }
        C64X4 transToI64(std::true_type) const
        {
            constexpr int64_t F1_2 = 4602678819172646912; // magic::bit_cast<int64_t>(0.5);
            auto F1_2X4 = F64X4(__m256d(_mm256_set1_epi64x(F1_2)));
            auto real_i64 = (real + F1_2X4).toI64X4();
            auto imag_i64 = (imag + F1_2X4).toI64X4();
            return C64X4(__m256d(real_i64), __m256d(imag_i64));
        }

        F64X4 real, imag;
    };
}
#endif


namespace hint
{
    using Float32 = float;
    using Float64 = double;
    using Complex32 = std::complex<Float32>;
    using Complex64 = std::complex<Float64>;

    constexpr Float64 HINT_PI = 3.141592653589793238462643;
    constexpr Float64 HINT_2PI = HINT_PI * 2;
    constexpr size_t FFT_MAX_LEN = size_t(1) << 23;

    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }

    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }

    template <typename IntTy>
    constexpr bool is_2pow(IntTy n)
    {
        return (n & (n - 1)) == 0;
    }

    // 求整数的对数
    template <typename T>
    constexpr int hint_log2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r)
        {
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n)
            {
                r = mid;
            }
            else
            {
                l = mid;
            }
        }
        return l;
    }
    constexpr int hint_ctz(uint32_t x)
    {
        int r0 = 31;
        x &= (-x);
        if (x & 0x55555555)
        {
            r0 &= ~1;
        }
        if (x & 0x33333333)
        {
            r0 &= ~2;
        }
        if (x & 0x0F0F0F0F)
        {
            r0 &= ~4;
        }
        if (x & 0x00FF00FF)
        {
            r0 &= ~8;
        }
        if (x & 0x0000FFFF)
        {
            r0 &= ~16;
        }
        r0 += (x == 0);
        return r0;
    }

    constexpr int hint_ctz(uint64_t x)
    {
        int r0 = 63;
        x &= (-x);
        if (x & 0x5555555555555555)
        {
            r0 &= ~1; // -1
        }
        if (x & 0x3333333333333333)
        {
            r0 &= ~2; // -2
        }
        if (x & 0x0F0F0F0F0F0F0F0F)
        {
            r0 &= ~4; // -4
        }
        if (x & 0x00FF00FF00FF00FF)
        {
            r0 &= ~8; // -8
        }
        if (x & 0x0000FFFF0000FFFF)
        {
            r0 &= ~16; // -16
        }
        if (x & 0x00000000FFFFFFFF)
        {
            r0 &= ~32; // -32
        }
        r0 += (x == 0);
        return r0;
    }

    template <typename T, T N>
    struct StaticObject
    {
        using Type = T;
        static constexpr Type value = N;
    };

    template <size_t N>
    using StaticSize = StaticObject<size_t, N>;

    template <int N>
    using StaticInt = StaticObject<int, N>;

    // FFT与类FFT变换的命名空间
    namespace transform
    {
        using namespace hint_simd;

        template <typename T>
        inline void transform2(T &sum, T &diff)
        {
            T temp0 = sum, temp1 = diff;
            sum = temp0 + temp1;
            diff = temp0 - temp1;
        }

        template <typename T>
        inline void transform2(const T a, const T b, T &sum, T &diff)
        {
            sum = a + b;
            diff = a - b;
        }

        // 返回单位圆上辐角为theta的点
        template <typename FloatTy>
        inline auto unit_root(FloatTy theta)
        {
            return std::polar<FloatTy>(1.0, theta);
        }

        // 二进制逆序
        template <typename It>
        void binary_reverse_swap(It begin, It end)
        {
            const size_t len = end - begin;
            // 左下标小于右下标时交换,防止重复交换
            auto smaller_swap = [=](It it_left, It it_right)
            {
                if (it_left < it_right)
                {
                    std::swap(it_left[0], it_right[0]);
                }
            };
            // 若i的逆序数的迭代器为last,则返回i+1的逆序数的迭代器
            auto get_next_bitrev = [=](It last)
            {
                size_t k = len / 2, indx = last - begin;
                indx ^= k;
                while (k > indx)
                {
                    k >>= 1;
                    indx ^= k;
                };
                return begin + indx;
            };
            // 长度较短的普通逆序
            if (len <= 16)
            {
                for (auto i = begin + 1, j = begin + len / 2; i < end - 1; i++)
                {
                    smaller_swap(i, j);
                    j = get_next_bitrev(j);
                }
                return;
            }
            const size_t len_8 = len / 8;
            const auto last = begin + len_8;
            auto i0 = begin + 1, i1 = i0 + len / 2, i2 = i0 + len / 4, i3 = i1 + len / 4;
            for (auto j = begin + len / 2; i0 < last; i0++, i1++, i2++, i3++)
            {
                smaller_swap(i0, j);
                smaller_swap(i1, j + 1);
                smaller_swap(i2, j + 2);
                smaller_swap(i3, j + 3);
                smaller_swap(i0 + len_8, j + 4);
                smaller_swap(i1 + len_8, j + 5);
                smaller_swap(i2 + len_8, j + 6);
                smaller_swap(i3 + len_8, j + 7);
                j = get_next_bitrev(j);
            }
        }

        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T ary, const size_t len)
        {
            binary_reverse_swap(ary, ary + len);
        }
        namespace fft
        {
            using F64 = Float64;
            using C64 = std::complex<F64>;
            using F64X4 = Float64X4;
            using C64X4 = Complex64X4;
            template <typename Float, size_t OMEGA_LEN>
            class TableFix
            {
                alignas(64) std::array<Float, OMEGA_LEN * 2> table;

            public:
                TableFix(size_t theta_divider, size_t factor, size_t stride)
                {
                    const Float theta = -HINT_2PI * factor / theta_divider;
                    assert(OMEGA_LEN % stride == 0);
                    for (size_t begin = 0, index = 0; begin < OMEGA_LEN * 2; begin += stride * 2)
                    {
                        for (size_t j = 0; j < stride; j++, index++)
                        {
                            table[begin + j] = std::cos(theta * index);
                            table[begin + j + stride] = std::sin(theta * index);
                        }
                    }
                }
                constexpr const Float &operator[](size_t index) const
                {
                    return table[index];
                }
                constexpr const Float *getOmegaIt(size_t index) const
                {
                    return &table[index];
                }
            };
            template <typename Float, int LOG_BEGIN, int LOG_END, int DIV>
            class TableFixMulti
            {
                static_assert(LOG_END >= LOG_BEGIN);
                static_assert(is_2pow(DIV));
                static constexpr size_t TABLE_CPX_LEN = (size_t(1) << (LOG_END + 1)) / DIV;
                alignas(64) std::array<Float, TABLE_CPX_LEN * 2> table;

            public:
                TableFixMulti(size_t factor, size_t stride = 4)
                {
                    assert(((size_t(1) << LOG_BEGIN) / DIV) % stride == 0);
                    auto t1 = std::chrono::steady_clock::now();
                    initAVXF64(factor, stride);
                    auto t2 = std::chrono::steady_clock::now();
                    std::cout << "TableFixMulti init time: "
                              << std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count()
                              << "us" << std::endl;
                }
                void initAVXF64(size_t factor, size_t stride)
                {
                    assert((std::is_same<Float, Float64>::value));
                    assert(stride == 4);
                    size_t len = size_t(1) << LOG_BEGIN, cpx_len = len / DIV;
                    Float theta = -HINT_2PI * factor / len;
                    auto it = getBegin(LOG_BEGIN);
                    for (size_t i = 0; i < cpx_len; i++)
                    {
                        it[0] = std::cos(theta * i), it[stride] = std::sin(theta * i);
                        it += (i % stride == stride - 1 ? stride + 1 : 1);
                    }
                    it = getBegin(LOG_BEGIN);
                    for (int log_len = LOG_BEGIN + 1; log_len <= LOG_END; log_len++)
                    {
                        len = size_t(1) << log_len, cpx_len = len / DIV;
                        theta = -HINT_2PI * factor / len;
                        auto it = getBegin(log_len), it_last = getBegin(log_len - 1);
                        C64X4 unit(std::cos(theta), std::sin(theta));
                        for (auto end = it + cpx_len * 2; it < end; it += 16, it_last += 8)
                        {
                            Complex64X4 omega0, omega1;
                            omega0.load(it_last);
                            omega1 = omega0.mul(unit);
                            transpose64_2X4(omega0.real, omega1.real);
                            transpose64_2X4(omega0.imag, omega1.imag);
                            omega0.store(it), omega1.store(it + 8);
                        }
                    }
                }
                // void initEndAVXF64(size_t factor, size_t stride)
                // {
                //     const size_t end_len = (size_t(1) << LOG_END), cpx_len = end_len / DIV;
                //     const Float theta = -HINT_2PI * factor / end_len;
                //     auto begin = getBegin(LOG_END);
                //     for (size_t i = 0; i < stride; i++)
                //     {
                //         begin[i] = std::cos(theta * i), begin[i + stride] = std::sin(theta * i);
                //         begin[i + stride * 2] = std::cos(theta * (i + stride)), begin[i + stride * 3] = std::sin(theta * (i + stride));
                //     }
                //     begin[0] = 1, begin[stride] = 0;
                //     auto last = getBegin(LOG_END - 1);
                //     C64X4 last0, last1;
                //     last0.load(begin), last1.load(begin + 8);
                //     last0.real = F64X4::extractEven64X4(last0.real, last1.real);
                //     last0.imag = F64X4::extractEven64X4(last0.imag, last1.imag);
                //     last0.store(last);
                //     last += 8;
                //     for (size_t len = stride * 2; len < cpx_len; len *= 2)
                //     {
                //         const Float angle = theta * len;
                //         C64X4 unit(std::cos(angle), std::sin(angle));
                //         auto it = begin + len * 2;
                //         for (size_t i = 0; i < len * 2; i += 16, last += 8)
                //         {
                //             last0.load(&begin[i]), last1.load(&begin[i + 8]);
                //             last0 = last0.mul(unit);
                //             last1 = last1.mul(unit);
                //             last0.store(&it[i]), last1.store(&it[i + 8]);
                //             last0.real = F64X4::extractEven64X4(last0.real, last1.real);
                //             last0.imag = F64X4::extractEven64X4(last0.imag, last1.imag);
                //             last0.store(last);
                //         }
                //     }
                // }
                // void initAVXF64(size_t factor, size_t stride)
                // {
                //     assert((std::is_same<Float, Float64>::value));
                //     assert(stride == 4);
                //     initEndAVXF64(factor, stride);
                //     for (int log_len = LOG_END - 2; log_len >= LOG_BEGIN; log_len--)
                //     {
                //         auto it_src = getBegin(log_len + 1), it = getBegin(log_len);
                //         size_t cpx_len = (size_t(1) << log_len) / DIV;
                //         for (auto end = it + cpx_len * 2; it < end; it += 8, it_src += 16)
                //         {
                //             Complex64X4 omega0, omega1, omega2, omega3;
                //             omega0.load(it_src), omega1.load(it_src + 8);
                //             omega0.real = F64X4::extractEven64X4(omega0.real, omega1.real);
                //             omega0.imag = F64X4::extractEven64X4(omega0.imag, omega1.imag);
                //             omega0.store(it);
                //         }
                //     }
                // }
                constexpr const Float *getBegin(int log_len) const
                {
                    size_t shift = (size_t(1) << log_len) / DIV;
                    return &table[shift * 2];
                }
                constexpr Float *getBegin(int log_len)
                {
                    size_t shift = (size_t(1) << log_len) / DIV;
                    return &table[shift * 2];
                }
            };
            struct FFTFixed
            {
                static constexpr size_t LOG_MAX = 18;
                static constexpr size_t LOG_SHORT = 10;
                static constexpr size_t SHORT_LEN = size_t(1) << LOG_SHORT;
                static const TableFix<Float64, 4> table_8;
                static const TableFix<Float64, 4> table_16_1;
                static const TableFix<Float64, 4> table_16_3;
                static const TableFix<Float64, 8> table_32_1;
                static const TableFix<Float64, 8> table_32_3;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_2;
                static const TableFixMulti<Float64, 6, LOG_SHORT, 4> multi_table_3;
                static const TableFixMulti<Float64, 6, LOG_MAX, 4> multi_table_1;

                static constexpr const Float64 *it8 = &table_8[0];
                static constexpr const Float64 *it16_1 = &table_16_1[0];
                static constexpr const Float64 *it16_3 = &table_16_3[0];
                static constexpr const Float64 *it32_1 = &table_32_1[0];
                static constexpr const Float64 *it32_3 = &table_32_3[0];

                template <typename Float>
                static void dif4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, r3);
                    transform2(i1, i3);

                    transform2(r0, r1);
                    transform2(i0, i1);
                    transform2(r2, i3);
                    transform2(i2, r3, r3, i2);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void idit4(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r1);
                    transform2(i0, i1);
                    transform2(r2, r3);
                    transform2(i2, i3);

                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, i3, i3, r1);
                    transform2(i1, r3);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void difSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, r3);
                    transform2(i1, i3);

                    transform2(r2, i3);
                    transform2(i2, r3, r3, i2);
                    std::swap(i3, r3);
                }
                template <typename Float>
                static void iditSplit(Float &r0, Float &i0, Float &r1, Float &i1, Float &r2, Float &i2, Float &r3, Float &i3)
                {
                    transform2(r2, r3);
                    transform2(i2, i3);

                    transform2(r0, r2);
                    transform2(i0, i2);
                    transform2(r1, i3, i3, r1);
                    transform2(i1, r3);
                    std::swap(i3, r3);
                }
                static void dif4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
                {
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    dif4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);
                }
                static void idit4x4(F64X4 &r0, F64X4 &i0, F64X4 &r1, F64X4 &i1, F64X4 &r2, F64X4 &i2, F64X4 &r3, F64X4 &i3)
                {
                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);

                    idit4(r0, i0, r1, i1, r2, i2, r3, i3);

                    transpose64_4X4(r0, r1, r2, r3);
                    transpose64_4X4(i0, i1, i2, i3);
                }
                static void dif8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
                {
                    transform2(c0, c1);
                    transform2(c2, c3);
                    c1 = c1.mul(omega);
                    c3 = c3.mul(omega);
                    dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                }
                static void idit8x2(Complex64X4 &c0, Complex64X4 &c1, Complex64X4 &c2, Complex64X4 &c3, const Complex64X4 &omega)
                {
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c1 = c1.mulConj(omega);
                    c3 = c3.mulConj(omega);
                    transform2(c0, c1);
                    transform2(c2, c3);
                }
                static void dif8x2(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    dif8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void idit8x2(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24), omega.load(it8);
                    idit8x2(c0, c1, c2, c3, omega);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void dif16(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mul(omega);
                    omega.load(it16_1), c2 = c2.mul(omega);
                    omega.load(it16_3), c3 = c3.mul(omega);
                    dif4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void idit16(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 8), c2.load(in_out + 16), c3.load(in_out + 24);
                    idit4x4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it8), c1 = c1.mulConj(omega);
                    omega.load(it16_1), c2 = c2.mulConj(omega);
                    omega.load(it16_3), c3 = c3.mulConj(omega);
                    idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out), c1.store(in_out + 8), c2.store(in_out + 16), c3.store(in_out + 24);
                }
                static void dif32(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    c0.load(in_out), c1.load(in_out + 16), c2.load(in_out + 32), c3.load(in_out + 48);
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1), c2 = c2.mul(omega);
                    omega.load(it32_3), c3 = c3.mul(omega);
                    c0.store(in_out), c1.store(in_out + 16), c2.store(in_out + 32), c3.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56); // 1,3,5,7
                    difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    omega.load(it32_1 + 8), c2 = c2.mul(omega);
                    omega.load(it32_3 + 8), c3 = c3.mul(omega);
                    c0.store(in_out + 8), c1.store(in_out + 24);
                    c0.load(in_out + 32), c1.load(in_out + 48), omega.load(it8); // 4,6
                    dif8x2(c0, c2, c1, c3, omega);
                    c0.store(in_out + 32), c2.store(in_out + 40), c1.store(in_out + 48), c3.store(in_out + 56);
                    dif16(in_out);
                }
                static void idit32(Float64 in_out[])
                {
                    Complex64X4 c0, c1, c2, c3, omega;
                    idit16(in_out);
                    c0.load(in_out + 32), c1.load(in_out + 40), c2.load(in_out + 48), c3.load(in_out + 56), omega.load(it8); // 4,5,6,7
                    idit8x2(c0, c1, c2, c3, omega);
                    c1.store(in_out + 40), c3.store(in_out + 56);

                    c1.load(in_out), c3.load(in_out + 16);
                    omega.load(it32_1), c0 = c0.mulConj(omega);
                    omega.load(it32_3), c2 = c2.mulConj(omega);
                    iditSplit(c1.real, c1.imag, c3.real, c3.imag, c0.real, c0.imag, c2.real, c2.imag);
                    c1.store(in_out), c3.store(in_out + 16), c0.store(in_out + 32), c2.store(in_out + 48);

                    c0.load(in_out + 8), c1.load(in_out + 24), c2.load(in_out + 40), c3.load(in_out + 56);
                    omega.load(it32_1 + 8), c2 = c2.mulConj(omega);
                    omega.load(it32_3 + 8), c3 = c3.mulConj(omega);
                    iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                    c0.store(in_out + 8), c1.store(in_out + 24), c2.store(in_out + 40), c3.store(in_out + 56);
                }
                static void difRec(Float64 in_out[], StaticInt<LOG_SHORT>)
                {
                    difIter<LOG_SHORT>(in_out);
                }
                static void iditRec(Float64 in_out[], StaticInt<LOG_SHORT>)
                {
                    iditIter<LOG_SHORT>(in_out);
                }
                static void difRec(Float64 in_out[], StaticInt<LOG_SHORT - 1>)
                {
                    difIter<LOG_SHORT - 1>(in_out);
                }
                static void iditRec(Float64 in_out[], StaticInt<LOG_SHORT - 1>)
                {
                    iditIter<LOG_SHORT - 1>(in_out);
                }
                static void difRec(Float64 in_out[], StaticInt<4>)
                {
                    dif16(in_out);
                }
                static void iditRec(Float64 in_out[], StaticInt<4>)
                {
                    idit16(in_out);
                }
                static void difRec(Float64 in_out[], StaticInt<5>)
                {
                    dif32(in_out);
                }
                static void iditRec(Float64 in_out[], StaticInt<5>)
                {
                    idit32(in_out);
                }
                template <bool FROM_RIRI_PERM = false, int LOG_N>
                static void difRec(Float64 in_out[], StaticInt<LOG_N>)
                {
                    using FromRIRI = std::integral_constant<bool, FROM_RIRI_PERM>;
                    constexpr size_t LEN = size_t(1) << LOG_N;
                    constexpr size_t STRIDE1 = LEN / 2, STRIDE2 = STRIDE1 * 2, STRIDE3 = STRIDE1 * 3;
                    auto table1 = multi_table_1.getBegin(LOG_N), table3 = multi_table_3.getBegin(LOG_N);
                    for (auto end = in_out + STRIDE1, it = in_out; it < end; it += 8, table1 += 8, table3 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega;
                        c0.load(it, FromRIRI{}), c1.load(it + STRIDE1, FromRIRI{}), c2.load(it + STRIDE2, FromRIRI{}), c3.load(it + STRIDE3, FromRIRI{});
                        difSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        omega.load(table1), c2 = c2.mul(omega);
                        // omega.load(table3), c3 = c3.mul(omega);
                        c3 = c3.mul(omega.cube());
                        c0.store(it), c1.store(it + STRIDE1), c2.store(it + STRIDE2), c3.store(it + STRIDE3);
                    }
                    difRec(in_out, StaticInt<LOG_N - 1>{});
                    difRec(in_out + STRIDE2, StaticInt<LOG_N - 2>{});
                    difRec(in_out + STRIDE3, StaticInt<LOG_N - 2>{});
                }
                template <bool TO_RIRI_PERM = false, bool TO_INT64 = false, int LOG_N>
                static void iditRec(Float64 in_out[], StaticInt<LOG_N>)
                {
                    using ToRIRI = std::integral_constant<bool, TO_RIRI_PERM>;
                    using ToI64 = std::integral_constant<bool, TO_INT64>;
                    constexpr size_t LEN = size_t(1) << LOG_N;
                    constexpr size_t STRIDE1 = LEN / 2, STRIDE2 = STRIDE1 * 2, STRIDE3 = STRIDE1 * 3;
                    iditRec(in_out, StaticInt<LOG_N - 1>{});
                    iditRec(in_out + STRIDE2, StaticInt<LOG_N - 2>{});
                    iditRec(in_out + STRIDE3, StaticInt<LOG_N - 2>{});
                    auto table1 = multi_table_1.getBegin(LOG_N), table3 = multi_table_3.getBegin(LOG_N);
                    for (auto end = in_out + STRIDE1, it = in_out; it < end; it += 8, table1 += 8, table3 += 8)
                    {
                        Complex64X4 c0, c1, c2, c3, omega;
                        c0.load(it), c1.load(it + STRIDE1), c2.load(it + STRIDE2), c3.load(it + STRIDE3);
                        omega.load(table1), c2 = c2.mulConj(omega);
                        // omega.load(table3), c3 = c3.mulConj(omega);
                        c3 = c3.mulConj(omega.cube());
                        iditSplit(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                        c0 = c0.transToI64(ToI64{}), c1 = c1.transToI64(ToI64{}), c2 = c2.transToI64(ToI64{}), c3 = c3.transToI64(ToI64{});
                        c0.store(it, ToRIRI{}), c1.store(it + STRIDE1, ToRIRI{}), c2.store(it + STRIDE2, ToRIRI{}), c3.store(it + STRIDE3, ToRIRI{});
                    }
                }
                template <int LOG_N>
                static void difIter(Float64 in_out[])
                {
                    constexpr size_t LEN = size_t(1) << LOG_N, FLOAT_LEN = LEN * 2;
                    assert(LEN >= 64);
                    int log_rank = LOG_N;
                    for (; log_rank >= 6; log_rank -= 2)
                    {
                        size_t rank2 = size_t(1) << (log_rank + 1), stride1 = rank2 / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + FLOAT_LEN; begin < end; begin += rank2)
                        {
                            auto table1 = multi_table_1.getBegin(log_rank), table2 = multi_table_2.getBegin(log_rank), table3 = multi_table_3.getBegin(log_rank);
                            for (auto it = begin; it < begin + stride1; it += 8, table1 += 8, table2 += 8, table3 += 8)
                            {
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it), c1.load(it + stride1), c2.load(it + stride2), c3.load(it + stride3);
                                dif4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                omega.load(table2), c1 = c1.mul(omega);
                                omega.load(table1), c2 = c2.mul(omega);
                                omega.load(table3), c3 = c3.mul(omega);
                                c0.store(it), c1.store(it + stride1), c2.store(it + stride2), c3.store(it + stride3);
                            }
                        }
                    }
                    constexpr int SMALL_FFT_LOG = LOG_N % 2 == 0 ? 4 : 5;
                    constexpr size_t SMALL_FFT_LEN = size_t(1) << SMALL_FFT_LOG, STRIDE = SMALL_FFT_LEN * 2;
                    for (auto it = in_out; it < in_out + FLOAT_LEN; it += STRIDE * 4)
                    {
                        using SmallSize = StaticInt<SMALL_FFT_LOG>;
                        difRec(it, SmallSize{});
                        difRec(it + STRIDE, SmallSize{});
                        difRec(it + STRIDE * 2, SmallSize{});
                        difRec(it + STRIDE * 3, SmallSize{});
                    }
                }
                template <int LOG_N>
                static void iditIter(Float64 in_out[])
                {
                    const size_t LEN = size_t(1) << LOG_N, FLOAT_LEN = LEN * 2;
                    assert(LEN >= 64);
                    constexpr int SMALL_FFT_LOG = LOG_N % 2 == 0 ? 4 : 5;
                    constexpr size_t SMALL_FFT_LEN = size_t(1) << SMALL_FFT_LOG, STRIDE = SMALL_FFT_LEN * 2;
                    for (auto it = in_out; it < in_out + FLOAT_LEN; it += STRIDE * 4)
                    {
                        using SmallSize = StaticInt<SMALL_FFT_LOG>;
                        iditRec(it, SmallSize{});
                        iditRec(it + STRIDE, SmallSize{});
                        iditRec(it + STRIDE * 2, SmallSize{});
                        iditRec(it + STRIDE * 3, SmallSize{});
                    }
                    int log_rank = SMALL_FFT_LOG + 2;
                    for (; log_rank <= LOG_N; log_rank += 2)
                    {
                        size_t rank2 = size_t(1) << (log_rank + 1), stride1 = rank2 / 4, stride2 = stride1 * 2, stride3 = stride1 * 3;
                        for (auto begin = in_out, end = in_out + FLOAT_LEN; begin < end; begin += rank2)
                        {
                            auto table1 = multi_table_1.getBegin(log_rank), table2 = multi_table_2.getBegin(log_rank), table3 = multi_table_3.getBegin(log_rank);
                            auto it0 = begin, it1 = begin + stride1, it2 = begin + stride2, it3 = begin + stride3;
                            for (; it0 < begin + stride1; it0 += 8, it1 += 8, it2 += 8, it3 += 8, table1 += 8, table2 += 8, table3 += 8)
                            {
                                Complex64X4 c0, c1, c2, c3, omega;
                                c0.load(it0), c1.load(it1), c2.load(it2), c3.load(it3);
                                omega.load(table2), c1 = c1.mulConj(omega);
                                omega.load(table1), c2 = c2.mulConj(omega);
                                omega.load(table3), c3 = c3.mulConj(omega);
                                idit4(c0.real, c0.imag, c1.real, c1.imag, c2.real, c2.imag, c3.real, c3.imag);
                                c0.store(it0), c1.store(it1), c2.store(it2), c3.store(it3);
                            }
                        }
                    }
                }
            };
            constexpr size_t FFTFixed::LOG_SHORT;
            constexpr size_t FFTFixed::SHORT_LEN;

            const TableFix<Float64, 4> FFTFixed::table_8(8, 1, 4);
            const TableFix<Float64, 4> FFTFixed::table_16_1(16, 1, 4);
            const TableFix<Float64, 4> FFTFixed::table_16_3(16, 3, 4);
            const TableFix<Float64, 8> FFTFixed::table_32_1(32, 1, 4);
            const TableFix<Float64, 8> FFTFixed::table_32_3(32, 3, 4);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_2(2);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_SHORT, 4> FFTFixed::multi_table_3(3);
            const TableFixMulti<Float64, 6, FFTFixed::LOG_MAX, 4> FFTFixed::multi_table_1(1);

            class BinRevTableC64X4
            {
            public:
                using F64 = double;
                using C64 = std::complex<F64>;
                using C64X4 = hint_simd::Complex64X4;

                static constexpr int MAX_LOG_LEN = CHAR_BIT * sizeof(size_t) - 2;
                static constexpr size_t MAX_LEN = size_t(1) << MAX_LOG_LEN;

                // 由最大的FFT长度,最大的迭代次数得到
                BinRevTableC64X4(int log_max_iter_in, int log_fft_len_in)
                    : log_max_iter(log_max_iter_in), log_fft_len(log_fft_len_in)
                {
                    assert(log_max_iter <= log_fft_len);
                    assert(log_fft_len <= MAX_LOG_LEN);
                    F64 factor = F64(1) / (size_t(1) << (log_fft_len - log_max_iter));
                    table[0] = getOmega(2, 1, factor);
                    table[1] = getOmega(4, 1, -factor);
                    table[2] = getOmega(8, 1, factor);
                    for (int i = 3; i < MAX_LOG_LEN; i++)
                    {
                        const size_t rev_indx = 1;
                        const size_t last_indx = ((size_t(1) << i) - 1) << 1;
                        const size_t shift = size_t(6) << (i - 2);
                        table[i] = getOmega(size_t(1) << i + 1, last_indx - rev_indx - shift, -factor);
                    }
                    reset();
                }

                inline void reset(size_t i = 0)
                {
                    auto brev = [this, i](size_t j) -> size_t
                    {
                        static const int shift = log_max_iter - 2;
                        static size_t rev4[4]{0, size_t(2) << shift, size_t(1) << shift, size_t(3) << shift};
                        if (i == 0)
                        {
                            return rev4[j];
                        }
                        int log_i = hint_log2(i);
                        return rev4[j] | ((size_t(1) << (log_max_iter - 1 - log_i)));
                    };
                    assert(i % 4 == 0);
                    F64 omegaX4[8];
                    for (int j = 0; j < 4; j++)
                    {
                        auto omega = getOmega(size_t(1) << log_fft_len, brev(j));
                        omegaX4[j] = omega.real();
                        omegaX4[j + 4] = omega.imag();
                    }
                    cur = C64X4(omegaX4);
                    index = i;
                }

                C64X4 iterate()
                {
                    C64X4 diff, res = cur;
                    index += 4;
                    auto p = reinterpret_cast<F64 *>(&table[hint_ctz(index)]);
                    diff.load1(p, p + 1);
                    cur = cur.mul(diff);
                    return res;
                }

                static C64 getOmega(size_t n, size_t index, F64 factor = 1)
                {
                    const F64 theta = -HINT_2PI * factor * index / n;
                    return std::polar<F64>(1.0, theta);
                }

            private:
                C64X4 cur;
                size_t index;
                C64 table[MAX_LOG_LEN];
                int log_max_iter, log_fft_len;
            };

            template <size_t RI_DIFF = 1, typename FloatTy>
            inline void dot_rfft(FloatTy *inout0, FloatTy *inout1, const FloatTy *in0, const FloatTy *in1, const std::complex<FloatTy> &omega0)
            {
                using Complex = std::complex<FloatTy>;
                auto combine2 = [&omega0](auto r0, auto i0, auto r1, auto i1, Complex &out0, Complex &out1)
                {
                    auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
                    auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

                    r0 = ti1 * omega0.real() + tr1 * omega0.imag();
                    i0 = ti1 * omega0.imag() - tr1 * omega0.real();

                    out0.real(tr0 + r0);
                    out0.imag(ti0 + i0);

                    out1.real(tr0 - r0);
                    out1.imag(i0 - ti0);
                };
                Complex x0, x1, x2, x3;
                auto r0 = inout0[0], i0 = inout0[RI_DIFF], r1 = inout1[0], i1 = inout1[RI_DIFF];
                combine2(r0, i0, r1, i1, x0, x1);
                r0 = in0[0], i0 = in0[RI_DIFF], r1 = in1[0], i1 = in1[RI_DIFF];
                combine2(r0, i0, r1, i1, x2, x3);
                x0 *= x2;
                x1 *= x3;
                { // separate2
                    r0 = x0.real(), i0 = x0.imag(), r1 = x1.real(), i1 = x1.imag();
                    auto tr0 = r0 + r1, ti0 = i0 - i1; // sum
                    auto tr1 = r0 - r1, ti1 = i0 + i1; // diff

                    auto r = tr1 * omega0.imag() - ti1 * omega0.real();
                    auto i = tr1 * omega0.real() + ti1 * omega0.imag();

                    r0 = tr0 + r;
                    i0 = ti0 + i;

                    r1 = tr0 - r;
                    i1 = i - ti0;
                }
                inout0[0] = r0, inout0[RI_DIFF] = i0, inout1[0] = r1, inout1[RI_DIFF] = i1;
            }
            inline void dot_rfftX4(F64 *inout0, F64 *inout1, const F64 *in0, const F64 *in1, const C64X4 &omega0, const F64X4 &inv)
            {
                auto combine2 = [&omega0](C64X4 c0, C64X4 c1, C64X4 &out0, C64X4 &out1)
                {
                    auto tr0 = c0.real + c1.real, ti0 = c0.imag - c1.imag; // sum
                    auto tr1 = c0.real - c1.real, ti1 = c0.imag + c1.imag; // diff

                    c0.real = F64X4::fmadd(ti1, omega0.real, tr1 * omega0.imag);
                    c0.imag = F64X4::fmsub(ti1, omega0.imag, tr1 * omega0.real);

                    out0.real = tr0 + c0.real;
                    out0.imag = ti0 + c0.imag;

                    out1.real = tr0 - c0.real;
                    out1.imag = c0.imag - ti0;
                };
                C64X4 x0, x1;
                {
                    C64X4 x2, x3, x4, x5;
                    x0.load(inout0), x1.load(inout1);
                    x1.real = x1.real.reverse();
                    x1.imag = x1.imag.reverse();

                    combine2(x0, x1, x2, x3);

                    x0.load(in0), x1.load(in1);
                    x1.real = x1.real.reverse();
                    x1.imag = x1.imag.reverse();

                    combine2(x0, x1, x4, x5);

                    x0 = x2.mul(x4);
                    x1 = x3.mul(x5);
                }
                {                                                          // separate2
                    auto tr0 = x0.real + x1.real, ti0 = x0.imag - x1.imag; // sum
                    auto tr1 = x0.real - x1.real, ti1 = x0.imag + x1.imag; // diff
                    auto r = F64X4::fmsub(tr1, omega0.imag, ti1 * omega0.real);
                    auto i = F64X4::fmadd(tr1, omega0.real, ti1 * omega0.imag);

                    x0.real = (tr0 + r) * inv;
                    x0.imag = (ti0 + i) * inv;

                    x1.real = (tr0 - r) * inv;
                    x1.imag = (i - ti0) * inv;
                }
                x1.real = x1.real.reverse();
                x1.imag = x1.imag.reverse();
                x0.store(inout0), x1.store(inout1);
            }

            inline void real_conv_binrev4(Float64 in_out[], Float64 in[], size_t len_complex)
            {
                const F64X4 inv4(0.125 / len_complex);
                static BinRevTableC64X4 table(22, 23);
                auto t0 = C64(in_out[0], in_out[4]), t1 = C64(in[0], in[4]);
                auto t2 = (t0.real() + t0.imag()) * (t1.real() + t1.imag());
                auto t3 = (t0.real() - t0.imag()) * (t1.real() - t1.imag());
                in_out[0] = (t2 + t3) * 4, in_out[4] = (t2 - t3) * 4;
                t0 = C64(in_out[1], in_out[5]) * C64(in[1], in[5]) * 8.0;
                in_out[1] = t0.real(), in_out[5] = t0.imag();
                dot_rfft<4>(in_out + 2, in_out + 3, in + 2, in + 3, C64(0, -1));
                if (len_complex <= 4)
                {
                    return;
                }
                constexpr Float64 COS_16_1 = 0.92387953251128675612818318939;
                constexpr Float64 SIN_16_1 = -0.38268343236508977172845998403;
                dot_rfft<4>(in_out + 8, in_out + 11, in + 8, in + 11, C64(COS_16_1, SIN_16_1));
                dot_rfft<4>(in_out + 9, in_out + 10, in + 9, in + 10, C64(SIN_16_1, -COS_16_1));
                for (size_t i = 0; i < 16; i++)
                {
                    in_out[i] *= (0.125 / len_complex);
                }
                for (size_t len = 8; len < len_complex; len *= 2)
                {
                    size_t begin = len * 2;
                    table.reset(len);
                    auto it0 = in_out + begin, it1 = it0 + begin - 8, it2 = in + begin, it3 = it2 + begin - 8;
                    for (; it0 < it1; it0 += 8, it1 -= 8, it2 += 8, it3 -= 8)
                    {
                        auto omega = table.iterate();
                        dot_rfftX4(it0, it1, it2, it3, omega, inv4);
                    }
                }
            }

            template <int LOG_CONV = 17, bool TO_INT = true>
            inline void real_conv(F64 *in_out1, F64 *in2)
            {
                static_assert((LOG_CONV - 1) <= FFTFixed::LOG_MAX);
                FFTFixed::difRec<true>((double *)(in_out1), StaticInt<LOG_CONV - 1>{});
                FFTFixed::difRec<true>((double *)(in2), StaticInt<LOG_CONV - 1>{});
                real_conv_binrev4(in_out1, in2, size_t(1) << (LOG_CONV - 1));
                FFTFixed::iditRec<true, TO_INT>((double *)(in_out1), StaticInt<LOG_CONV - 1>{});
            }
            template <bool TO_INT = true>
            inline void real_conv(F64 *in_out1, F64 *in2, size_t conv_len)
            {
#define CASE_CONV(log_conv)                        \
    case log_conv:                                 \
        real_conv<log_conv, TO_INT>(in_out1, in2); \
        break;
                const int log_conv = hint_log2(conv_len);
                switch (log_conv)
                {
                    CASE_CONV(11)
                    CASE_CONV(12)
                    CASE_CONV(13)
                    CASE_CONV(14)
                    CASE_CONV(15)
                    CASE_CONV(16)
                    CASE_CONV(17)
                    CASE_CONV(18)
                    CASE_CONV(19)
                default:
                    assert(false && "Unsupported convolution length");
                }
#undef CASE_CONV
            }
        }
    }
    constexpr uint64_t stoui64(const char *s, size_t dig = 4)
    {
        uint64_t result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }

    constexpr uint32_t stobase10000(const char *s)
    {
        return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
    }
    constexpr uint32_t stobase100000(const char *s)
    {
        return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
    }
    static constexpr int DIGIT = 4;
    constexpr uint64_t BASE = 10000;
    inline size_t char_to_float64(const char *buffer, double *float_ary, size_t str_len)
    {
        int64_t len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            uint32_t tmp = stobase10000(buffer + pos - DIGIT);
            float_ary[i] = tmp;
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            uint32_t tmp = stoui64(buffer, pos);
            float_ary[i] = tmp;
        }
        return len;
    }

    class ItoStrBase10000
    {
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num)
        {
            uint32_t res = '0' * 0x1010101;
            res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                   ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res;
        }
        constexpr ItoStrBase10000()
        {
            for (size_t i = 0; i < 10000; i++)
            {
                table[i] = itosbase10000(i);
            }
        }
        void tostr(char *str, uint32_t num) const
        {
            *reinterpret_cast<uint32_t *>(str) = table[num];
        }
        uint32_t tostr(uint32_t num) const
        {
            return table[num];
        }
    };
    // 读取两个数字字符串
    void read_2num_str(const char *s, const char *&a, size_t &len_a, const char *&b, size_t &len_b)
    {
        while (!isdigit(*s))
        {
            s++;
        }
        a = s;
        while (*s >= '0')
        {
            s++;
        }
        len_a = s - a;
        while (!isdigit(*s))
        {
            s++;
        }
        b = s;
        len_b = strlen(b);
        while (!isdigit(b[len_b - 1]))
        {
            len_b--;
        }
    }
}

using namespace hint;
using namespace transform;
using namespace fft;

void test_big_mul()
{
    constexpr size_t STR_LEN = 2000008;
    constexpr int LOG_LEN = 18;
    constexpr size_t MAX_FFT_LEN = size_t(1) << LOG_LEN;
    constexpr size_t FLOAT_MAX_LEN = MAX_FFT_LEN * 2;
    static constexpr ItoStrBase10000 transfer;
    static AlignAry<char, STR_LEN> out;
    static AlignAry<Float64, FLOAT_MAX_LEN> ary1;
    static AlignAry<Float64, FLOAT_MAX_LEN> ary2;
    uint32_t *ary = out.template cast_ptr<uint32_t>();
    size_t len_a = 0, len_b = 0;
    fread(out.data(), 1, STR_LEN, stdin);
    // str_fill(out.data(), (STR_LEN - 8) / 2);
    const char *a, *b;
    read_2num_str(out.data(), a, len_a, b, len_b);

    /*
        struct stat sb;
        int fd = fileno(stdin);
        fstat(fd, &sb);
        p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
        madvise(p, sb.st_size, MADV_SEQUENTIAL);
    */
    if (len_a == 1 && a[0] == '0')
    {
        puts("0");
        return;
    }
    if (len_b == 1 && b[0] == '0')
    {
        puts("0");
        return;
    }
    size_t len2 = char_to_float64(b, ary2.data(), len_b);
    size_t len1 = char_to_float64(a, ary1.data(), len_a);
    size_t conv_len = len1 + len2 - 1, len = int_ceil2(conv_len);

    real_conv(ary1.data(), ary2.data(), len);

    auto i64_ary1 = reinterpret_cast<uint64_t *>(ary1.data());
    uint64_t carry = 0;
    size_t pos = STR_LEN / 4 - 1;
    for (size_t i = 0; i < conv_len; i++)
    {
        carry += i64_ary1[i];
        uint64_t num = carry % BASE;
        carry /= BASE;
        ary[pos] = transfer.tostr(num);
        pos--;
    }
    ary[pos] = transfer.tostr(carry);
    pos *= 4;
    while (out[pos] == '0')
    {
        pos++;
    } // 0.8ms
    fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}

int main()
{
    test_big_mul();
}

CompilationN/AN/ACompile ErrorScore: N/A


Judge Duck Online | 评测鸭在线
Server Time: 2025-10-27 07:38:33 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠