提交记录 20152


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 17.059 ms 20328 KB C++14 41.88 KB
提交时间 评测时间
2023-09-11 22:56:31 2023-09-11 22:56:34
#include <vector>
#include <array>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
#pragma GCC optimize("inline")
#define TABLE_ENABLE 1  // 是否使用查找表
#define MULTITHREAD 0   // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 1 // 是否提前初始化表 0 means no, 1 means yes

#if MULTITHREAD == 1
#define TABLE_ENABLE 1
#endif

namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using ULONG = unsigned long;
    using LONG = long;
    using HintFloat = double;
    using Complex = std::complex<HintFloat>;
    constexpr UINT_64 HINT_CHAR_BIT = 8;
    constexpr UINT_64 HINT_SHORT_BIT = 16;
    constexpr UINT_64 HINT_INT_BIT = 32;
    constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
    constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
    constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
    constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
    constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
    constexpr UINT_64 HINT_INT32_0X01 = 1;
    constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
    constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
    constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
    constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
    constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
    constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;
    constexpr double HINT_PI = 3.1415926535897932384626433832795;
    constexpr double HINT_2PI = HINT_PI * 2;
    constexpr double HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;
    constexpr UINT_64 NTT_MOD1 = 3221225473;
    constexpr UINT_64 NTT_ROOT1 = 5;
    constexpr UINT_64 NTT_MOD2 = 3489660929;
    constexpr UINT_64 NTT_ROOT2 = 3;
    constexpr size_t NTT_MAX_LEN = 1ull << 28;
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(128) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }
    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }
    template <typename T>
    constexpr int hint_log2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        int l = -1, r = bits;
        while ((l + 1) != r)
        {
            int mid = (l + r) / 2;
            if ((T(1) << mid) > n)
            {
                r = mid;
            }
            else
            {
                l = mid;
            }
        }
        return l;
    }
    template <typename T>
    constexpr T qpow(T m, UINT_64 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    template <typename T>
    constexpr std::pair<T, T> div_mod(T a, T b)
    {
        return std::make_pair(a / b, a % b);
    }
#if MULTITHREAD == 1
    const UINT_32 hint_threads = std::thread::hardware_concurrency();
    const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
    std::atomic<UINT_32> cur_ths;
#endif
    // 模板数组拷贝
    template <typename T>
    void ary_copy(T *target, const T *source, size_t len)
    {
        if (len == 0 || target == source)
        {
            return;
        }
        if (len >= INT64_MAX)
        {
            throw("Ary too long\n");
        }
        std::copy(target, source, len * sizeof(T));
    }
    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        namespace hint_fft
        {
            // 返回单位圆上辐角为theta的点
            static Complex unit_root(HintFloat theta)
            {
                return std::polar<HintFloat>(1.0, theta);
            }
            // 返回单位圆上平分m份的第n个
            static Complex unit_root(size_t m, size_t n)
            {
                return unit_root((HINT_2PI * n) / m);
            }
            struct RIPtr
            {
                HintFloat *real = nullptr;
                HintFloat *imag = nullptr;
                constexpr RIPtr() {}
                constexpr RIPtr(HintFloat *in_real, HintFloat *in_imag)
                    : real(in_real), imag(in_imag) {}
                template <typename DataTy>
                void load(DataTy &r, DataTy &i) const
                {
                    r.load(real);
                    i.load(imag);
                }
                template <typename DataTy>
                void save(const DataTy &r, const DataTy &i)
                {
                    r.save(real);
                    i.save(imag);
                }
                void load(HintFloat &r, HintFloat &i) const
                {
                    r = *real;
                    i = *imag;
                }
                void save(HintFloat r, HintFloat i)
                {
                    *real = r;
                    *imag = i;
                }
                constexpr RIPtr operator+(size_t offset) const
                {
                    return RIPtr(real + offset, imag + offset);
                }
                constexpr RIPtr operator-(size_t offset) const
                {
                    return RIPtr(real - offset, imag - offset);
                }
            };
            template <typename T>
            constexpr T complex_mul_real(const T &ar, const T &ai, const T &br, const T &bi)
            {
                return ar * br - ai * bi;
            }
            template <typename T>
            constexpr T complex_mul_imag(const T &ar, const T &ai, const T &br, const T &bi)
            {
                return ar * bi + ai * br;
            }
            template <UINT_32 MAX_SHIFT>
            class ComplexTableS
            {
            public:
                enum
                {
                    TABLE_LEN = (size_t(1) << MAX_SHIFT),
                    RI_DIS = TABLE_LEN / 2
                };

            private:
                AlignAry<HintFloat, TABLE_LEN> table1;
                AlignAry<HintFloat, TABLE_LEN> table3;
                INT_32 max_log_size = 2;
                INT_32 cur_log_size = 2;

                static constexpr size_t FAC = 1;

                ComplexTableS(const ComplexTableS &) = delete;
                ComplexTableS &operator=(const ComplexTableS &) = delete;

            public:
                // 初始化可以生成平分圆1<<shift份产生的单位根的表
                constexpr ComplexTableS()
                {
                    max_log_size = std::max<size_t>(MAX_SHIFT, 1);
                    table1[0] = table1[1] = 1;
                    table3[0] = table3[1] = 1;
#if TABLE_PRELOAD == 1
                    expand(max_log_size);
#endif
                }
                constexpr void expand(INT_32 shift)
                {
                    expand_topdown(shift);
                }
                constexpr void expand_topdown(INT_32 shift)
                {
                    shift = std::min(shift, max_log_size);
                    if (shift <= cur_log_size)
                    {
                        return;
                    }
                    size_t len = 1ull << shift, vec_size = len * FAC / 4;
                    RIPtr ptr1 = get_omega_begin(shift);
                    RIPtr ptr3 = get_omega3_begin(shift);
                    ptr1.save(1.0, 0.0);
                    ptr3.save(1.0, 0.0);
                    for (size_t pos = 1; pos < vec_size / 2; pos *= 2)
                    {
                        HintFloat theta = -HINT_2PI * pos / len;
                        HintFloat real = std::cos(theta), imag = std::sin(theta);
                        (ptr1 + pos).save(real, imag);
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos++)
                    {
                        size_t sub_pos = pos & (pos - 1);
                        HintFloat real1, imag1;
                        HintFloat real2, imag2;
                        HintFloat real3, imag3;
                        (ptr1 + sub_pos).load(real1, imag1), (ptr1 + pos - sub_pos).load(real2, imag2);
                        real3 = complex_mul_real(real1, imag1, real2, imag2);
                        imag3 = complex_mul_imag(real1, imag1, real2, imag2);
                        (ptr1 + pos).save(real3, imag3);
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos++)
                    {
                        size_t sub_pos = pos & (pos - 1);
                        HintFloat real, imag;
                        (ptr1 + pos).load(real, imag);
                        (ptr1 + vec_size - pos).save(-imag, -real);
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos++)
                    {
                        Complex tmp = get_omega(shift, pos * 3);
                        (ptr3 + pos).save(tmp.real(), tmp.imag());
                        (ptr3 + vec_size - pos).save(tmp.imag(), tmp.real());
                    }
                    Complex tmp = std::conj(unit_root(8, 1));
                    (ptr1 + vec_size / 2).save(tmp.real(), tmp.imag());
                    tmp = std::conj(unit_root(8, 3));
                    (ptr3 + vec_size / 2).save(tmp.real(), tmp.imag());
                    for (INT_32 log = shift - 1; log > cur_log_size; log--)
                    {
                        len = 1ull << log, vec_size = len / 4;
                        ptr1 = get_omega_begin(log);
                        ptr3 = get_omega3_begin(log);
                        RIPtr src1 = get_omega_begin(log + 1);
                        RIPtr src3 = get_omega3_begin(log + 1);
                        for (size_t pos = 0; pos < vec_size; pos++)
                        {
                            HintFloat r1, r3, i1, i3;
                            (src1 + pos * 2).load(r1, i1);
                            (src3 + pos * 2).load(r3, i3);
                            (ptr1 + pos).save(r1, i1);
                            (ptr3 + pos).save(r3, i3);
                        }
                    }
                    cur_log_size = std::max(cur_log_size, shift);
                }
                // shift表示圆平分为1<<shift份,3n表示第几个单位根
                constexpr Complex get_omega(UINT_32 shift, size_t n)
                {
                    size_t vec_size = (size_t(1) << shift) / 4;
                    RIPtr omg_ptr = get_omega_begin(shift);
                    HintFloat real, imag;
                    if (n < vec_size)
                    {
                        (omg_ptr + n).load(real, imag);
                        return Complex(real, imag);
                    }
                    else if (n > vec_size)
                    {
                        (omg_ptr + vec_size * 2 - n).load(real, imag);
                        return Complex(-real, imag);
                    }
                    else
                    {
                        return Complex(0, -1);
                    }
                }
                constexpr RIPtr get_omega_begin(UINT_32 shift)
                {
                    HintFloat *ptr = table1.data() + (1 << (shift - 2));
                    RIPtr ri_ptr(ptr, ptr + TABLE_LEN / 2);
                    return ri_ptr;
                }
                constexpr RIPtr get_omega3_begin(UINT_32 shift)
                {
                    HintFloat *ptr = table3.data() + (1 << (shift - 2));
                    RIPtr ri_ptr(ptr, ptr + TABLE_LEN / 2);
                    return ri_ptr;
                }
                template <UINT_32 SHIFT>
                constexpr const HintFloat *get_omega_iter() const
                {
                    return table1.data() + (1 << (SHIFT - 2));
                }
                template <UINT_32 SHIFT>
                constexpr const HintFloat *get_omega3_iter() const
                {
                    return table3.data() + (1 << (SHIFT - 2));
                }
            };

            constexpr size_t lut_max_rank = 19;
            using FFTable = ComplexTableS<lut_max_rank>;
            static FFTable TABLE;

            template <typename T>
            inline void fft_2point(T &sum, T &diff)
            {
                T tmp0 = sum;
                T tmp1 = diff;
                sum = tmp0 + tmp1;
                diff = tmp0 - tmp1;
            }
            struct ComputeEnd4
            {
                HintFloat f0, f1, f2, f3;
                ComputeEnd4() = default;
                ComputeEnd4(const ComputeEnd4 &in) = default;
                ComputeEnd4(HintFloat fin0, HintFloat fin1, HintFloat fin2, HintFloat fin3)
                    : f0(fin0), f1(fin1), f2(fin2), f3(fin3) {}
                ComputeEnd4(const HintFloat *ptr)
                {
                    load(ptr);
                }
                void load(const HintFloat *ptr)
                {
                    f0 = ptr[0];
                    f1 = ptr[1];
                    f2 = ptr[2];
                    f3 = ptr[3];
                }
                void save(HintFloat *const ptr)
                {
                    ptr[0] = f0;
                    ptr[1] = f1;
                    ptr[2] = f2;
                    ptr[3] = f3;
                }
                ComputeEnd4 operator+(const ComputeEnd4 &in) const
                {
                    return ComputeEnd4(
                        f0 + in.f0, f1 + in.f1, f2 + in.f2, f3 + in.f3);
                }
                ComputeEnd4 operator-(const ComputeEnd4 &in) const
                {
                    return ComputeEnd4(
                        f0 - in.f0, f1 - in.f1, f2 - in.f2, f3 - in.f3);
                }
                ComputeEnd4 operator*(const ComputeEnd4 &in) const
                {
                    return ComputeEnd4(
                        f0 * in.f0, f1 * in.f1, f2 * in.f2, f3 * in.f3);
                }
                ComputeEnd4 operator-() const
                {
                    return ComputeEnd4(-f0, -f1, -f2, -f3);
                }
                void print() const
                {
                    std::cout << f0 << " " << f1 << " " << f2 << " " << f3 << "\n";
                }
            };
            struct ComputeEndAVX
            {
                __m256d data;
                ComputeEndAVX() = default;
                ComputeEndAVX(const ComputeEndAVX &in) = default;
                ComputeEndAVX(__m256d in) : data(in) {}
                ComputeEndAVX(HintFloat fin0, HintFloat fin1, HintFloat fin2, HintFloat fin3)
                {
                    data = _mm256_set_pd(fin3, fin2, fin1, fin0);
                }
                ComputeEndAVX(HintFloat fin)
                {
                    data = _mm256_set1_pd(fin);
                }
                ComputeEndAVX(const HintFloat *ptr)
                {
                    load(ptr);
                }
                void operator=(const HintFloat *ptr)
                {
                    load(ptr);
                }
                void loadu(const HintFloat *ptr)
                {
                    data = _mm256_loadu_pd(ptr);
                }
                void load(const HintFloat *ptr)
                {
                    data = _mm256_load_pd(ptr);
                }
                void save(HintFloat *const ptr) const
                {
                    _mm256_store_pd(ptr, data);
                }
                void saveu(HintFloat *const ptr) const
                {
                    _mm256_storeu_pd(ptr, data);
                }
                ComputeEndAVX operator+(const ComputeEndAVX &in) const
                {
                    return _mm256_add_pd(data, in.data);
                }
                ComputeEndAVX operator-(const ComputeEndAVX &in) const
                {
                    return _mm256_sub_pd(data, in.data);
                }
                ComputeEndAVX operator*(const ComputeEndAVX &in) const
                {
                    return _mm256_mul_pd(data, in.data);
                }
                ComputeEndAVX operator-() const
                {
                    return _mm256_sub_pd(_mm256_setzero_pd(), data);
                }
                static ComputeEndAVX fmadd(const ComputeEndAVX &a, const ComputeEndAVX &b, const ComputeEndAVX &c)
                {
                    return _mm256_fmadd_pd(a.data, b.data, c.data);
                }
                static ComputeEndAVX fmsub(const ComputeEndAVX &a, const ComputeEndAVX &b, const ComputeEndAVX &c)
                {
                    return _mm256_fmsub_pd(a.data, b.data, c.data);
                }
                void print() const
                {
                    HintFloat ary[4];
                    saveu(ary);
                    for (int i = 0; i < 4; ++i)
                        printf("%lf\n", ary[i]);
                }
            };
            struct ComputeEnd2
            {
                HintFloat f0, f1;
                ComputeEnd2() = default;
                ComputeEnd2(const ComputeEnd2 &in) = default;
                ComputeEnd2(HintFloat fin0, HintFloat fin1)
                    : f0(fin0), f1(fin1) {}
                ComputeEnd2(const HintFloat *ptr)
                {
                    load(ptr);
                }
                void operator=(const HintFloat *ptr)
                {
                    load(ptr);
                }
                void load(const HintFloat *ptr)
                {
                    f0 = ptr[0];
                    f1 = ptr[1];
                }
                void save(HintFloat *const ptr) const
                {
                    ptr[0] = f0;
                    ptr[1] = f1;
                }
                ComputeEnd2 operator+(const ComputeEnd2 &in) const
                {
                    return ComputeEnd2(
                        f0 + in.f0, f1 + in.f1);
                }
                ComputeEnd2 operator-(const ComputeEnd2 &in) const
                {
                    return ComputeEnd2(
                        f0 - in.f0, f1 - in.f1);
                }
                ComputeEnd2 operator*(const ComputeEnd2 &in) const
                {
                    return ComputeEnd2(
                        f0 * in.f0, f1 * in.f1);
                }
                ComputeEnd2 operator-() const
                {
                    return ComputeEnd2(-f0, -f1);
                }
                void print() const
                {
                    std::cout << f0 << " " << f1 << "\n";
                }
            };
            using Iter = HintFloat *;
            using ConstIter = const HintFloat *;
            template <size_t LEN, size_t RI_DIS>
            struct FFT
            {
                enum
                {
                    FT_DIS = RI_DIS,
                    LUT_DIS = FFTable::RI_DIS
                };
                static constexpr size_t log_len = hint_log2(LEN);
                static constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
                using DataTy = ComputeEndAVX;
                static constexpr size_t offset = sizeof(DataTy) / sizeof(HintFloat);
                using half_fft = FFT<half_len, RI_DIS>;
                using quarter_fft = FFT<quarter_len, RI_DIS>;
                static void fft_split_radix_dit_butterfly(ConstIter omega, ConstIter omega_cube, Iter fft_input)
                {
                    DataTy r0, r1, r2, r3, i0, i1, i2, i3;
                    DataTy tr0, tr1, tr2, tr3, ti0, ti1, ti2, ti3;

                    r0 = fft_input;
                    r1 = fft_input + quarter_len;
                    r2 = fft_input + quarter_len * 2;
                    r3 = fft_input + quarter_len * 3;
                    i0 = fft_input + FT_DIS;
                    i1 = fft_input + FT_DIS + quarter_len;
                    i2 = fft_input + FT_DIS + quarter_len * 2;
                    i3 = fft_input + FT_DIS + quarter_len * 3;

                    tr0 = omega;
                    tr1 = omega_cube;
                    ti0 = omega + LUT_DIS;
                    ti1 = omega_cube + LUT_DIS;
                    // tr2 = complex_mul_real(r2, i2, tr0, ti0);
                    // ti2 = complex_mul_imag(r2, i2, tr0, ti0);
                    // tr3 = complex_mul_real(r3, i3, tr1, ti1);
                    // ti3 = complex_mul_imag(r3, i3, tr1, ti1);
                    tr2 = DataTy::fmsub(r2, tr0, i2 * ti0);
                    ti2 = DataTy::fmadd(r2, ti0, i2 * tr0);
                    tr3 = DataTy::fmsub(r3, tr1, i3 * ti1);
                    ti3 = DataTy::fmadd(r3, ti1, i3 * tr1);

                    r2 = tr2 + tr3;
                    i2 = ti2 + ti3;
                    r3 = ti2 - ti3;
                    i3 = tr3 - tr2;

                    tr0 = r0 + r2;
                    ti0 = i0 + i2;
                    tr2 = r0 - r2;
                    ti2 = i0 - i2;

                    tr1 = r1 + r3;
                    ti1 = i1 + i3;
                    tr3 = r1 - r3;
                    ti3 = i1 - i3;

                    tr0.save(fft_input);
                    tr1.save(fft_input + quarter_len);
                    tr2.save(fft_input + quarter_len * 2);
                    tr3.save(fft_input + quarter_len * 3);
                    ti0.save(fft_input + FT_DIS);
                    ti1.save(fft_input + FT_DIS + quarter_len);
                    ti2.save(fft_input + FT_DIS + quarter_len * 2);
                    ti3.save(fft_input + FT_DIS + quarter_len * 3);
                }
                static void fft_split_radix_dif_butterfly(ConstIter omega, ConstIter omega_cube, Iter fft_input)
                {
                    DataTy r0, r1, r2, r3, i0, i1, i2, i3;
                    DataTy tr0, tr1, tr2, tr3, ti0, ti1, ti2, ti3;

                    r0 = fft_input;
                    r1 = fft_input + quarter_len;
                    r2 = fft_input + quarter_len * 2;
                    r3 = fft_input + quarter_len * 3;
                    i0 = fft_input + FT_DIS;
                    i1 = fft_input + FT_DIS + quarter_len;
                    i2 = fft_input + FT_DIS + quarter_len * 2;
                    i3 = fft_input + FT_DIS + quarter_len * 3;

                    tr0 = r0 + r2;
                    ti0 = i0 + i2;
                    tr2 = r0 - r2;
                    ti2 = i0 - i2;

                    tr1 = r1 + r3;
                    ti1 = i1 + i3;
                    tr3 = i1 - i3;
                    ti3 = r3 - r1;

                    r2 = tr2 + tr3;
                    i2 = ti2 + ti3;
                    r3 = tr2 - tr3;
                    i3 = ti2 - ti3;

                    r0 = omega;
                    r1 = omega_cube;
                    i0 = omega + LUT_DIS;
                    i1 = omega_cube + LUT_DIS;
                    // tr2 = complex_mul_real(r2, i2, r0, i0);
                    // ti2 = complex_mul_imag(r2, i2, r0, i0);
                    // tr3 = complex_mul_real(r3, i3, r1, i1);
                    // ti3 = complex_mul_imag(r3, i3, r1, i1);
                    tr2 = DataTy::fmsub(r2, r0, i2 * i0);
                    ti2 = DataTy::fmadd(r2, i0, i2 * r0);
                    tr3 = DataTy::fmsub(r3, r1, i3 * i1);
                    ti3 = DataTy::fmadd(r3, i1, i3 * r1);

                    tr0.save(fft_input);
                    tr1.save(fft_input + quarter_len);
                    tr2.save(fft_input + quarter_len * 2);
                    tr3.save(fft_input + quarter_len * 3);
                    ti0.save(fft_input + FT_DIS);
                    ti1.save(fft_input + FT_DIS + quarter_len);
                    ti2.save(fft_input + FT_DIS + quarter_len * 2);
                    ti3.save(fft_input + FT_DIS + quarter_len * 3);
                }
                static void fft_split_radix_dit(Iter fft_input)
                {
                    quarter_fft::fft_split_radix_dit(fft_input + half_len + quarter_len);
                    quarter_fft::fft_split_radix_dit(fft_input + half_len);
                    half_fft::fft_split_radix_dit(fft_input);
                    static ConstIter omg_ptr = TABLE.get_omega_iter<log_len>();
                    static ConstIter omg3_ptr = TABLE.get_omega3_iter<log_len>();
                    for (size_t i = 0; i < quarter_len; i += offset)
                    {
                        fft_split_radix_dit_butterfly(omg_ptr + i, omg3_ptr + i, fft_input + i);
                    }
                }
                static void fft_split_radix_dif(Iter fft_input)
                {
                    static ConstIter omg_ptr = TABLE.get_omega_iter<log_len>();
                    static ConstIter omg3_ptr = TABLE.get_omega3_iter<log_len>();
                    for (size_t i = 0; i < quarter_len; i += offset)
                    {
                        fft_split_radix_dif_butterfly(omg_ptr + i, omg3_ptr + i, fft_input + i);
                    }
                    half_fft::fft_split_radix_dif(fft_input);
                    quarter_fft::fft_split_radix_dif(fft_input + half_len);
                    quarter_fft::fft_split_radix_dif(fft_input + half_len + quarter_len);
                }
            };
            template <size_t RI_DIS>
            struct FFT<0, RI_DIS>
            {
                static void fft_split_radix_dit(Iter fft_input) {}
                static void fft_split_radix_dif(Iter fft_input) {}
            };
            template <size_t RI_DIS>
            struct FFT<1, RI_DIS>
            {
                static void fft_split_radix_dit(Iter fft_input) {}
                static void fft_split_radix_dif(Iter fft_input) {}
            };
            template <size_t RI_DIS>
            struct FFT<2, RI_DIS>
            {
                static void fft_split_radix_dit(HintFloat *fft_input)
                {
                    fft_2point(fft_input[0], fft_input[1]);
                    fft_2point(fft_input[RI_DIS], fft_input[RI_DIS + 1]);
                }
                static void fft_split_radix_dif(HintFloat *fft_input)
                {
                    fft_2point(fft_input[0], fft_input[1]);
                    fft_2point(fft_input[RI_DIS], fft_input[RI_DIS + 1]);
                }
            };
            template <size_t RI_DIS>
            struct FFT<4, RI_DIS>
            {
                static void fft_dit_4point(Iter fft_input)
                {
                    HintFloat r0 = fft_input[0];
                    HintFloat r1 = fft_input[1];
                    HintFloat r2 = fft_input[2];
                    HintFloat r3 = fft_input[3];
                    HintFloat i0 = fft_input[RI_DIS];
                    HintFloat i1 = fft_input[RI_DIS + 1];
                    HintFloat i2 = fft_input[RI_DIS + 2];
                    HintFloat i3 = fft_input[RI_DIS + 3];

                    HintFloat tr0 = r0 + r1;
                    HintFloat ti0 = i0 + i1;
                    HintFloat tr1 = r0 - r1;
                    HintFloat ti1 = i0 - i1;
                    HintFloat tr2 = r2 + r3;
                    HintFloat ti2 = i2 + i3;
                    HintFloat tr3 = i2 - i3;
                    HintFloat ti3 = r3 - r2;

                    fft_input[0] = tr0 + tr2;
                    fft_input[1] = tr1 + tr3;
                    fft_input[2] = tr0 - tr2;
                    fft_input[3] = tr1 - tr3;
                    fft_input[RI_DIS] = ti0 + ti2;
                    fft_input[RI_DIS + 1] = ti1 + ti3;
                    fft_input[RI_DIS + 2] = ti0 - ti2;
                    fft_input[RI_DIS + 3] = ti1 - ti3;
                }
                static void fft_dif_4point(Iter fft_input)
                {
                    HintFloat r0 = fft_input[0];
                    HintFloat r1 = fft_input[1];
                    HintFloat r2 = fft_input[2];
                    HintFloat r3 = fft_input[3];
                    HintFloat i0 = fft_input[RI_DIS];
                    HintFloat i1 = fft_input[RI_DIS + 1];
                    HintFloat i2 = fft_input[RI_DIS + 2];
                    HintFloat i3 = fft_input[RI_DIS + 3];

                    HintFloat tr0 = r0 + r2;
                    HintFloat ti0 = i0 + i2;
                    HintFloat tr2 = r0 - r2;
                    HintFloat ti2 = i0 - i2;
                    HintFloat tr1 = r1 + r3;
                    HintFloat ti1 = i1 + i3;
                    HintFloat tr3 = i1 - i3;
                    HintFloat ti3 = r3 - r1;

                    fft_input[0] = tr0 + tr1;
                    fft_input[1] = tr0 - tr1;
                    fft_input[2] = tr2 + tr3;
                    fft_input[3] = tr2 - tr3;
                    fft_input[RI_DIS] = ti0 + ti1;
                    fft_input[RI_DIS + 1] = ti0 - ti1;
                    fft_input[RI_DIS + 2] = ti2 + ti3;
                    fft_input[RI_DIS + 3] = ti2 - ti3;
                }
                static void fft_split_radix_dit(HintFloat *fft_input)
                {
                    fft_dit_4point(fft_input);
                }
                static void fft_split_radix_dif(HintFloat *fft_input)
                {
                    fft_dif_4point(fft_input);
                }
            };
            template <size_t RI_DIS>
            struct FFT<8, RI_DIS>
            {
                static void fft_dit_8point(Iter fft_input)
                {
                    HintFloat r0 = fft_input[0];
                    HintFloat r1 = fft_input[1];
                    HintFloat r2 = fft_input[2];
                    HintFloat r3 = fft_input[3];
                    HintFloat r4 = fft_input[4];
                    HintFloat r5 = fft_input[5];
                    HintFloat r6 = fft_input[6];
                    HintFloat r7 = fft_input[7];
                    HintFloat i0 = fft_input[RI_DIS];
                    HintFloat i1 = fft_input[RI_DIS + 1];
                    HintFloat i2 = fft_input[RI_DIS + 2];
                    HintFloat i3 = fft_input[RI_DIS + 3];
                    HintFloat i4 = fft_input[RI_DIS + 4];
                    HintFloat i5 = fft_input[RI_DIS + 5];
                    HintFloat i6 = fft_input[RI_DIS + 6];
                    HintFloat i7 = fft_input[RI_DIS + 7];
                    // 4Xdit2
                    HintFloat tr0 = r0 + r1, ti0 = i0 + i1; // 0-1
                    HintFloat tr1 = r0 - r1, ti1 = i0 - i1;
                    HintFloat tr2 = r2 + r3, ti2 = i2 + i3; // 2-3
                    HintFloat tr3 = i2 - i3, ti3 = r3 - r2;
                    HintFloat tr4 = r4 + r5, ti4 = i4 + i5; // 4-5
                    HintFloat tr5 = r4 - r5, ti5 = i4 - i5;
                    HintFloat tr6 = r6 + r7, ti6 = i6 + i7; // 6-7
                    HintFloat tr7 = i6 - i7, ti7 = r7 - r6;
                    // 2Xdit4
                    r0 = tr0 + tr2, i0 = ti0 + ti2; // 0-2
                    r2 = tr0 - tr2, i2 = ti0 - ti2;
                    r1 = tr1 + tr3, i1 = ti1 + ti3; // 1-3
                    r3 = tr1 - tr3, i3 = ti1 - ti3;
                    r4 = tr4 + tr6, i4 = ti4 + ti6; // 4-6
                    r6 = ti4 - ti6, i6 = tr6 - tr4;
                    r5 = tr5 + tr7, i5 = ti5 + ti7; // 5-7
                    r7 = tr5 - tr7, i7 = ti5 - ti7;
                    static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
                    static constexpr HintFloat cos_3_8 = -cos_1_8;
                    tr5 = cos_1_8 * (i5 + r5), ti5 = cos_1_8 * (i5 - r5);
                    tr7 = cos_3_8 * (r7 - i7), ti7 = cos_3_8 * (r7 + i7);
                    // dit8
                    fft_input[0] = r0 + r4;
                    fft_input[1] = r1 + tr5;
                    fft_input[2] = r2 + r6;
                    fft_input[3] = r3 + tr7;
                    fft_input[4] = r0 - r4;
                    fft_input[5] = r1 - tr5;
                    fft_input[6] = r2 - r6;
                    fft_input[7] = r3 - tr7;
                    fft_input[RI_DIS] = i0 + i4;
                    fft_input[RI_DIS + 1] = i1 + ti5;
                    fft_input[RI_DIS + 2] = i2 + i6;
                    fft_input[RI_DIS + 3] = i3 + ti7;
                    fft_input[RI_DIS + 4] = i0 - i4;
                    fft_input[RI_DIS + 5] = i1 - ti5;
                    fft_input[RI_DIS + 6] = i2 - i6;
                    fft_input[RI_DIS + 7] = i3 - ti7;
                }
                static void fft_dif_8point(Iter fft_input)
                {
                    HintFloat r0 = fft_input[0];
                    HintFloat r1 = fft_input[1];
                    HintFloat r2 = fft_input[2];
                    HintFloat r3 = fft_input[3];
                    HintFloat r4 = fft_input[4];
                    HintFloat r5 = fft_input[5];
                    HintFloat r6 = fft_input[6];
                    HintFloat r7 = fft_input[7];
                    HintFloat i0 = fft_input[RI_DIS];
                    HintFloat i1 = fft_input[RI_DIS + 1];
                    HintFloat i2 = fft_input[RI_DIS + 2];
                    HintFloat i3 = fft_input[RI_DIS + 3];
                    HintFloat i4 = fft_input[RI_DIS + 4];
                    HintFloat i5 = fft_input[RI_DIS + 5];
                    HintFloat i6 = fft_input[RI_DIS + 6];
                    HintFloat i7 = fft_input[RI_DIS + 7];
                    // dif8
                    HintFloat tr0 = r0 + r4, ti0 = i0 + i4; // 0-4
                    HintFloat tr4 = r0 - r4, ti4 = i0 - i4;
                    HintFloat tr1 = r1 + r5, ti1 = i1 + i5; // 1-5
                    HintFloat tr5 = r1 - r5, ti5 = i1 - i5;
                    HintFloat tr2 = r2 + r6, ti2 = i2 + i6; // 2-6
                    HintFloat tr6 = i2 - i6, ti6 = r6 - r2;
                    HintFloat tr3 = r3 + r7, ti3 = i3 + i7; // 3-7
                    HintFloat tr7 = r3 - r7, ti7 = i3 - i7;
                    static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
                    static constexpr HintFloat cos_3_8 = -cos_1_8;
                    r5 = cos_1_8 * (ti5 + tr5), i5 = cos_1_8 * (ti5 - tr5);
                    r7 = cos_3_8 * (tr7 - ti7), i7 = cos_3_8 * (tr7 + ti7);
                    // 2Xdif4
                    r0 = tr0 + tr2, i0 = ti0 + ti2; // 0-2
                    r2 = tr0 - tr2, i2 = ti0 - ti2;
                    r1 = tr1 + tr3, i1 = ti1 + ti3; // 1-3
                    r3 = ti1 - ti3, i3 = tr3 - tr1;
                    r4 = tr4 + tr6, i4 = ti4 + ti6; // 4-6
                    r6 = tr4 - tr6, i6 = ti4 - ti6;
                    tr5 = r5 + r7, ti5 = i5 + i7; // 5-7
                    tr7 = i5 - i7, ti7 = r7 - r5;
                    // 4xdif2
                    fft_input[0] = r0 + r1;
                    fft_input[1] = r0 - r1;
                    fft_input[2] = r2 + r3;
                    fft_input[3] = r2 - r3;
                    fft_input[4] = r4 + tr5;
                    fft_input[5] = r4 - tr5;
                    fft_input[6] = r6 + tr7;
                    fft_input[7] = r6 - tr7;
                    fft_input[RI_DIS] = i0 + i1;
                    fft_input[RI_DIS + 1] = i0 - i1;
                    fft_input[RI_DIS + 2] = i2 + i3;
                    fft_input[RI_DIS + 3] = i2 - i3;
                    fft_input[RI_DIS + 4] = i4 + ti5;
                    fft_input[RI_DIS + 5] = i4 - ti5;
                    fft_input[RI_DIS + 6] = i6 + ti7;
                    fft_input[RI_DIS + 7] = i6 - ti7;
                }
                static void fft_split_radix_dit(Iter fft_input)
                {
                    fft_dit_8point(fft_input);
                }
                static void fft_split_radix_dif(Iter fft_input)
                {
                    fft_dif_8point(fft_input);
                }
            };
            // 辅助选择函数
            template <size_t LEN = 1>
            void fft_split_radix_dit_template_alt(Iter input, size_t fft_len)
            {
                if (fft_len < LEN)
                {
                    fft_split_radix_dit_template_alt<LEN / 2>(input, fft_len);
                    return;
                }
                TABLE.expand_topdown(hint_log2(LEN));
                FFT<LEN, LEN>::fft_split_radix_dit(input);
            }
            template <>
            void fft_split_radix_dit_template_alt<0>(Iter input, size_t fft_len) {}

            // 辅助选择函数
            template <size_t LEN = 1>
            void fft_split_radix_dif_template_alt(Iter input, size_t fft_len)
            {
                if (fft_len < LEN)
                {
                    fft_split_radix_dif_template_alt<LEN / 2>(input, fft_len);
                    return;
                }
                TABLE.expand_topdown(hint_log2(LEN));
                FFT<LEN, LEN>::fft_split_radix_dif(input);
            }
            template <>
            void fft_split_radix_dif_template_alt<0>(Iter input, size_t fft_len) {}

            // auto fft_split_radix_dit = fft_split_radix_dit_template_alt<size_t(1) << lut_max_rank>;
            // auto fft_split_radix_dif = fft_split_radix_dif_template_alt<size_t(1) << lut_max_rank>;
        }
    }
    template <typename T>
    inline void ary_clr(T *ptr, size_t len)
    {
        memset(ptr, 0, len * sizeof(T));
    }
    inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
    {
        std::string result(digits, '0');
        for (UINT_8 i = 0; i < digits; i++)
        {
            result[digits - i - 1] = static_cast<char>(input % 10 + '0');
            input /= 10;
        }
        return result;
    }
    constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
    {
        UINT_64 result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }

    constexpr UINT_32 stobase10000(const char *s)
    {
        return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
    }
    static constexpr INT_64 DIGIT = 4;
    static constexpr INT_32 BASE = qpow(10, DIGIT);
    inline size_t char_to_real(const char *buffer1, size_t len1, HintFloat *ary, size_t fft_len)
    {
        hint::INT_64 len = len1, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            hint::UINT_32 tmp = stobase10000(buffer1 + pos - DIGIT);
            ary[i] = tmp;
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_32 tmp = stoui64(buffer1, pos);
            ary[i] = tmp;
            i++;
        }
        ary_clr(ary + i, fft_len - i);
        return len;
    }
    class ItoStrBase10000
    {
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num)
        {
            uint32_t res = '0' * 0x1010101;
            res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                   ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res;
        }
        constexpr ItoStrBase10000()
        {
            for (size_t i = 0; i < 10000; i++)
            {
                table[i] = itosbase10000(i);
            }
        }
        void tostr(char *str, uint32_t num) const
        {
            *reinterpret_cast<uint32_t *>(str) = table[num];
        }
        uint32_t tostr(uint32_t num) const
        {
            return table[num];
        }
    };
}
using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;

int main()
{
    constexpr size_t STR_LEN = 2000008;
    constexpr size_t fft_len = 1 << lut_max_rank;
    static constexpr ItoStrBase10000 transfer;
    static AlignAry<char, STR_LEN> out;
    static AlignAry<HintFloat, fft_len * 2> fft_arr;
    auto *fft_ary = fft_arr.data();
    uint32_t *ary = out.cast_ptr<uint32_t>();
    size_t len_a = 0, len_b = 0;
    fread(out.data(), 1, STR_LEN, stdin);
    // str_fill(out.data(), 1000000);
    // watch_default.start();
    char *p = out.data();
    /*
        struct stat sb;
        int fd = fileno(stdin);
        fstat(fd, &sb);
        p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
        madvise(p, sb.st_size, MADV_SEQUENTIAL);
    */

    while (p[len_a] >= '0')
    {
        len_a++;
    }
    if (len_a == 1 && p[0] == '0')
    {
        puts("0");
        return 0;
    }
    char *b = p + len_a;
    while (!isdigit(*b))
    {
        b++;
    }
    while (b[len_b] >= '0')
    {
        len_b++;
    }
    if (len_b == 1 && b[0] == '0')
    {
        puts("0");
        return 0;
    }
    size_t len1 = char_to_real(p, len_a, fft_ary, fft_len);
    size_t len2 = char_to_real(b, len_b, fft_ary + fft_len, fft_len);
    TABLE.expand(hint_log2(fft_len));
    FFT<fft_len, fft_len>::fft_split_radix_dif(fft_ary);
    const ComputeEndAVX invx4(0.5 / fft_len);
    for (size_t i = 0; i < fft_len; i += 4)
    {
        ComputeEndAVX real = fft_ary + i, imag = fft_ary + i + fft_len;
        ((-complex_mul_real(real, imag, real, imag)) * invx4).save(fft_ary + i);
        (complex_mul_imag(real, imag, real, imag) * invx4).save(fft_ary + i + fft_len);
    }
    FFT<fft_len, fft_len>::fft_split_radix_dit(fft_ary);
    UINT_64 carry = 0;
    size_t pos = STR_LEN / 4 - 1;
    constexpr HintFloat inv = -0.5 / fft_len;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        carry += UINT_64(fft_ary[i + fft_len] + 0.5);
        UINT_64 num = 0;
        std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
        ary[pos] = transfer.tostr(num);
        pos--;
    }
    ary[pos] = transfer.tostr(carry);
    pos *= 4;
    while (out[pos] == '0')
    {
        pos++;
    } // 0.8ms
    //watch_default.stop();
    fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);

    // putchar('\n');
    // cout << watch_default.duration() / 1000 << "ms\n";
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #117.059 ms19 MB + 872 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-13 22:04:27 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠