提交记录 20297


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 13.461 ms 20324 KB C++14 44.72 KB
提交时间 评测时间
2023-10-11 21:59:47 2023-10-11 21:59:49
// 给个STAR吧
// 再不济CV代码把这些链接留下吧,球球了
// https://github.com/With-Sky/HintFFT
// https://github.com/With-Sky/FFT-Benchmark
// https://github.com/With-Sky/HyperInt-mini
// https://space.bilibili.com/511540153
#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("fma")
#pragma GCC target("avx2")

namespace hint_simd
{
    template <typename T, size_t LEN>
    using Ary = T[LEN];
#define ALIGN_ARY(type, len, x) alignas(32) Ary<type, len>(x)
#define STATIC_ALIGN_ARY(type, len, x) alignas(32) static Ary<type, len>(x)
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(64) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };

    // Use AVX
    // 256bit simd
    using HintFloat = double;
    using Complex = std::complex<HintFloat>;

    // 2个复数并行
    struct Complex2
    {
        __m256d data;
        Complex2()
        {
            data = _mm256_setzero_pd();
        }
        Complex2(double input)
        {
            data = _mm256_set1_pd(input);
        }
        Complex2(__m256d input)
        {
            data = input;
        }
        Complex2(const Complex2 &input)
        {
            data = input.data;
        }
        // 从连续的数组构造
        Complex2(double const *ptr)
        {
            data = _mm256_load_pd(ptr);
        }
        Complex2(const Complex *ptr)
        {
            data = _mm256_load_pd((const double *)ptr);
        }
        Complex2(const Complex &a, const Complex &b)
        {
            data = _mm256_set_m128d(*(const __m128d *)&b, *(const __m128d *)&a);
        }
        void clr()
        {
            data = _mm256_setzero_pd();
        }
        void store(Complex *a) const
        {
            _mm256_store_pd((double *)a, data);
        }
        void print() const
        {
            double ary[4];
            _mm256_storeu_pd(ary, data);
            printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
        }
        template <int M>
        Complex2 element_mask_neg() const
        {
            static const __m256d neg_mask = _mm256_castsi256_pd(
                _mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
            return _mm256_xor_pd(data, neg_mask);
        }
        template <int M>
        Complex2 element_permute() const
        {
            return _mm256_permute_pd(data, M);
        }
        template <int M>
        Complex2 element_permute64() const
        {
            return _mm256_permute4x64_pd(data, M);
        }
        Complex2 all_real() const
        {
            return _mm256_unpacklo_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 0);
            // return _mm256_movedup_pd(data);
        }
        Complex2 all_imag() const
        {
            return _mm256_unpackhi_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 15);
            // return element_permute<0XF>();
        }
        Complex2 swap() const
        {
            return _mm256_shuffle_pd(data, data, 5);
            // return element_permute<0X5>();
        }
        Complex2 mul_neg_i() const
        {
            static const __m256d subber{};
            return Complex2(_mm256_addsub_pd(subber, data)).swap();
            // return swap().conj();
        }
        Complex2 conj() const
        {
            return element_mask_neg<10>();
        }
        Complex2 linear_mul(Complex2 input) const
        {
            return _mm256_mul_pd(data, input.data);
        }
        Complex2 square() const
        {
            const __m256d rr = all_real().data;
            const __m256d ir = swap().data;
            const __m256d add = _mm256_add_pd(rr, ir);
            const __m256d sub = _mm256_sub_pd(rr, ir);
            return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
        }
        Complex2 operator+(const Complex2 &input) const
        {
            return _mm256_add_pd(data, input.data);
        }
        Complex2 operator-(const Complex2 &input) const
        {
            return _mm256_sub_pd(data, input.data);
        }
        Complex2 operator*(const Complex2 &input) const
        {
            auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
            return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
        }
    };
}
#endif

#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"

#define TABLE_ENABLE 1
#define MULTITHREAD 0   // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
using namespace hint_simd;
namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using ULONG = unsigned long;
    using LONG = long;
    using HintFloat = double;
    using Complex = std::complex<HintFloat>;

    constexpr UINT_64 HINT_CHAR_BIT = 8;
    constexpr UINT_64 HINT_SHORT_BIT = 16;
    constexpr UINT_64 HINT_INT_BIT = 32;
    constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
    constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
    constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
    constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
    constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
    constexpr UINT_64 HINT_INT32_0X01 = 1;
    constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
    constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
    constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
    constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
    constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
    constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;

    constexpr HintFloat HINT_PI = 3.1415926535897932384626433832795;
    constexpr HintFloat HINT_2PI = HINT_PI * 2;
    constexpr HintFloat HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;

    constexpr UINT_64 NTT_MOD1 = 3221225473;
    constexpr UINT_64 NTT_ROOT1 = 5;
    constexpr UINT_64 NTT_MOD2 = 3489660929;
    constexpr UINT_64 NTT_ROOT2 = 3;
    constexpr size_t NTT_MAX_LEN = 1ull << 28;

    /// @brief 生成不大于n的最大的2的幂次的数
    /// @param n
    /// @return 不大于n的最大的2的幂次的数
    template <typename T>
    constexpr T max_2pow(T n)
    {
        T res = 1;
        res <<= (sizeof(T) * 8 - 1);
        while (res > n)
        {
            res /= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }
    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }
    template <typename T>
    constexpr T hint_log2(T n)
    {
        T res = 0;
        while (n > 1)
        {
            n /= 2;
            res++;
        }
        return res;
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, UINT_64 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    template <typename T>
    constexpr std::pair<T, T> div_mod(T a, T b)
    {
        return std::make_pair(a / b, a % b);
    }
#if MULTITHREAD == 1
    const UINT_32 hint_threads = std::thread::hardware_concurrency();
    const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
    std::atomic<UINT_32> cur_ths;
#endif

    // 模板数组拷贝
    template <typename T>
    void ary_copy(T *target, const T *source, size_t len)
    {
        if (len == 0 || target == source)
        {
            return;
        }
        if (len >= INT64_MAX)
        {
            throw("Ary too long\n");
        }
        std::memcpy(target, source, len * sizeof(T));
    }
    // 从其他类型数组拷贝到复数组
    template <typename T>
    inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
    {
        size_t min_len = std::min(len1, len2);
        size_t i = 0;
        while (i < min_len)
        {
            target[i] = Complex(source1[i], source2[i]);
            i++;
        }
        while (i < len1)
        {
            target[i].real(source1[i]);
            i++;
        }
        while (i < len2)
        {
            target[i].imag(source2[i]);
            i++;
        }
    }

    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T &ary, size_t len)
        {
            size_t i = 0;
            for (size_t j = 1; j < len - 1; j++)
            {
                size_t k = len >> 1;
                i ^= k;
                while (k > i)
                {
                    k >>= 1;
                    i ^= k;
                };
                if (j < i)
                {
                    std::swap(ary[i], ary[j]);
                }
            }
        }
        namespace hint_fft
        {
            // 返回单位圆上辐角为theta的点
            static Complex unit_root(HintFloat theta)
            {
                return std::polar<HintFloat>(1.0, theta);
            }
            // 返回单位圆上平分m份的第n个
            static Complex unit_root(size_t m, size_t n)
            {
                return unit_root((HINT_2PI * n) / m);
            }
            template <UINT_32 MAX_SHIFT>
            class ComplexTableC
            {
            private:
                enum
                {
                    TABLE_LEN = (size_t(1) << MAX_SHIFT) / 2
                };
                AlignAry<HintFloat, TABLE_LEN * 2> table1_ary;
                AlignAry<HintFloat, TABLE_LEN * 2> table3_ary;
                // __m256d table1_avx[TABLE_LEN / 2];
                // __m256d table3_avx[TABLE_LEN / 2];

                Complex *table1 = nullptr, *table3 = nullptr;
                INT_32 max_log_size = 2;
                INT_32 cur_log_size = 2;

                static constexpr size_t FAC = 1;

                ComplexTableC(const ComplexTableC &) = delete;
                ComplexTableC &operator=(const ComplexTableC &) = delete;

            public:
                // 初始化可以生成平分圆1<<shift份产生的单位根的表
                constexpr ComplexTableC()
                {
                    table1 = table1_ary.template cast_ptr<Complex>();
                    table3 = table3_ary.template cast_ptr<Complex>();
                    max_log_size = std::max<size_t>(MAX_SHIFT, 1);
                    table1[0] = table1[1] = table3[0] = table3[1] = Complex(1);
                    // expand(max_log_size);
                }
                constexpr Complex &table1_access(int shift, size_t n)
                {
                    return table1[(1 << shift) / 4 + n];
                }
                constexpr Complex &table3_access(int shift, size_t n)
                {
                    return table3[(1 << shift) / 4 + n];
                }
                constexpr void expand(INT_32 shift)
                {
                    expand_topdown(shift);
                }
                constexpr void expand_topdown(INT_32 shift)
                {
                    shift = std::min(shift, max_log_size);
                    if (shift <= cur_log_size)
                    {
                        return;
                    }
                    size_t len = 1ull << shift, vec_size = len * FAC / 4;
                    table1_access(shift, 0) = table3_access(shift, 0) = Complex(1, 0);
                    const HintFloat inv = -HINT_2PI / len;
                    for (size_t pos = 1; pos < vec_size / 2; pos *= 2)
                    {
                        table1_access(shift, pos) = unit_root(inv * pos);
                    }
                    table1_access(shift, 3) = table1_access(shift, 1) * table1_access(shift, 2);
                    for (size_t pos = 4; pos < vec_size / 2; pos += 2)
                    {
                        size_t sub_pos = pos & (pos - 1);
                        Complex tmp = table1_access(shift, pos - sub_pos);
                        Complex2 tmpx2(tmp, tmp);
                        Complex2 tmpx2_sub = &table1_access(shift, sub_pos);
                        (tmpx2 * tmpx2_sub).store(&table1_access(shift, pos));
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos++)
                    {
                        Complex tmp = table1_access(shift, pos);
                        table1_access(shift, vec_size - pos) = -Complex(tmp.imag(), tmp.real());
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos++)
                    {
                        Complex tmp = get_omega(shift, pos * 3);
                        table3_access(shift, pos) = tmp;
                        table3_access(shift, vec_size - pos) = Complex(tmp.imag(), tmp.real());
                    }
                    table1_access(shift, vec_size / 2) = std::conj(unit_root(8, 1));
                    table3_access(shift, vec_size / 2) = std::conj(unit_root(8, 3));
                    for (INT_32 log = shift - 1; log > cur_log_size; log--)
                    {
                        len = 1ull << log, vec_size = len / 4;
                        for (size_t pos = 0; pos < vec_size; pos++)
                        {
                            table1_access(log, pos) = table1_access(log + 1, pos * 2);
                        }
                        for (size_t pos = 1; pos < vec_size / 2; pos++)
                        {
                            Complex tmp = get_omega(log, pos * 3);
                            table3_access(log, pos) = tmp;
                            table3_access(log, vec_size - pos) = Complex(tmp.imag(), tmp.real());
                        }
                        table3_access(log, 0) = Complex(1, 0);
                        table3_access(log, vec_size / 2) = std::conj(unit_root(8, 3));
                    }
                    cur_log_size = std::max(cur_log_size, shift);
                }
                // shift表示圆平分为1<<shift份,3n表示第几个单位根
                constexpr Complex get_omega(UINT_32 shift, size_t n) const
                {
                    size_t vec_size = (size_t(1) << shift) / 4;
                    if (n < vec_size)
                    {
                        return table1[vec_size + n];
                    }
                    else if (n > vec_size)
                    {
                        Complex tmp = table1[vec_size + vec_size * 2 - n];
                        return Complex(-tmp.real(), tmp.imag());
                    }
                    else
                    {
                        return Complex(0, -1);
                    }
                }
                // shift表示圆平分为1<<shift份,3n表示第几个单位根
                Complex get_omega3(UINT_32 shift, size_t n) const
                {
                    return table3_access(shift, n);
                }
                // shift表示圆平分为1<<shift份,n表示第几个单位根
                Complex2 get_omegaX2(UINT_32 shift, size_t n) const
                {
                    return Complex2(table1 + (1 << (shift - 2)) + n);
                }
                // shift表示圆平分为1<<shift份,3n表示第几个单位根
                Complex2 get_omega3X2(UINT_32 shift, size_t n) const
                {
                    return Complex2(table3 + (1 << (shift - 2)) + n);
                }
                // shift表示圆平分为1<<shift份,n表示第几个单位根
                const Complex *get_omega_ptr(UINT_32 shift, size_t n) const
                {
                    return table1 + (1 << (shift - 2)) + n;
                }
                // shift表示圆平分为1<<shift份,3n表示第几个单位根
                const Complex *get_omega3_ptr(UINT_32 shift, size_t n) const
                {
                    return table3 + (1 << (shift - 2)) + n;
                }
            };
            constexpr size_t lut_max_rank = 19;
            // static ComplexTableY TABLE(lut_max_rank);
            static ComplexTableC<lut_max_rank> TABLE;
            // 2点fft
            template <typename T>
            inline void fft_2point(T &sum, T &diff)
            {
                T tmp0 = sum;
                T tmp1 = diff;
                sum = tmp0 + tmp1;
                diff = tmp0 - tmp1;
            }
            inline void fft_dit_4point(Complex *input, size_t rank = 1)
            {
                Complex tmp0 = input[0];
                Complex tmp1 = input[rank];
                Complex tmp2 = input[rank * 2];
                Complex tmp3 = input[rank * 3];

                fft_2point(tmp0, tmp1);
                fft_2point(tmp2, tmp3);
                tmp3 = Complex(tmp3.imag(), -tmp3.real());

                input[0] = tmp0 + tmp2;
                input[rank] = tmp1 + tmp3;
                input[rank * 2] = tmp0 - tmp2;
                input[rank * 3] = tmp1 - tmp3;
            }
            inline void fft_dif_4point(Complex *input, size_t rank = 1)
            {
                Complex tmp0 = input[0];
                Complex tmp1 = input[rank];
                Complex tmp2 = input[rank * 2];
                Complex tmp3 = input[rank * 3];

                fft_2point(tmp0, tmp2);
                fft_2point(tmp1, tmp3);
                tmp3 = Complex(tmp3.imag(), -tmp3.real());

                input[0] = tmp0 + tmp1;
                input[rank] = tmp0 - tmp1;
                input[rank * 2] = tmp2 + tmp3;
                input[rank * 3] = tmp2 - tmp3;
            }
            inline void fft_dit_4point_avx(Complex *input)
            {
                static const __m256d neg_mask = _mm256_castsi256_pd(
                    _mm256_set_epi64x(INT64_MIN, 0, 0, 0));
                __m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input));     // c0,c1
                __m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3

                __m256d tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
                __m256d tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3

                tmp0 = _mm256_add_pd(tmp2, tmp3); // c0+c1,c2+c3
                tmp1 = _mm256_sub_pd(tmp2, tmp3); // c0-c1,c2-c3

                tmp2 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
                tmp3 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3

                tmp3 = _mm256_permute_pd(tmp3, 0b0110);
                tmp3 = _mm256_xor_pd(tmp3, neg_mask); // (C,D)

                tmp0 = _mm256_add_pd(tmp2, tmp3); // A+C,B+D
                tmp1 = _mm256_sub_pd(tmp2, tmp3); // A-C,B-D

                _mm256_store_pd(reinterpret_cast<double *>(input), tmp0);
                _mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp1);
            }
            inline void fft_dit_8point_avx(Complex *input)
            {
                static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
                static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
                static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
                __m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input));     // c0,c1
                __m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
                __m256d tmp2 = _mm256_load_pd(reinterpret_cast<double *>(input + 4)); // c0,c1
                __m256d tmp3 = _mm256_load_pd(reinterpret_cast<double *>(input + 6)); // c2,c3

                __m256d tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0,c2
                __m256d tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c1,c3
                __m256d tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0,c2
                __m256d tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c1,c3

                tmp0 = _mm256_add_pd(tmp4, tmp5); // c0+c1,c2+c3
                tmp1 = _mm256_sub_pd(tmp4, tmp5); // c0-c1,c2-c3
                tmp2 = _mm256_add_pd(tmp6, tmp7); // c0+c1,c2+c3
                tmp3 = _mm256_sub_pd(tmp6, tmp7); // c0-c1,c2-c3

                tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20); // c0+c1,c0-c1;(A,B)
                tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31); // c2+c3,c2-c3
                tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // c0+c1,c0-c1;(A,B)
                tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // c2+c3,c2-c3

                tmp5 = _mm256_permute_pd(tmp5, 0b0110);
                tmp5 = _mm256_xor_pd(tmp5, neg_mask); // (C,D)
                tmp7 = _mm256_permute_pd(tmp7, 0b0110);
                tmp7 = _mm256_xor_pd(tmp7, neg_mask); // (C,D)

                tmp0 = _mm256_add_pd(tmp4, tmp5); // A+C,B+D
                tmp1 = _mm256_sub_pd(tmp4, tmp5); // A-C,B-D
                tmp2 = _mm256_add_pd(tmp6, tmp7); // A+C,B+D
                tmp3 = _mm256_sub_pd(tmp6, tmp7); // A-C,B-D

                // 2X4point-done
                tmp6 = _mm256_permute_pd(tmp2, 0b0110);
                tmp6 = _mm256_addsub_pd(tmp6, tmp2);
                tmp6 = _mm256_permute_pd(tmp6, 0b0110);
                tmp6 = _mm256_mul_pd(tmp6, mul1);
                tmp2 = _mm256_blend_pd(tmp2, tmp6, 0b1100);

                tmp7 = _mm256_permute_pd(tmp3, 0b0101);
                tmp3 = _mm256_addsub_pd(tmp3, tmp7);
                tmp3 = _mm256_blend_pd(tmp7, tmp3, 0b1100);
                tmp3 = _mm256_mul_pd(tmp3, mul2);

                tmp4 = _mm256_add_pd(tmp0, tmp2);
                tmp5 = _mm256_add_pd(tmp1, tmp3);
                tmp6 = _mm256_sub_pd(tmp0, tmp2);
                tmp7 = _mm256_sub_pd(tmp1, tmp3);
                _mm256_store_pd(reinterpret_cast<double *>(input), tmp4);
                _mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp5);
                _mm256_store_pd(reinterpret_cast<double *>(input + 4), tmp6);
                _mm256_store_pd(reinterpret_cast<double *>(input + 6), tmp7);
            }
            inline void fft_dif_4point_avx(Complex *input)
            {
                __m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input));     // c0,c1
                __m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3

                __m256d tmp2 = _mm256_add_pd(tmp0, tmp1); // c0+c2,c1+c3;
                __m256d tmp3 = _mm256_sub_pd(tmp0, tmp1); // c0-c2,c1-c3;
                tmp3 = _mm256_permute_pd(tmp3, 0b0110);   // c0-c2,r(c1-c3);

                static const __m256d neg_mask = _mm256_castsi256_pd(
                    _mm256_set_epi64x(INT64_MIN, 0, 0, 0));
                tmp3 = _mm256_xor_pd(tmp3, neg_mask);

                tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20); // A,C
                tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31); // B,D

                tmp2 = _mm256_add_pd(tmp0, tmp1); // A+B,C+D
                tmp3 = _mm256_sub_pd(tmp0, tmp1); // A-B,C-D

                tmp0 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
                tmp1 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);

                _mm256_store_pd(reinterpret_cast<double *>(input), tmp0);
                _mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp1);
            }
            inline void fft_dif_8point_avx(Complex *input)
            {
                static const __m256d neg_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, 0, 0));
                static const __m256d mul1 = _mm256_set_pd(0.70710678118654752440084436210485, 0.70710678118654752440084436210485, 0, 0);
                static const __m256d mul2 = _mm256_set_pd(-0.70710678118654752440084436210485, -0.70710678118654752440084436210485, -1, 1);
                __m256d tmp0 = _mm256_load_pd(reinterpret_cast<double *>(input));     // c0,c1
                __m256d tmp1 = _mm256_load_pd(reinterpret_cast<double *>(input + 2)); // c2,c3
                __m256d tmp2 = _mm256_load_pd(reinterpret_cast<double *>(input + 4)); // c4,c5
                __m256d tmp3 = _mm256_load_pd(reinterpret_cast<double *>(input + 6)); // c6,c7

                __m256d tmp4 = _mm256_add_pd(tmp0, tmp2);
                __m256d tmp5 = _mm256_add_pd(tmp1, tmp3);
                __m256d tmp6 = _mm256_sub_pd(tmp0, tmp2);
                __m256d tmp7 = _mm256_sub_pd(tmp1, tmp3);

                tmp2 = _mm256_permute_pd(tmp6, 0b0110);
                tmp2 = _mm256_addsub_pd(tmp2, tmp6);
                tmp2 = _mm256_permute_pd(tmp2, 0b0110);
                tmp2 = _mm256_mul_pd(tmp2, mul1);
                tmp6 = _mm256_blend_pd(tmp6, tmp2, 0b1100);

                tmp3 = _mm256_permute_pd(tmp7, 0b0101);
                tmp7 = _mm256_addsub_pd(tmp7, tmp3);
                tmp7 = _mm256_blend_pd(tmp3, tmp7, 0b1100);
                tmp7 = _mm256_mul_pd(tmp7, mul2);

                // 2X4point
                tmp0 = _mm256_add_pd(tmp4, tmp5);
                tmp1 = _mm256_sub_pd(tmp4, tmp5);
                tmp1 = _mm256_permute_pd(tmp1, 0b0110);
                tmp1 = _mm256_xor_pd(tmp1, neg_mask);

                tmp2 = _mm256_add_pd(tmp6, tmp7);
                tmp3 = _mm256_sub_pd(tmp6, tmp7);
                tmp3 = _mm256_permute_pd(tmp3, 0b0110);
                tmp3 = _mm256_xor_pd(tmp3, neg_mask);

                tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
                tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
                tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
                tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);

                tmp0 = _mm256_add_pd(tmp4, tmp5);
                tmp1 = _mm256_sub_pd(tmp4, tmp5);
                tmp2 = _mm256_add_pd(tmp6, tmp7);
                tmp3 = _mm256_sub_pd(tmp6, tmp7);

                tmp4 = _mm256_permute2f128_pd(tmp0, tmp1, 0x20);
                tmp5 = _mm256_permute2f128_pd(tmp0, tmp1, 0x31);
                tmp6 = _mm256_permute2f128_pd(tmp2, tmp3, 0x20);
                tmp7 = _mm256_permute2f128_pd(tmp2, tmp3, 0x31);

                _mm256_store_pd(reinterpret_cast<double *>(input), tmp4);
                _mm256_store_pd(reinterpret_cast<double *>(input + 2), tmp5);
                _mm256_store_pd(reinterpret_cast<double *>(input + 4), tmp6);
                _mm256_store_pd(reinterpret_cast<double *>(input + 6), tmp7);
            }

            // fft分裂基时间抽取蝶形变换
            inline void fft_split_radix_dit_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
                                                      Complex *input, size_t rank)
            {
                Complex2 tmp0 = input;
                Complex2 tmp1 = input + rank;
                Complex2 tmp2 = Complex2(input + rank * 2) * omega;
                Complex2 tmp3 = Complex2(input + rank * 3) * omega_cube;

                fft_2point(tmp2, tmp3);
                tmp3 = tmp3.mul_neg_i();

                (tmp0 + tmp2).store(input);
                (tmp1 + tmp3).store(input + rank);
                (tmp0 - tmp2).store(input + rank * 2);
                (tmp1 - tmp3).store(input + rank * 3);
            }
            // fft分裂基频率抽取蝶形变换
            inline void fft_split_radix_dif_butterfly(const Complex2 &omega, const Complex2 &omega_cube,
                                                      Complex *input, size_t rank)
            {
                Complex2 tmp0 = (input);
                Complex2 tmp1 = (input + rank);
                Complex2 tmp2 = (input + rank * 2);
                Complex2 tmp3 = (input + rank * 3);

                fft_2point(tmp0, tmp2);
                fft_2point(tmp1, tmp3);
                tmp3 = tmp3.mul_neg_i();

                tmp0.store(input);
                tmp1.store(input + rank);
                ((tmp2 + tmp3) * omega).store(input + rank * 2);
                ((tmp2 - tmp3) * omega_cube).store(input + rank * 3);
            }
            // fft分裂基时间抽取蝶形变换
            inline void fft_split_radix_dit_butterfly(const Complex *omega, const Complex *omega_cube,
                                                      Complex *input, size_t rank)
            {
                Complex2 tmp0 = input;
                Complex2 tmp4 = input + 2;
                Complex2 tmp1 = input + rank;
                Complex2 tmp5 = input + rank + 2;
                Complex2 tmp2 = Complex2(input + rank * 2) * Complex2(omega);
                Complex2 tmp6 = Complex2(input + rank * 2 + 2) * Complex2(omega + 2);
                Complex2 tmp3 = Complex2(input + rank * 3) * Complex2(omega_cube);
                Complex2 tmp7 = Complex2(input + rank * 3 + 2) * Complex2(omega_cube + 2);

                fft_2point(tmp2, tmp3);
                fft_2point(tmp6, tmp7);
                tmp3 = tmp3.mul_neg_i();
                tmp7 = tmp7.mul_neg_i();

                (tmp0 + tmp2).store(input);
                (tmp4 + tmp6).store(input + 2);
                (tmp1 + tmp3).store(input + rank);
                (tmp5 + tmp7).store(input + rank + 2);
                (tmp0 - tmp2).store(input + rank * 2);
                (tmp4 - tmp6).store(input + rank * 2 + 2);
                (tmp1 - tmp3).store(input + rank * 3);
                (tmp5 - tmp7).store(input + rank * 3 + 2);
            }
            // fft分裂基频率抽取蝶形变换
            inline void fft_split_radix_dif_butterfly(const Complex *omega, const Complex *omega_cube,
                                                      Complex *input, size_t rank)
            {
                Complex2 tmp0 = input;
                Complex2 tmp4 = input + 2;
                Complex2 tmp1 = input + rank;
                Complex2 tmp5 = input + rank + 2;
                Complex2 tmp2 = input + rank * 2;
                Complex2 tmp6 = input + rank * 2 + 2;
                Complex2 tmp3 = input + rank * 3;
                Complex2 tmp7 = input + rank * 3 + 2;

                fft_2point(tmp0, tmp2);
                fft_2point(tmp1, tmp3);
                fft_2point(tmp4, tmp6);
                fft_2point(tmp5, tmp7);
                tmp3 = tmp3.mul_neg_i();
                tmp7 = tmp7.mul_neg_i();

                tmp0.store(input);
                tmp4.store(input + 2);
                tmp1.store(input + rank);
                tmp5.store(input + rank + 2);
                ((tmp2 + tmp3) * Complex2(omega)).store(input + rank * 2);
                ((tmp6 + tmp7) * Complex2(omega + 2)).store(input + rank * 2 + 2);
                ((tmp2 - tmp3) * Complex2(omega_cube)).store(input + rank * 3);
                ((tmp6 - tmp7) * Complex2(omega_cube + 2)).store(input + rank * 3 + 2);
            }

            // 模板化时间抽取分裂基fft
            static constexpr HintFloat cos_1_8 = 0.70710678118654752440084436210485;
            static constexpr HintFloat cos_1_16 = 0.92387953251128675612818318939679;
            static constexpr HintFloat sin_1_16 = 0.3826834323650897717284599840304;
            static constexpr Complex w1(cos_1_16, -sin_1_16), w3(sin_1_16, -cos_1_16), w9(-cos_1_16, sin_1_16);
            static constexpr Complex omega1_table[4] = {Complex(1), w1, Complex(cos_1_8, -cos_1_8), w3};
            static constexpr Complex omega3_table[4] = {Complex(1), w3, Complex(-cos_1_8, -cos_1_8), w9};
            static const Complex2 omega0(omega1_table), omega1(omega1_table + 2);
            static const Complex2 omega_cu0(omega3_table), omega_cu1(omega3_table + 2);
            template <size_t LEN>
            void fft_split_radix_dit_template(Complex *input)
            {
                constexpr size_t log_len = hint_log2(LEN);
                constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
                fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
                fft_split_radix_dit_template<quarter_len>(input + half_len);
                fft_split_radix_dit_template<half_len>(input);
#if TABLE_ENABLE == 1
                auto omega = TABLE.get_omega_ptr(log_len, 0);
                auto omega_cube = TABLE.get_omega3_ptr(log_len, 0);
                for (size_t i = 0; i < quarter_len; i += 4)
                {
                    fft_split_radix_dit_butterfly(omega + i, omega_cube + i, input + i, quarter_len);
                }
#else
                static const Complex unit1 = std::conj(unit_root(LEN, 1));
                static const Complex unit2 = std::conj(unit_root(LEN, 2));
                static const Complex unit3 = std::conj(unit_root(LEN, 3));
                static const Complex unit6 = std::conj(unit_root(LEN, 6));
                static const Complex unit9 = std::conj(unit_root(LEN, 9));
                static const Complex unit4 = std::conj(unit_root(LEN, 4));
                static const Complex unit12 = std::conj(unit_root(LEN, 12));

                static const Complex2 unit(unit4, unit4);
                static const Complex2 unit_cube(unit12, unit12);
                Complex2 omega1(Complex(1, 0), unit1);
                Complex2 omega2(unit2, unit3);
                Complex2 omega_cube1(Complex(1, 0), unit3);
                Complex2 omega_cube2(unit6, unit9);
                for (size_t i = 0; i < quarter_len; i += 4)
                {
                    fft_split_radix_dit_butterfly(omega1, omega_cube1, input + i, quarter_len);
                    fft_split_radix_dit_butterfly(omega2, omega_cube2, input + i + 2, quarter_len);
                    omega1 = omega1 * unit;
                    omega2 = omega2 * unit;
                    omega_cube1 = omega_cube1 * unit_cube;
                    omega_cube2 = omega_cube2 * unit_cube;
                }
#endif
            }
            template <>
            inline void fft_split_radix_dit_template<0>(Complex *input) {}
            template <>
            inline void fft_split_radix_dit_template<1>(Complex *input) {}
            template <>
            inline void fft_split_radix_dit_template<2>(Complex *input)
            {
                fft_2point(input[0], input[1]);
            }
            template <>
            inline void fft_split_radix_dit_template<4>(Complex *input)
            {
                fft_dit_4point_avx(input);
            }
            template <>
            inline void fft_split_radix_dit_template<8>(Complex *input)
            {
                fft_dit_8point_avx(input);
            }
            template <>
            inline void fft_split_radix_dit_template<16>(Complex *input)
            {
                fft_dit_4point_avx(input + 12);
                fft_dit_4point_avx(input + 8);
                fft_dit_8point_avx(input);
                fft_split_radix_dit_butterfly(omega0, omega_cu0, input, 4);
                fft_split_radix_dit_butterfly(omega1, omega_cu1, input + 2, 4);
            }
            // 模板化频率抽取分裂基fft
            template <size_t LEN>
            void fft_split_radix_dif_template(Complex *input)
            {
                constexpr size_t log_len = hint_log2(LEN);
                constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
#if TABLE_ENABLE == 1
                auto omega = TABLE.get_omega_ptr(log_len, 0);
                auto omega_cube = TABLE.get_omega3_ptr(log_len, 0);
                for (size_t i = 0; i < quarter_len; i += 4)
                {
                    fft_split_radix_dif_butterfly(omega + i, omega_cube + i, input + i, quarter_len);
                }
#else
                static const Complex unit1 = std::conj(unit_root(LEN, 1));
                static const Complex unit2 = std::conj(unit_root(LEN, 2));
                static const Complex unit3 = std::conj(unit_root(LEN, 3));
                static const Complex unit6 = std::conj(unit_root(LEN, 6));
                static const Complex unit9 = std::conj(unit_root(LEN, 9));
                static const Complex unit4 = std::conj(unit_root(LEN, 4));
                static const Complex unit12 = std::conj(unit_root(LEN, 12));

                static const Complex2 unit(unit4, unit4);
                static const Complex2 unit_cube(unit12, unit12);
                Complex2 omega1(Complex(1, 0), unit1);
                Complex2 omega2(unit2, unit3);
                Complex2 omega_cube1(Complex(1, 0), unit3);
                Complex2 omega_cube2(unit6, unit9);
                for (size_t i = 0; i < quarter_len; i += 4)
                {
                    fft_split_radix_dif_butterfly(omega1, omega_cube1, input + i, quarter_len);
                    fft_split_radix_dif_butterfly(omega2, omega_cube2, input + i + 2, quarter_len);
                    omega1 = omega1 * unit;
                    omega2 = omega2 * unit;
                    omega_cube1 = omega_cube1 * unit_cube;
                    omega_cube2 = omega_cube2 * unit_cube;
                }
#endif
                fft_split_radix_dif_template<half_len>(input);
                fft_split_radix_dif_template<quarter_len>(input + half_len);
                fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
            }
            template <>
            inline void fft_split_radix_dif_template<0>(Complex *input) {}
            template <>
            inline void fft_split_radix_dif_template<1>(Complex *input) {}
            template <>
            inline void fft_split_radix_dif_template<2>(Complex *input)
            {
                fft_2point(input[0], input[1]);
            }
            template <>
            inline void fft_split_radix_dif_template<4>(Complex *input)
            {
                fft_dif_4point_avx(input);
            }
            template <>
            inline void fft_split_radix_dif_template<8>(Complex *input)
            {
                fft_dif_8point_avx(input);
            }
            template <>
            inline void fft_split_radix_dif_template<16>(Complex *input)
            {
                fft_split_radix_dif_butterfly(omega0, omega_cu0, input, 4);
                fft_split_radix_dif_butterfly(omega1, omega_cu1, input + 2, 4);
                fft_dif_8point_avx(input);
                fft_dif_4point_avx(input + 8);
                fft_dif_4point_avx(input + 12);
            }

            // 辅助选择函数
            template <size_t LEN = 1>
            void fft_split_radix_dit_template_alt(Complex *input, size_t fft_len)
            {
                if (fft_len < LEN)
                {
                    fft_split_radix_dit_template_alt<LEN / 2>(input, fft_len);
                    return;
                }
                fft_split_radix_dit_template<LEN>(input);
            }
            template <>
            void fft_split_radix_dit_template_alt<0>(Complex *input, size_t fft_len) {}

            // 辅助选择函数
            template <size_t LEN = 1>
            void fft_split_radix_dif_template_alt(Complex *input, size_t fft_len)
            {
                if (fft_len < LEN)
                {
                    fft_split_radix_dif_template_alt<LEN / 2>(input, fft_len);
                    return;
                }
                fft_split_radix_dif_template<LEN>(input);
            }
            template <>
            void fft_split_radix_dif_template_alt<0>(Complex *input, size_t fft_len) {}

            auto fft_split_radix_dit = fft_split_radix_dit_template_alt<size_t(1) << lut_max_rank>;
            auto fft_split_radix_dif = fft_split_radix_dif_template_alt<size_t(1) << lut_max_rank>;
        }
    }
    inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
    {
        std::string result(digits, '0');
        for (UINT_8 i = 0; i < digits; i++)
        {
            result[digits - i - 1] = static_cast<char>(input % 10 + '0');
            input /= 10;
        }
        return result;
    }
    constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
    {
        UINT_64 result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }

    constexpr UINT_32 stobase10000(const char *s)
    {
        return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
    }
    constexpr UINT_32 stobase100000(const char *s)
    {
        return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
    }
    static constexpr INT_64 DIGIT = 4;
    inline size_t char_to_real(const char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            // hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
            hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
            comary[i].real(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_32 tmp = stoui64(buffer, pos);
            comary[i].real(tmp);
        }
        return len;
    }
    inline size_t char_to_imag(const char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            // hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
            hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
            comary[i].imag(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_32 tmp = stoui64(buffer, pos);
            comary[i].imag(tmp);
        }
        return len;
    }
    // 读取两个数字字符串
    void read_2num_str(const char *s, const char *&a, size_t &len_a, const char *&b, size_t &len_b)
    {
        while (!isdigit(*s))
        {
            s++;
        }
        a = s;
        while (*s >= '0')
        {
            s++;
        }
        len_a = s - a;
        while (!isdigit(*s))
        {
            s++;
        }
        b = s;
        len_b = strlen(b);
        while (!isdigit(b[len_b - 1]))
        {
            len_b--;
        }
    }
    class ItoStrBase10000
    {
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num)
        {
            uint32_t res = '0' * 0x1010101;
            res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                   ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res;
        }
        constexpr ItoStrBase10000()
        {
            for (size_t i = 0; i < 10000; i++)
            {
                table[i] = itosbase10000(i);
            }
        }
        void tostr(char *str, uint32_t num) const
        {
            *reinterpret_cast<uint32_t *>(str) = table[num];
        }
        uint32_t tostr(uint32_t num) const
        {
            return table[num];
        }
    };
}

using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;

int main()
{
    constexpr size_t STR_LEN = 2000008;
    constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
    static constexpr ItoStrBase10000 transfer;
    static AlignAry<char, STR_LEN> out;
    static AlignAry<HintFloat, 1 << 20> fft_ary_avx;
    Complex *fft_ary = fft_ary_avx.template cast_ptr<Complex>();
    uint32_t *ary = out.template cast_ptr<uint32_t>();
    size_t len_a = 0, len_b = 0;
    size_t len = fread(out.data(), 1, STR_LEN, stdin);
    const char *a, *b;
    read_2num_str(out.data(), a, len_a, b, len_b);
    /*
        struct stat sb;
        int fd = fileno(stdin);
        fstat(fd, &sb);
        p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
        madvise(p, sb.st_size, MADV_SEQUENTIAL);
    */
    if (len_a == 1 && a[0] == '0')
    {
        puts("0");
        return 0;
    }
    if (len_b == 1 && b[0] == '0')
    {
        puts("0");
        return 0;
    } // 0.46ms
    size_t len1 = char_to_real(a, fft_ary, len_a);
    size_t len2 = char_to_imag(b, fft_ary, len_b); // 1.67ms
    size_t fft_len = int_ceil2(len1 + len2 - 1);
    TABLE.expand(lut_max_rank); // 3.5ms
    fft_split_radix_dif_template<1 << lut_max_rank>(fft_ary);
    HintFloat inv = -0.5 / fft_len;
    const Complex2 invx4(inv);
    for (size_t i = 0; i < fft_len; i += 2)
    {
        Complex2 tmp = fft_ary + i;
        tmp = tmp.square().linear_mul(invx4);
        (tmp.conj()).store(fft_ary + i);
    }
    fft_split_radix_dit_template<1 << lut_max_rank>(fft_ary); // 6ms
    UINT_64 carry = 0;
    size_t pos = STR_LEN / 4 - 1;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        carry += UINT_64(fft_ary[i].imag() + 0.5);
        UINT_64 num = 0;
        std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
        ary[pos] = transfer.tostr(num);
        pos--;
    }
    ary[pos] = transfer.tostr(carry);
    pos *= 4;
    while (out[pos] == '0')
    {
        pos++;
    } // 0.8ms
    fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #113.461 ms19 MB + 868 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2024-12-04 02:03:51 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠