提交记录 19408


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Runtime Error 0 34.2 us 32 KB C++14 42.46 KB
提交时间 评测时间
2023-05-08 14:46:50 2023-05-08 14:46:54
// https://github.com/With-Sky/HintFFT
// https://github.com/With-Sky/FFT-Benchmark
// https://github.com/With-Sky/HyperInt-mini
// https://space.bilibili.com/511540153

#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("fma")

// Use AVX
// 256bit simd
using HintFloat = double;
using Complex = std::complex<HintFloat>;
// 2个复数并行
struct Complex2
{
    __m256d data;
    Complex2()
    {
        // data = _mm256_setzero_pd();
    }
    Complex2(double input)
    {
        data = _mm256_set1_pd(input);
    }
    Complex2(__m256d input)
    {
        data = input;
    }
    Complex2(const Complex2 &input)
    {
        data = input.data;
    }
    // 从连续的数组构造
    Complex2(double const *ptr)
    {
        data = _mm256_loadu_pd(ptr);
    }
    Complex2(Complex a)
    {
        data = _mm256_broadcast_pd((__m128d *)&a);
    }
    Complex2(Complex a, Complex b)
    {
        data = _mm256_set_m128d(*(__m128d *)&b, *(__m128d *)&a);
    }
    Complex2(const Complex *ptr)
    {
        data = _mm256_loadu_pd((const double *)ptr);
    }
    void clr()
    {
        data = _mm256_setzero_pd();
    }
    void store(Complex *a) const
    {
        _mm256_storeu_pd((double *)a, data);
    }
    void print() const
    {
        double ary[4];
        _mm256_storeu_pd(ary, data);
        printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
    }
    template <int M>
    Complex2 element_permute() const
    {
        return _mm256_permute_pd(data, M);
    }
    Complex2 all_real() const
    {
        return _mm256_movedup_pd(data);
    }
    Complex2 all_imag() const
    {
        return element_permute<0XF>();
    }
    Complex2 swap() const
    {
        return element_permute<0X5>();
    }
    Complex2 mul_neg_i() const
    {
        return swap().conj();
    }
    Complex2 conj() const
    {
        static const __m256d inv_mask = _mm256_castsi256_pd(_mm256_set_epi64x(INT64_MIN, 0, INT64_MIN, 0));
        return _mm256_xor_pd(data, inv_mask);
    }
    Complex2 linear_mul(Complex2 input) const
    {
        return _mm256_mul_pd(data, input.data);
    }
    Complex2 operator+(Complex2 input) const
    {
        return _mm256_add_pd(data, input.data);
    }
    Complex2 operator-(Complex2 input) const
    {
        return _mm256_sub_pd(data, input.data);
    }
    Complex2 operator*(Complex2 input) const
    {
        auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
        return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
    }
    Complex2 operator/(Complex2 input) const
    {
        return _mm256_div_pd(data, input.data);
    }
};
#endif

#include <vector>
#include <complex>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"

#define MULTITHREAD 0   // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes

namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using ULONG = unsigned long;
    using LONG = long;
    using Complex = std::complex<double>;

    constexpr UINT_64 HINT_CHAR_BIT = 8;
    constexpr UINT_64 HINT_SHORT_BIT = 16;
    constexpr UINT_64 HINT_INT_BIT = 32;
    constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
    constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
    constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
    constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
    constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
    constexpr UINT_64 HINT_INT32_0X01 = 1;
    constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
    constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
    constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
    constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
    constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
    constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;

    constexpr double HINT_PI = 3.1415926535897932384626433832795;
    constexpr double HINT_2PI = HINT_PI * 2;
    constexpr double HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;

    constexpr UINT_64 NTT_MOD1 = 3221225473;
    constexpr UINT_64 NTT_ROOT1 = 5;
    constexpr UINT_64 NTT_MOD2 = 3489660929;
    constexpr UINT_64 NTT_ROOT2 = 3;
    constexpr size_t NTT_MAX_LEN = 1ull << 28;

    /// @brief 生成不大于n的最大的2的幂次的数
    /// @param n
    /// @return 不大于n的最大的2的幂次的数
    template <typename T>
    constexpr T max_2pow(T n)
    {
        T res = 1;
        res <<= (sizeof(T) * 8 - 1);
        while (res > n)
        {
            res /= 2;
        }
        return res;
    }
    /// @brief 生成不小于n的最小的2的幂次的数
    /// @param n
    /// @return 不小于n的最小的2的幂次的数
    template <typename T>
    constexpr T min_2pow(T n)
    {
        T res = 1;
        while (res < n)
        {
            res *= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T hint_log2(T n)
    {
        T res = 0;
        while (n > 1)
        {
            n /= 2;
            res++;
        }
        return res;
    }
#if MULTITHREAD == 1
    const UINT_32 hint_threads = std::thread::hardware_concurrency();
    const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
    std::atomic<UINT_32> cur_ths;
#endif

    // 模板数组拷贝
    template <typename T>
    void ary_copy(T *target, const T *source, size_t len)
    {
        if (len == 0 || target == source)
        {
            return;
        }
        if (len >= INT64_MAX)
        {
            throw("Ary too long\n");
        }
        std::memcpy(target, source, len * sizeof(T));
    }
    // 重分配空间
    template <typename T>
    inline T *ary_realloc(T *ptr, size_t len)
    {
        if (len * sizeof(T) < INT64_MAX)
        {
            ptr = static_cast<T *>(realloc(ptr, len * sizeof(T)));
        }
        if (ptr == nullptr)
        {
            throw("realloc error");
        }
        return ptr;
    }
    // 从其他类型数组拷贝到复数组
    template <typename T>
    inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
    {
        size_t min_len = std::min(len1, len2);
        size_t i = 0;
        while (i < min_len)
        {
            target[i] = Complex(source1[i], source2[i]);
            i++;
        }
        while (i < len1)
        {
            target[i].real(source1[i]);
            i++;
        }
        while (i < len2)
        {
            target[i].imag(source2[i]);
            i++;
        }
    }
    // 数组交错重排
    template <UINT_64 N, typename T>
    void ary_interlace(T ary[], size_t len)
    {
        size_t sub_len = len / N;
        T *tmp_ary = new T[len - sub_len];
        for (size_t i = 0; i < len; i += N)
        {
            ary[i / N] = ary[i];
            for (size_t j = 0; j < N - 1; j++)
            {
                tmp_ary[j * sub_len + i / N] = ary[i + j + 1];
            }
        }
        ary_copy(ary + sub_len, tmp_ary, len - sub_len);
        delete[] tmp_ary;
    }
    template <typename T = UINT_32, typename SIZE_TYPE = UINT_32>
    class HintVector
    {
    private:
        T *ary_ptr = nullptr;
        SIZE_TYPE sign_n_len = 0;
        SIZE_TYPE size = 0;
        static constexpr SIZE_TYPE SIZE_TYPE_BITS = sizeof(SIZE_TYPE) * 8;   // size和len成员的比特数
        static constexpr SIZE_TYPE SIZE_80 = (1ull << (SIZE_TYPE_BITS - 1)); // 第一位为1,其余位为0
        static constexpr SIZE_TYPE LEN_MAX = SIZE_80 - 1;                    // 定义最大长度

    public:
        ~HintVector()
        {
            if (ary_ptr != nullptr)
            {
                delete[] ary_ptr;
                ary_ptr = nullptr;
            }
        }
        HintVector()
        {
            ary_ptr = nullptr;
        }
        HintVector(SIZE_TYPE ele_size)
        {
            if (ele_size > 0)
            {
                resize(ele_size);
            }
        }
        HintVector(SIZE_TYPE ele_size, T ele)
        {
            if (ele_size > 0)
            {
                resize(ele_size);
                change_length(ele_size);
                fill_element(ele, ele_size);
            }
        }
        HintVector(const T *ary, SIZE_TYPE len)
        {
            if (ary == nullptr)
            {
                throw("Can't copy from nullptr\n");
            }
            if (len > 0)
            {
                resize(len);
                change_length(len);
                ary_copy(ary_ptr, ary, len);
            }
        }
        HintVector(const HintVector &input)
        {
            if (input.ary_ptr != nullptr)
            {
                if (ary_ptr != nullptr)
                {
                    delete[] ary_ptr;
                }
                ary_ptr = nullptr;
                size_t len = input.length();
                resize(len);
                change_length(len);
                change_sign(input.sign());
                ary_copy(ary_ptr, input.ary_ptr, len);
            }
        }
        HintVector(HintVector &&input)
        {
            if (input.ary_ptr != nullptr)
            {
                size = input.size;
                change_length(input.length());
                change_sign(input.sign());
                if (ary_ptr != nullptr)
                {
                    delete[] ary_ptr;
                }
                ary_ptr = input.ary_ptr;
                input.ary_ptr = nullptr;
            }
        }
        HintVector &operator=(const HintVector &input)
        {
            if (this != &input)
            {
                if (ary_ptr != nullptr)
                {
                    delete[] ary_ptr;
                }
                ary_ptr = nullptr;
                size_t len = input.length();
                resize(len);
                change_length(len);
                change_sign(input.sign());
                ary_copy(ary_ptr, input.ary_ptr, len);
            }
            return *this;
        }
        HintVector &operator=(HintVector &&input)
        {
            if (this != &input)
            {
                size = input.size;
                change_length(input.length());
                change_sign(input.sign());
                if (ary_ptr != nullptr)
                {
                    delete[] ary_ptr;
                }
                ary_ptr = input.ary_ptr;
                input.ary_ptr = nullptr;
            }
            return *this;
        }
        template <typename Ty, typename SIZE_Ty>
        void copy_from(const HintVector<Ty, SIZE_Ty> &input)
        {
            if (this != &input)
            {
                if (input.length() > LEN_MAX)
                {
                    std::cerr << "HintVector err: too long\n";
                    exit(EXIT_FAILURE);
                }
                delete[] ary_ptr;
                ary_ptr = nullptr;
                SIZE_TYPE len = input.length();
                resize(len);
                change_length(len);
                change_sign(input.sign());
                ary_copy_2type(ary_ptr, input.raw_ptr(), len);
            }
        }
        template <typename Ty>
        void copy_from(const Ty *input, SIZE_TYPE len)
        {
            if (input == nullptr)
            {
                throw("Can't copy from nullptr\n");
            }
            len = min(len, length());
            ary_copy(ary_ptr, input, len);
        }
        T front() const
        {
            if (length() == 0)
            {
                return 0;
            }
            return ary_ptr[0];
        }
        T back() const
        {
            if (length() == 0)
            {
                return 0;
            }
            return ary_ptr[length() - 1];
        }
        T &operator[](SIZE_TYPE index) const
        {
            return ary_ptr[index];
        }
        void *raw_ptr() const
        {
            if (ary_ptr == nullptr)
            {
                std::cerr << "HintVector err: Get a ptr from an empty vector\n";
                exit(EXIT_FAILURE);
            }
            return ary_ptr;
        }
        constexpr T *data() const
        {
            return ary_ptr;
        }
        static SIZE_TYPE size_generator(SIZE_TYPE new_size)
        {
            new_size = std::min<SIZE_TYPE>(new_size, LEN_MAX);
            if (new_size <= 2)
            {
                return 2;
            }
            else if (new_size <= 4)
            {
                return 4;
            }
            SIZE_TYPE size1 = min_2pow(new_size);
            SIZE_TYPE size2 = size1 / 2;
            size2 = size2 + size2 / 2;
            if (new_size <= size2)
            {
                return size2;
            }
            else
            {
                return size1;
            }
        }
        void print() const
        {
            SIZE_TYPE i = length();
            if (i == 0 || ary_ptr == nullptr)
            {
                std::cout << "Empty vector" << std::endl;
            }
            while (i > 0)
            {
                i--;
                std::cout << ary_ptr[i] << "\t";
            }
            std::cout << std::endl;
        }
        constexpr SIZE_TYPE capacity() const
        {
            return size;
        }
        constexpr SIZE_TYPE length() const
        {
            return sign_n_len & LEN_MAX;
        }
        bool sign() const
        {
            return (SIZE_80 & sign_n_len) != 0;
        }
        constexpr void resize(SIZE_TYPE new_size, SIZE_TYPE(size_func)(SIZE_TYPE) = size_generator)
        {
            new_size = size_func(new_size);
            if (ary_ptr == nullptr)
            {
                size = new_size;
                ary_ptr = new T[size]{};
                change_length(0);
            }
            else if (size != new_size)
            {
                size = new_size;
                ary_ptr = ary_realloc(ary_ptr, size);
                change_length(length());
            }
        }
        void change_length(SIZE_TYPE new_length)
        {
            if (new_length > LEN_MAX)
            {
                throw("Length too long\n");
            }
            new_length = std::min(new_length, size);
            bool sign_tmp = sign();
            sign_n_len = new_length;
            change_sign(sign_tmp);
        }
        void change_sign(bool is_sign)
        {
            if ((!is_sign) || length() == 0)
            {
                sign_n_len = sign_n_len & LEN_MAX;
            }
            else
            {
                sign_n_len = sign_n_len | SIZE_80;
            }
        }
        SIZE_TYPE set_true_len()
        {
            if (ary_ptr == nullptr)
            {
                return 0;
            }
            SIZE_TYPE t_len = length();
            while (t_len > 0 && ary_ptr[t_len - 1] == 0)
            {
                t_len--;
            }
            change_length(t_len);
            return length();
        }
        void fill_element(T ele, SIZE_TYPE count, SIZE_TYPE begin = 0)
        {
            if (begin >= size || ary_ptr == nullptr)
            {
                return;
            }
            if (begin + count >= size)
            {
                count = size - begin;
            }
            std::fill(ary_ptr + begin, ary_ptr + begin + count, ele);
        }
        void clear()
        {
            ary_clr(ary_ptr, size);
        }
    };
    template <typename T, typename SIZE_TYPE>
    constexpr SIZE_TYPE HintVector<T, SIZE_TYPE>::SIZE_TYPE_BITS;
    template <typename T, typename SIZE_TYPE>
    constexpr SIZE_TYPE HintVector<T, SIZE_TYPE>::SIZE_80;
    template <typename T, typename SIZE_TYPE>
    constexpr SIZE_TYPE HintVector<T, SIZE_TYPE>::LEN_MAX;
    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        class ComplexTableY
        {
        private:
            std::vector<HintVector<Complex2>> table1;
            std::vector<HintVector<Complex2>> table3;
            INT_32 max_log_size = 2;
            INT_32 cur_log_size = 2;

            static constexpr size_t FAC = 1;

            ComplexTableY(const ComplexTableY &) = delete;
            ComplexTableY &operator=(const ComplexTableY &) = delete;

        public:
            ~ComplexTableY() {}
            // 初始化可以生成平分圆1<<shift份产生的单位根的表
            ComplexTableY(UINT_32 max_shift)
            {
                max_shift = std::max<size_t>(max_shift, 1);
                max_log_size = max_shift;
                table1.resize(max_shift + 1);
                table3.resize(max_shift + 1);
                static const Complex2 one(Complex(1, 0), Complex(1, 0));
                static const Complex2 two(Complex(1, 0), Complex(-1, 0));
                table1[0] = table1[1] = table3[0] = table3[1] = HintVector<Complex2>(1, one);
                table1[2] = table3[2] = HintVector<Complex2>(1, two);
#if TABLE_PRELOAD == 1
                expand(max_shift);
#endif
            }
            void expand(INT_32 shift)
            {
                shift = std::max<INT_32>(shift, 2);
                if (shift > max_log_size)
                {
                    throw("FFT length too long for lut\n");
                }
                for (INT_32 i = cur_log_size + 1; i <= shift; i++)
                {
                    size_t len = 1ull << i, vec_size = len * FAC / 4;
                    table1[i].resize(vec_size / 2);
                    table3[i].resize(vec_size / 2);
                    auto ptr1 = reinterpret_cast<Complex *>(table1[i].data());
                    auto ptr3 = reinterpret_cast<Complex *>(table3[i].data());
                    auto last_ptr1 = reinterpret_cast<Complex *>(table1[i - 1].data());
                    auto last_ptr3 = reinterpret_cast<Complex *>(table3[i - 1].data());
                    ptr1[0] = ptr3[0] = Complex(1, 0);
                    for (size_t pos = 0; pos < vec_size / 2; pos++)
                    {
                        ptr1[pos * 2] = last_ptr1[pos];
                        ptr3[pos * 2] = last_ptr3[pos];
                    }
                    for (size_t pos = 1; pos < vec_size / 2; pos += 2)
                    {
                        double cos_theta = std::cos(HINT_2PI * pos / len);
                        double sin_theta = std::sin(HINT_2PI * pos / len);
                        ptr1[pos] = Complex(cos_theta, -sin_theta);
                        ptr1[vec_size - pos] = Complex(sin_theta, -cos_theta);
                    }
                    ptr1[vec_size / 2] = std::conj(unit_root(8, 1));
                    for (size_t pos = 1; pos < vec_size / 2; pos += 2)
                    {
                        Complex tmp = get_omega(i, pos * 3);
                        ptr3[pos] = tmp;
                        ptr3[vec_size - pos] = Complex(tmp.imag(), tmp.real());
                    }
                    ptr3[vec_size / 2] = std::conj(unit_root(8, 3));
                }
                cur_log_size = std::max(cur_log_size, shift);
            }
            // 返回单位圆上辐角为theta的点
            static Complex unit_root(double theta)
            {
                return std::polar<double>(1.0, theta);
            }
            // 返回单位圆上平分m份的第n个
            static Complex unit_root(size_t m, size_t n)
            {
                return unit_root((HINT_2PI * n) / m);
            }
            // shift表示圆平分为1<<shift份,n表示第几个单位根
            Complex get_omega(UINT_32 shift, size_t n) const
            {
                size_t rank = 1ull << shift;
                const size_t rank_ff = rank - 1, quad_n = n << 2;
                // n &= rank_ff;
                size_t zone = quad_n >> shift; // 第几象限
                if ((quad_n & rank_ff) == 0)
                {
                    static constexpr Complex ONES[4] = {Complex(1, 0), Complex(0, -1), Complex(-1, 0), Complex(0, 1)};
                    return ONES[zone];
                }
                auto ptr1 = reinterpret_cast<const Complex *>(table1[shift].data());
                Complex tmp;
                if ((zone & 2) == 0)
                {
                    if ((zone & 1) == 0)
                    {
                        tmp = ptr1[n];
                    }
                    else
                    {
                        tmp = ptr1[rank / 2 - n];
                        tmp.real(-tmp.real());
                    }
                }
                else
                {
                    if ((zone & 1) == 0)
                    {
                        tmp = -ptr1[n - rank / 2];
                    }
                    else
                    {
                        tmp = ptr1[rank - n];
                        tmp.imag(-tmp.imag());
                    }
                }
                return tmp;
            }
            // shift表示圆平分为1<<shift份,3n表示第几个单位根
            Complex get_omega3(UINT_32 shift, size_t n) const
            {
                return reinterpret_cast<const Complex *>(table3[shift].data())[n];
            }
            // shift表示圆平分为1<<shift份,n表示第几个单位根
            Complex2 get_omegaX2(UINT_32 shift, size_t n) const
            {
                return table1[shift][n];
            }
            // shift表示圆平分为1<<shift份,3n表示第几个单位根
            Complex2 get_omega3X2(UINT_32 shift, size_t n) const
            {
                return table3[shift][n];
            }
        };

        constexpr size_t lut_max_rank = 23;

        static ComplexTableY TABLE(lut_max_rank);

        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T &ary, size_t len)
        {
            size_t i = 0;
            for (size_t j = 1; j < len - 1; j++)
            {
                size_t k = len >> 1;
                i ^= k;
                while (k > i)
                {
                    k >>= 1;
                    i ^= k;
                };
                if (j < i)
                {
                    std::swap(ary[i], ary[j]);
                }
            }
        }
        // 四进制逆序
        template <typename SizeType = UINT_32, typename T>
        void quaternary_reverse_swap(T &ary, size_t len)
        {
            SizeType log_n = hint_log2(len);
            SizeType *rev = new SizeType[len / 4];
            rev[0] = 0;
            for (SizeType i = 1; i < len; i++)
            {
                SizeType index = (rev[i >> 2] >> 2) | ((i & 3) << (log_n - 2)); // 求rev交换数组
                if (i < len / 4)
                {
                    rev[i] = index;
                }
                if (i < index)
                {
                    std::swap(ary[i], ary[index]);
                }
            }
            delete[] rev;
        }
        // 2点fft
        template <typename T>
        inline void fft_2point(T &sum, T &diff)
        {
            T tmp0 = sum;
            T tmp1 = diff;
            sum = tmp0 + tmp1;
            diff = tmp0 - tmp1;
        }
        // 4点fft
        inline void fft_4point(Complex *input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp2 + tmp3;
            input[rank * 2] = tmp0 - tmp1;
            input[rank * 3] = tmp2 - tmp3;
        }
        inline void fft_dit_4point(Complex *input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp1);
            fft_2point(tmp2, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 2] = tmp0 - tmp2;
            input[rank * 3] = tmp1 - tmp3;
        }
        inline void fft_dit_8point(Complex *input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];
            Complex tmp4 = input[rank * 4];
            Complex tmp5 = input[rank * 5];
            Complex tmp6 = input[rank * 6];
            Complex tmp7 = input[rank * 7];
            fft_2point(tmp0, tmp1);
            fft_2point(tmp2, tmp3);
            fft_2point(tmp4, tmp5);
            fft_2point(tmp6, tmp7);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());
            tmp7 = Complex(tmp7.imag(), -tmp7.real());

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            fft_2point(tmp4, tmp6);
            fft_2point(tmp5, tmp7);
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
            tmp6 = Complex(tmp6.imag(), -tmp6.real());
            tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());

            input[0] = tmp0 + tmp4;
            input[rank] = tmp1 + tmp5;
            input[rank * 2] = tmp2 + tmp6;
            input[rank * 3] = tmp3 + tmp7;
            input[rank * 4] = tmp0 - tmp4;
            input[rank * 5] = tmp1 - tmp5;
            input[rank * 6] = tmp2 - tmp6;
            input[rank * 7] = tmp3 - tmp7;
        }

        inline void fft_dif_4point(Complex *input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp0 - tmp1;
            input[rank * 2] = tmp2 + tmp3;
            input[rank * 3] = tmp2 - tmp3;
        }
        inline void fft_dif_8point(Complex *input, size_t rank = 1)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];
            Complex tmp4 = input[rank * 4];
            Complex tmp5 = input[rank * 5];
            Complex tmp6 = input[rank * 6];
            Complex tmp7 = input[rank * 7];

            fft_2point(tmp0, tmp4);
            fft_2point(tmp1, tmp5);
            fft_2point(tmp2, tmp6);
            fft_2point(tmp3, tmp7);
            static constexpr double cos_1_8 = 0.70710678118654752440084436210485;
            tmp5 = cos_1_8 * Complex(tmp5.imag() + tmp5.real(), tmp5.imag() - tmp5.real());
            tmp6 = Complex(tmp6.imag(), -tmp6.real());
            tmp7 = -cos_1_8 * Complex(tmp7.real() - tmp7.imag(), tmp7.real() + tmp7.imag());

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            fft_2point(tmp4, tmp6);
            fft_2point(tmp5, tmp7);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());
            tmp7 = Complex(tmp7.imag(), -tmp7.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp0 - tmp1;
            input[rank * 2] = tmp2 + tmp3;
            input[rank * 3] = tmp2 - tmp3;
            input[rank * 4] = tmp4 + tmp5;
            input[rank * 5] = tmp4 - tmp5;
            input[rank * 6] = tmp6 + tmp7;
            input[rank * 7] = tmp6 - tmp7;
        }

        // fft基2时间抽取蝶形变换
        inline void fft_radix2_dit_butterfly(Complex omega, Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank] * omega;
            input[0] = tmp0 + tmp1;
            input[rank] = tmp0 - tmp1;
        }
        // fft基2频率抽取蝶形变换
        inline void fft_radix2_dif_butterfly(Complex omega, Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            input[0] = tmp0 + tmp1;
            input[rank] = (tmp0 - tmp1) * omega;
        }
        // fft基2频率抽取蝶形变换
        inline void fft_radix2_dif_butterfly(Complex2 omega, Complex *input, size_t rank)
        {
            Complex2 tmp0(input);
            Complex2 tmp1(input + rank);
            (tmp0 + tmp1).store(input);
            ((tmp0 - tmp1) * omega).store(input + rank);
        }

        // fft分裂基时间抽取蝶形变换
        inline void fft_split_radix_dit_butterfly(Complex omega, Complex omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2] * omega;
            Complex tmp3 = input[rank * 3] * omega_cube;

            fft_2point(tmp2, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp2;
            input[rank] = tmp1 + tmp3;
            input[rank * 2] = tmp0 - tmp2;
            input[rank * 3] = tmp1 - tmp3;
        }
        // fft分裂基频率抽取蝶形变换
        inline void fft_split_radix_dif_butterfly(Complex omega, Complex omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0;
            input[rank] = tmp1;
            input[rank * 2] = (tmp2 + tmp3) * omega;
            input[rank * 3] = (tmp2 - tmp3) * omega_cube;
        }
        // fft分裂基时间抽取蝶形变换
        inline void fft_split_radix_dit_butterfly(Complex2 omega, Complex2 omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex2 tmp0 = input;
            Complex2 tmp1 = input + rank;
            Complex2 tmp2 = Complex2(input + rank * 2) * omega;
            Complex2 tmp3 = Complex2(input + rank * 3) * omega_cube;

            fft_2point(tmp2, tmp3);
            tmp3 = tmp3.mul_neg_i();

            (tmp0 + tmp2).store(input);
            (tmp1 + tmp3).store(input + rank);
            (tmp0 - tmp2).store(input + rank * 2);
            (tmp1 - tmp3).store(input + rank * 3);
        }
        // fft分裂基频率抽取蝶形变换
        inline void fft_split_radix_dif_butterfly(Complex2 omega, Complex2 omega_cube,
                                                  Complex *input, size_t rank)
        {
            Complex2 tmp0 = (input);
            Complex2 tmp1 = (input + rank);
            Complex2 tmp2 = (input + rank * 2);
            Complex2 tmp3 = (input + rank * 3);

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = tmp3.mul_neg_i();

            tmp0.store(input);
            tmp1.store(input + rank);
            ((tmp2 + tmp3) * omega).store(input + rank * 2);
            ((tmp2 - tmp3) * omega_cube).store(input + rank * 3);
        }
        // fft基4时间抽取蝶形变换
        inline void fft_radix4_dit_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
                                             Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank] * omega;
            Complex tmp2 = input[rank * 2] * omega_sqr;
            Complex tmp3 = input[rank * 3] * omega_cube;

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = tmp2 + tmp3;
            input[rank * 2] = tmp0 - tmp1;
            input[rank * 3] = tmp2 - tmp3;
        }
        // fft基4频率抽取蝶形变换
        inline void fft_radix4_dif_butterfly(Complex omega, Complex omega_sqr, Complex omega_cube,
                                             Complex *input, size_t rank)
        {
            Complex tmp0 = input[0];
            Complex tmp1 = input[rank];
            Complex tmp2 = input[rank * 2];
            Complex tmp3 = input[rank * 3];

            fft_2point(tmp0, tmp2);
            fft_2point(tmp1, tmp3);
            tmp3 = Complex(tmp3.imag(), -tmp3.real());

            input[0] = tmp0 + tmp1;
            input[rank] = (tmp2 + tmp3) * omega;
            input[rank * 2] = (tmp0 - tmp1) * omega_sqr;
            input[rank * 3] = (tmp2 - tmp3) * omega_cube;
        }
        // 求共轭复数及归一化,逆变换用
        inline void fft_conj(Complex *input, size_t fft_len, double div = 1)
        {
            for (size_t i = 0; i < fft_len; i++)
            {
                input[i] = std::conj(input[i]) / div;
            }
        }
        // 归一化,逆变换用
        inline void fft_normalize(Complex *input, size_t fft_len)
        {
            double len = static_cast<double>(fft_len);
            for (size_t i = 0; i < fft_len; i++)
            {
                input[i] /= len;
            }
        }
        // 模板化时间抽取分裂基fft
        template <size_t LEN>
        void fft_split_radix_dit_template(Complex *input)
        {
            constexpr size_t log_len = hint_log2(LEN);
            constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
            fft_split_radix_dit_template<half_len>(input);
            fft_split_radix_dit_template<quarter_len>(input + half_len);
            fft_split_radix_dit_template<quarter_len>(input + half_len + quarter_len);
            for (size_t i = 0; i < quarter_len; i += 2)
            {
                auto omega = TABLE.get_omegaX2(log_len, i / 2);
                auto omega_cube = TABLE.get_omega3X2(log_len, i / 2);
                fft_split_radix_dit_butterfly(omega, omega_cube, input + i, quarter_len);
            }
        }
        template <>
        void fft_split_radix_dit_template<0>(Complex *input) {}
        template <>
        void fft_split_radix_dit_template<1>(Complex *input) {}
        template <>
        void fft_split_radix_dit_template<2>(Complex *input)
        {
            fft_2point(input[0], input[1]);
        }
        template <>
        void fft_split_radix_dit_template<4>(Complex *input)
        {
            fft_dit_4point(input);
        }
        template <>
        void fft_split_radix_dit_template<8>(Complex *input)
        {
            fft_dit_8point(input);
        }

        // 模板化频率抽取分裂基fft
        template <size_t LEN>
        void fft_split_radix_dif_template(Complex *input)
        {
            constexpr size_t log_len = hint_log2(LEN);
            constexpr size_t half_len = LEN / 2, quarter_len = LEN / 4;
            for (size_t i = 0; i < quarter_len; i += 2)
            {
                Complex2 omega = TABLE.get_omegaX2(log_len, i / 2);
                Complex2 omega_cube = TABLE.get_omega3X2(log_len, i / 2);
                fft_split_radix_dif_butterfly(omega, omega_cube, input + i, quarter_len);
            }
            fft_split_radix_dif_template<half_len>(input);
            fft_split_radix_dif_template<quarter_len>(input + half_len);
            fft_split_radix_dif_template<quarter_len>(input + half_len + quarter_len);
        }
        template <>
        void fft_split_radix_dif_template<0>(Complex *input) {}
        template <>
        void fft_split_radix_dif_template<1>(Complex *input) {}
        template <>
        void fft_split_radix_dif_template<2>(Complex *input)
        {
            fft_2point(input[0], input[1]);
        }
        template <>
        void fft_split_radix_dif_template<4>(Complex *input)
        {
            fft_dif_4point(input);
        }
        template <>
        void fft_split_radix_dif_template<8>(Complex *input)
        {
            fft_dif_8point(input);
        }

        template <size_t LEN = 1>
        void fft_dit_template(Complex *input, size_t fft_len)
        {
            if (fft_len > LEN)
            {
                fft_dit_template<LEN * 2>(input, fft_len);
                return;
            }
            TABLE.expand(hint_log2(LEN));
            fft_split_radix_dit_template<LEN>(input);
        }
        template <>
        void fft_dit_template<1 << 24>(Complex *input, size_t fft_len) {}

        template <size_t LEN = 1>
        void fft_dif_template(Complex *input, size_t fft_len)
        {
            if (fft_len > LEN)
            {
                fft_dif_template<LEN * 2>(input, fft_len);
                return;
            }
            TABLE.expand(hint_log2(LEN));
            fft_split_radix_dif_template<LEN>(input);
        }
        template <>
        void fft_dif_template<1 << 24>(Complex *input, size_t fft_len) {}

        /// @brief 时间抽取基2fft
        /// @param input 复数组
        /// @param fft_len 数组长度
        /// @param bit_rev 是否逆序
        inline void fft_dit(Complex *input, size_t fft_len, bool bit_rev = true)
        {
            fft_len = max_2pow(fft_len);
            if (bit_rev)
            {
                binary_reverse_swap(input, fft_len);
            }
            fft_dit_template<1>(input, fft_len);
        }

        /// @brief 频率抽取基2fft
        /// @param input 复数组
        /// @param fft_len 数组长度
        /// @param bit_rev 是否逆序
        inline void fft_dif(Complex *input, size_t fft_len, bool bit_rev = true)
        {
            fft_len = max_2pow(fft_len);
            fft_dif_template<1>(input, fft_len);
            if (bit_rev)
            {
                binary_reverse_swap(input, fft_len);
            }
        }

    }
    // 利用编译器优化一次性算出商和余数
    template <typename T>
    constexpr std::pair<T, T> div_mod(T a, T b)
    {
        return std::make_pair(a / b, a % b);
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, UINT_64 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    std::string ui64to_string(UINT_64 input, UINT_8 digits)
    {
        std::string result(digits, '0');
        for (UINT_8 i = 0; i < digits; i++)
        {
            result[digits - i - 1] = static_cast<char>(input % 10 + '0');
            input /= 10;
        }
        return result;
    }
    constexpr UINT_64 stoui64(char *s, size_t dig = 4)
    {
        UINT_64 result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }
    static constexpr INT_64 DIGIT = 4;
    size_t char_to_real(char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            hint::UINT_64 tmp = stoui64(buffer + pos - DIGIT, DIGIT);
            comary[i].real(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_64 tmp = stoui64(buffer, pos);
            comary[i].real(tmp);
        }
        return len;
    }
    size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            hint::UINT_64 tmp = stoui64(buffer + pos - DIGIT, DIGIT);
            comary[i].imag(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_64 tmp = stoui64(buffer, pos);
            comary[i].imag(tmp);
        }
        return len;
    }
    void num_to_s(char *s, UINT_64 num)
    {
        char c = '0';
        int i = DIGIT;
        while (i > 0)
        {
            i--;
            std::tie(num, c) = div_mod<UINT_64>(num, 10);
            s[i] = c + '0';
        }
    }
}
using namespace std;
using namespace hint;
using namespace hint_transform;

void poly_multiply(unsigned *a, int n, unsigned *b, int m, unsigned *c)
{
    size_t len1 = n + 1, len2 = m + 1, out_len = len1 + len2 - 1;
    size_t fft_len = min_2pow(out_len);
    Complex *fft_ary = new Complex[fft_len];
    com_ary_combine_copy(fft_ary, a, len1, b, len2);
    fft_dif(fft_ary, fft_len, false); // 优化FFT
    for (size_t i = 0; i < fft_len; i++)
    {
        Complex tmp = fft_ary[i];
        fft_ary[i] = std::conj(tmp * tmp);
    }
    fft_dit(fft_ary, fft_len, false); // 优化FFT
    HintFloat inv = -0.5 / fft_len;
    for (size_t i = 0; i < out_len; i++)
    {
        c[i] = unsigned(fft_ary[i].imag() * inv + 0.5);
    }
}

// int main()
// {
//     int len1 = 4, len2 = 4;
//     unsigned a[100] = {5, 5, 5, 5, 5};
//     unsigned b[100] = {5, 5, 5, 5, 5};
//     unsigned c[100] = {};
//     poly_multiply(a, len1, b, len2, c);
//     for (int i = 0; i < len1 + len2 + 1; i++)
//     {
//         cout << c[i] << "\t";
//     }
// }
// int main()
// {
//     size_t m = 4, n = 4;
//     std::cin >> m >> n;
//     size_t len1 = m + 1, len2 = n + 1;
//     size_t fft_len = min_2pow(len1 + len2 - 1);
//     Complex *a = new Complex[fft_len];
//     for (size_t i = 0; i < len1; i++)
//     {
//         int c = 5;
//         scanf("%d", &c);
//         a[i].real(c);
//     }
//     for (size_t i = 0; i < len2; i++)
//     {
//         int c = 5;
//         scanf("%d", &c);
//         a[i].imag(c);
//     }
//     fft_dif(a, fft_len, false);
//     HintFloat inv = -1.0 / (fft_len * 2);
//     for (size_t i = 0; i < fft_len; i++)
//     {
//         Complex tmp = a[i];
//         a[i] = std::conj(tmp * tmp * inv);
//     }
//     fft_dit(a, fft_len, false);
//     for (size_t i = 0; i < len1 + len2 - 1; i++)
//     {
//         printf("%d ", int(a[i].imag() + 0.5));
//     }
// }

constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
constexpr size_t STR_LEN = 2000005;
char out[STR_LEN];
Complex2 avx_ary[1 << 18];

int main()
{
    // for (size_t i = 0; i < 1000000; i++)
    // {
    //     out[i] = '7';
    // }
    
    Complex *fft_ary = reinterpret_cast<Complex *>(avx_ary);
    size_t len_a = 0, len_b = 0;
    scanf("%s", out);
    while (isdigit(out[len_a]))
    {
        len_a++;
    }
    if (len_a == 1 && out[0] == '0')
    {
        printf("0");
        return 0;
    }
    size_t len1 = char_to_real(out, fft_ary, len_a);
    // for (size_t i = 0; i < 1000000; i++)
    // {
    //     out[i] = '9';
    // }
    scanf("%s", out);
    while (isdigit(out[len_b]))
    {
        len_b++;
    }
    if (len_b == 1 && out[0] == '0')
    {
        printf("0");
        return 0;
    }
    size_t len2 = char_to_imag(out, fft_ary, len_b);
    size_t fft_len = min_2pow(len1 + len2 - 1);
    fft_dif(fft_ary, fft_len, false); // 优化FFT
    for (size_t i = 0; i < fft_len; i++)
    {
        Complex tmp = fft_ary[i];
        fft_ary[i] = std::conj(tmp * tmp);
    }
    fft_dit(fft_ary, fft_len, false); // 优化FFT
    HintFloat inv = -0.5 / fft_len;
    UINT_64 carry = 0;
    size_t pos = STR_LEN - 1;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        carry += UINT_64(fft_ary[i].imag() * inv + 0.5);
        UINT_64 num = 0;
        std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
        num_to_s(out + pos - DIGIT, num);
        pos -= DIGIT;
    }
    num_to_s(out + pos - DIGIT, carry);
    pos -= DIGIT;
    while (out[pos] == '0')
    {
        pos++;
    }
    puts(out + pos);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #134.2 us32 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-15 12:05:02 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠