提交记录 20830


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Accepted 100 23.423 ms 16248 KB C++14 21.15 KB
提交时间 评测时间
2024-01-21 18:49:50 2024-01-21 18:49:52
#include <tuple>
#include <iostream>
#include <complex>
#include <cstring>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("fma")
#pragma GCC target("avx")

namespace hint_simd
{
    template <typename T, size_t LEN>
    class AlignAry
    {
    private:
        alignas(4096) T ary[LEN];

    public:
        constexpr AlignAry() {}
        constexpr T &operator[](size_t index)
        {
            return ary[index];
        }
        constexpr const T &operator[](size_t index) const
        {
            return ary[index];
        }
        T *data()
        {
            return reinterpret_cast<T *>(ary);
        }
        const T *data() const
        {
            return reinterpret_cast<const T *>(ary);
        }
        template <typename Ty>
        Ty *cast_ptr()
        {
            return reinterpret_cast<Ty *>(ary);
        }
        template <typename Ty>
        const Ty *cast_ptr() const
        {
            return reinterpret_cast<const Ty *>(ary);
        }
    };

    // Use AVX
    // 256bit simd
    using HintFloat = double;
    using Complex = std::complex<HintFloat>;

    // 2个复数并行
    struct Complex2
    {
        __m256d data;
        Complex2()
        {
            data = _mm256_setzero_pd();
        }
        Complex2(double input)
        {
            data = _mm256_set1_pd(input);
        }
        Complex2(__m256d input)
        {
            data = input;
        }
        Complex2(const Complex2 &input)
        {
            data = input.data;
        }
        // 从连续的数组构造
        Complex2(double const *ptr)
        {
            data = _mm256_load_pd(ptr);
        }
        Complex2(const Complex *ptr)
        {
            data = _mm256_load_pd((const double *)ptr);
        }
        void clr()
        {
            data = _mm256_setzero_pd();
        }
        void store(Complex *a) const
        {
            _mm256_store_pd((double *)a, data);
        }
        Complex2(const Complex &a, const Complex &b)
        {
            data = _mm256_set_m128d(*(const __m128d *)&b, *(const __m128d *)&a);
        }
        void print() const
        {
            double ary[4];
            _mm256_storeu_pd(ary, data);
            printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
        }
        template <int M>
        Complex2 element_mask_neg() const
        {
            static const __m256d neg_mask = _mm256_castsi256_pd(
                _mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
            return _mm256_xor_pd(data, neg_mask);
        }
        template <int M>
        Complex2 element_permute() const
        {
            return _mm256_permute_pd(data, M);
        }
        template <int M>
        Complex2 element_permute64() const
        {
            return _mm256_permute4x64_pd(data, M);
        }
        Complex2 all_real() const
        {
            return _mm256_unpacklo_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 0);
            // return _mm256_movedup_pd(data);
        }
        Complex2 all_imag() const
        {
            return _mm256_unpackhi_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 15);
            // return element_permute<0XF>();
        }
        Complex2 swap() const
        {
            return _mm256_shuffle_pd(data, data, 5);
            // return element_permute<0X5>();
        }
        Complex2 mul_neg_i() const
        {
            return Complex2(_mm256_addsub_pd(_mm256_setzero_pd(), data)).swap();
            // return swap().conj();
        }
        Complex2 conj() const
        {
            return element_mask_neg<10>();
        }
        Complex2 linear_mul(Complex2 input) const
        {
            return _mm256_mul_pd(data, input.data);
        }
        Complex2 square() const
        {
            const __m256d rr = all_real().data;
            const __m256d ir = swap().data;
            const __m256d add = _mm256_add_pd(rr, ir);
            const __m256d sub = _mm256_sub_pd(rr, ir);
            return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
        }
        Complex2 operator+(const Complex2 &input) const
        {
            return _mm256_add_pd(data, input.data);
        }
        Complex2 operator-(const Complex2 &input) const
        {
            return _mm256_sub_pd(data, input.data);
        }
        Complex2 operator*(const Complex2 &input) const
        {
            auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
            return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
        }
    };
}
#endif

#include <array>
#include <vector>
#include <iostream>
#include <future>
#include <ctime>
#include <cstring>
// #include "stopwatch.hpp"

#define TABLE_ENABLE 1
#define MULTITHREAD 0   // 多线程 0 means no, 1 means yes
#define TABLE_PRELOAD 0 // 是否提前初始化表 0 means no, 1 means yes
using namespace hint_simd;
namespace hint
{
    using UINT_8 = uint8_t;
    using UINT_16 = uint16_t;
    using UINT_32 = uint32_t;
    using UINT_64 = uint64_t;
    using INT_32 = int32_t;
    using INT_64 = int64_t;
    using ULONG = unsigned long;
    using LONG = long;
    using HintFloat = double;
    using Complex = std::complex<HintFloat>;

    constexpr UINT_64 HINT_CHAR_BIT = 8;
    constexpr UINT_64 HINT_SHORT_BIT = 16;
    constexpr UINT_64 HINT_INT_BIT = 32;
    constexpr UINT_64 HINT_INT8_0XFF = UINT8_MAX;
    constexpr UINT_64 HINT_INT8_0X10 = (UINT8_MAX + 1ull);
    constexpr UINT_64 HINT_INT16_0XFF = UINT16_MAX;
    constexpr UINT_64 HINT_INT16_0X10 = (UINT16_MAX + 1ull);
    constexpr UINT_64 HINT_INT32_0XFF = UINT32_MAX;
    constexpr UINT_64 HINT_INT32_0X01 = 1;
    constexpr UINT_64 HINT_INT32_0X80 = 0X80000000ull;
    constexpr UINT_64 HINT_INT32_0X7F = INT32_MAX;
    constexpr UINT_64 HINT_INT32_0X10 = (UINT32_MAX + 1ull);
    constexpr UINT_64 HINT_INT64_0X80 = INT64_MIN;
    constexpr UINT_64 HINT_INT64_0X7F = INT64_MAX;
    constexpr UINT_64 HINT_INT64_0XFF = UINT64_MAX;

    constexpr HintFloat HINT_PI = 3.1415926535897932384626433832795;
    constexpr HintFloat HINT_2PI = HINT_PI * 2;
    constexpr HintFloat HINT_HSQ_ROOT2 = 0.70710678118654752440084436210485;

    constexpr UINT_64 NTT_MOD1 = 3221225473;
    constexpr UINT_64 NTT_ROOT1 = 5;
    constexpr UINT_64 NTT_MOD2 = 3489660929;
    constexpr UINT_64 NTT_ROOT2 = 3;
    constexpr size_t NTT_MAX_LEN = 1ull << 28;

    /// @brief 生成不大于n的最大的2的幂次的数
    /// @param n
    /// @return 不大于n的最大的2的幂次的数
    template <typename T>
    constexpr T max_2pow(T n)
    {
        T res = 1;
        res <<= (sizeof(T) * 8 - 1);
        while (res > n)
        {
            res /= 2;
        }
        return res;
    }
    template <typename T>
    constexpr T int_floor2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return (n >> 1) + 1;
    }
    template <typename T>
    constexpr T int_ceil2(T n)
    {
        constexpr int bits = sizeof(n) * 8;
        n--;
        for (int i = 1; i < bits; i *= 2)
        {
            n |= (n >> i);
        }
        return n + 1;
    }
    template <typename T>
    constexpr T hint_log2(T n)
    {
        T res = 0;
        while (n > 1)
        {
            n /= 2;
            res++;
        }
        return res;
    }
    // 模板快速幂
    template <typename T>
    constexpr T qpow(T m, UINT_64 n)
    {
        T result = 1;
        while (n > 0)
        {
            if ((n & 1) != 0)
            {
                result = result * m;
            }
            m = m * m;
            n >>= 1;
        }
        return result;
    }
    template <typename T>
    constexpr std::pair<T, T> div_mod(T a, T b)
    {
        return std::make_pair(a / b, a % b);
    }
#if MULTITHREAD == 1
    const UINT_32 hint_threads = std::thread::hardware_concurrency();
    const UINT_32 log2_threads = std::ceil(hint_log2(hint_threads));
    std::atomic<UINT_32> cur_ths;
#endif

    // 模板数组拷贝
    template <typename T>
    void ary_copy(T *target, const T *source, size_t len)
    {
        if (len == 0 || target == source)
        {
            return;
        }
        if (len >= INT64_MAX)
        {
            throw("Ary too long\n");
        }
        std::memcpy(target, source, len * sizeof(T));
    }
    // 从其他类型数组拷贝到复数组
    template <typename T>
    inline void com_ary_combine_copy(Complex *target, const T &source1, size_t len1, const T &source2, size_t len2)
    {
        size_t min_len = std::min(len1, len2);
        size_t i = 0;
        while (i < min_len)
        {
            target[i] = Complex(source1[i], source2[i]);
            i++;
        }
        while (i < len1)
        {
            target[i].real(source1[i]);
            i++;
        }
        while (i < len2)
        {
            target[i].imag(source2[i]);
            i++;
        }
    }

    // FFT与类FFT变换的命名空间
    namespace hint_transform
    {
        // 二进制逆序
        template <typename T>
        void binary_reverse_swap(T &ary, size_t len)
        {
            size_t i = 0;
            for (size_t j = 1; j < len - 1; j++)
            {
                size_t k = len >> 1;
                i ^= k;
                while (k > i)
                {
                    k >>= 1;
                    i ^= k;
                };
                if (j < i)
                {
                    std::swap(ary[i], ary[j]);
                }
            }
        }
        namespace hint_fft
        {
            template <typename T>
            inline void fft_2point(T &sum, T &diff)
            {
                T tmp0 = sum;
                T tmp1 = diff;
                sum = tmp0 + tmp1;
                diff = tmp0 - tmp1;
            }
            template <size_t LEN>
            struct FFT_RADIX2
            {
                enum
                {
                    max_fft_len = LEN,
                    table_len = max_fft_len / 2
                };
                static constexpr Complex mul(Complex a, Complex b)
                {
                    return Complex(a.real() * b.real() - a.imag() * b.imag(), a.imag() * b.real() + a.real() * b.imag());
                }
                std::array<Complex, table_len> TABLE;
                constexpr FFT_RADIX2()
                {
                    table_init();
                }
                constexpr void table_init()
                {
                    constexpr Complex pre_compute[24] = {
                        Complex(1, 0),
                        Complex(0, -1),
                        Complex(0.7071067811865476, -0.7071067811865476),
                        Complex(0.9238795325112867, -0.3826834323650898),
                        Complex(0.9807852804032304, -0.19509032201612825),
                        Complex(0.9951847266721969, -0.0980171403295606),
                        Complex(0.9987954562051724, -0.049067674327418015),
                        Complex(0.9996988186962042, -0.024541228522912288),
                        Complex(0.9999247018391445, -0.012271538285719925),
                        Complex(0.9999811752826011, -0.006135884649154475),
                        Complex(0.9999952938095762, -0.003067956762965976),
                        Complex(0.9999988234517019, -0.0015339801862847655),
                        Complex(0.9999997058628822, -0.0007669903187427045),
                        Complex(0.9999999264657179, -0.00038349518757139556),
                        Complex(0.9999999816164293, -0.0001917475973107033),
                        Complex(0.9999999954041073, -9.587379909597734e-05),
                        Complex(0.9999999988510269, -4.793689960306688e-05),
                        Complex(0.9999999997127567, -2.396844980841822e-05),
                        Complex(0.9999999999281892, -1.1984224905069705e-05),
                        Complex(0.9999999999820472, -5.9921124526424275e-06),
                        Complex(0.9999999999955118, -2.996056226334661e-06),
                        Complex(0.999999999998878, -1.4980281131690111e-06),
                        Complex(0.9999999999997194, -7.490140565847157e-07),
                        Complex(0.9999999999999298, -3.7450702829238413e-07),
                    };
                    TABLE[0] = Complex(1.0, 0);
                    for (size_t i = 1, index = 1; i < table_len; i *= 2, index++)
                    {
                        TABLE[i] = pre_compute[index];
                    }
                    for (size_t i = 0; i < table_len; i++)
                    {
                        size_t j = i & (i - 1);
                        TABLE[i] = mul(TABLE[i - j], TABLE[j]);
                    }
                }
                void dif_avx(Complex *input, size_t len) const
                {
                    if (len > max_fft_len)
                    {
                        return;
                    }
                    for (size_t i = 0; i < len / 2; i++)
                    {
                        fft_2point(input[i], input[i + len / 2]);
                    }
                    for (size_t rank = len / 4; rank > 1; rank /= 2)
                    {
                        for (size_t begin = 0, index = 0; begin < len; begin += rank * 2, index++)
                        {
                            Complex *p1 = input + begin;
                            const Complex2 omega(TABLE[index], TABLE[index]);
                            for (size_t pos = 0; pos < rank; pos += 2)
                            {
                                Complex2 tmp1 = p1 + pos;
                                Complex2 tmp2 = Complex2(p1 + pos + rank) * omega;
                                (tmp1 + tmp2).store(p1 + pos);
                                (tmp1 - tmp2).store(p1 + pos + rank);
                            }
                        }
                    }
                    for (size_t i = 0; i < len; i += 2)
                    {
                        Complex tmp1 = input[i];
                        Complex tmp2 = input[i + 1] * TABLE[i / 2];
                        input[i] = tmp1 + tmp2;
                        input[i + 1] = tmp1 - tmp2;
                    }
                }
                void dit_avx(Complex *input, size_t len) const
                {
                    if (len > max_fft_len)
                    {
                        return;
                    }
                    for (size_t i = 0; i < len; i += 2)
                    {
                        Complex tmp1 = input[i];
                        Complex tmp2 = input[i + 1];
                        input[i] = tmp1 + tmp2;
                        input[i + 1] = (tmp1 - tmp2) * TABLE[i / 2];
                    }
                    for (size_t rank = 2; rank < len / 2; rank *= 2)
                    {
                        for (size_t begin = 0, index = 0; begin < len; begin += rank * 2, index++)
                        {
                            Complex *p1 = input + begin;
                            const Complex2 omega(TABLE[index], TABLE[index]);
                            for (size_t pos = 0; pos < rank; pos += 2)
                            {
                                Complex2 tmp1 = p1 + pos;
                                Complex2 tmp2 = p1 + pos + rank;
                                (tmp1 + tmp2).store(p1 + pos);
                                ((tmp1 - tmp2) * omega).store(p1 + pos + rank);
                            }
                        }
                    }
                    for (size_t i = 0; i < len / 2; i++)
                    {
                        fft_2point(input[i], input[i + len / 2]);
                    }
                }
            };
            constexpr size_t len1 = 1 << 19;
            static FFT_RADIX2<len1> small_fft;
        }
    }
    inline std::string ui64to_string(UINT_64 input, UINT_8 digits)
    {
        std::string result(digits, '0');
        for (UINT_8 i = 0; i < digits; i++)
        {
            result[digits - i - 1] = static_cast<char>(input % 10 + '0');
            input /= 10;
        }
        return result;
    }
    constexpr UINT_64 stoui64(const char *s, size_t dig = 4)
    {
        UINT_64 result = 0;
        for (size_t i = 0; i < dig; i++)
        {
            result *= 10;
            result += (s[i] - '0');
        }
        return result;
    }

    constexpr UINT_32 stobase10000(const char *s)
    {
        return s[0] * 1000 + s[1] * 100 + s[2] * 10 + s[3] - '0' * 1111;
    }
    constexpr UINT_32 stobase100000(const char *s)
    {
        return s[0] * 10000 + s[1] * 1000 + s[2] * 100 + s[3] * 10 + s[4] - '0' * 11111;
    }
    static constexpr INT_64 DIGIT = 4;
    inline size_t char_to_real(const char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            // hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
            hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
            comary[i].real(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_32 tmp = stoui64(buffer, pos);
            comary[i].real(tmp);
        }
        return len;
    }
    inline size_t char_to_imag(char *buffer, Complex *comary, size_t str_len)
    {
        hint::INT_64 len = str_len, pos = len, i = 0;
        len = (len + DIGIT - 1) / DIGIT;
        while (pos - DIGIT > 0)
        {
            // hint::UINT_64 tmp = stoui64<DIGIT>(buffer + pos - DIGIT);
            hint::UINT_32 tmp = stobase10000(buffer + pos - DIGIT);
            comary[i].imag(tmp);
            i++;
            pos -= DIGIT;
        }
        if (pos > 0)
        {
            hint::UINT_32 tmp = stoui64(buffer, pos);
            comary[i].imag(tmp);
        }
        return len;
    }
    inline void num_to_s(char *s, UINT_64 num)
    {
        char c = '0';
        int i = DIGIT;
        while (i > 0)
        {
            i--;
            std::tie(num, c) = div_mod<UINT_64>(num, 10);
            s[i] = c + '0';
        }
    }
    constexpr void num_to_s_base10000(char *s, UINT_32 num)
    {
        s[3] = '0' + num % 10;
        s[2] = '0' + num / 10 % 10;
        s[1] = '0' + num / 100 % 10;
        s[0] = '0' + num / 1000 % 10;
    }
    constexpr void num_to_s_base100000(char *s, UINT_32 num)
    {
        s[4] = '0' + num % 10;
        s[3] = '0' + num / 10 % 10;
        s[2] = '0' + num / 100 % 10;
        s[1] = '0' + num / 1000 % 10;
        s[0] = '0' + num / 10000 % 10;
    }
    class ItoStrBase10000
    {
    private:
        uint32_t table[10000]{};

    public:
        static constexpr uint32_t itosbase10000(uint32_t num)
        {
            uint32_t res = '0' * 0x1010101;
            res += (num / 1000 % 10) | ((num / 100 % 10) << 8) |
                   ((num / 10 % 10) << 16) | ((num % 10) << 24);
            return res;
        }
        constexpr ItoStrBase10000()
        {
            for (size_t i = 0; i < 10000; i++)
            {
                table[i] = itosbase10000(i);
            }
        }
        void tostr(char *str, uint32_t num) const
        {
            *reinterpret_cast<uint32_t *>(str) = table[num];
        }
        uint32_t tostr(uint32_t num) const
        {
            return table[num];
        }
    };
}

using namespace std;
using namespace hint;
using namespace hint_transform;
using namespace hint_fft;

int main()
{
    constexpr size_t STR_LEN = 2000008;
    constexpr uint64_t BASE = hint::qpow(10ull, DIGIT);
    static constexpr ItoStrBase10000 transfer;
    static AlignAry<char, STR_LEN> out;
    static AlignAry<HintFloat, 1 << 20> fft_ary_avx;
    Complex *fft_ary = fft_ary_avx.template cast_ptr<Complex>();
    uint32_t *ary = out.template cast_ptr<uint32_t>();
    size_t len_a = 0, len_b = 0;
    fread(out.data(), 1, STR_LEN, stdin);
    char *p = out.data();
    /*
        struct stat sb;
        int fd = fileno(stdin);
        fstat(fd, &sb);
        p = (char *)mmap(0, sb.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
        madvise(p, sb.st_size, MADV_SEQUENTIAL);
    */
    while (p[len_a] >= '0')
    {
        len_a++;
    }
    if (len_a == 1 && p[0] == '0')
    {
        puts("0");
        return 0;
    }
    char *b = p + len_a;
    while (!isdigit(*b))
    {
        b++;
    }
    while (b[len_b] >= '0')
    {
        len_b++;
    }
    if (len_b == 1 && b[0] == '0')
    {
        puts("0");
        return 0;
    } // 0.46ms
    size_t len1 = char_to_real(p, fft_ary, len_a);
    size_t len2 = char_to_imag(b, fft_ary, len_b); // 1.67ms
    size_t fft_len = int_ceil2(len1 + len2 - 1);
    small_fft.dif_avx(fft_ary, fft_len);
    HintFloat inv = -0.5 / fft_len;
    const Complex2 invx4(inv);
    for (size_t i = 0; i < fft_len; i += 2)
    {
        Complex2 tmp = fft_ary + i;
        tmp = tmp.square().linear_mul(invx4);
        (tmp.conj()).store(fft_ary + i);
    }
    small_fft.dit_avx(fft_ary, fft_len); // 6ms
    UINT_64 carry = 0;
    size_t pos = STR_LEN / 4 - 1;
    for (size_t i = 0; i < len1 + len2 - 1; i++)
    {
        carry += UINT_64(fft_ary[i].imag() + 0.5);
        UINT_64 num = 0;
        std::tie(carry, num) = div_mod<UINT_64>(carry, BASE);
        ary[pos] = transfer.tostr(num);
        pos--;
    }
    ary[pos] = transfer.tostr(carry);
    pos *= 4;
    while (out[pos] == '0')
    {
        pos++;
    } // 0.8ms
    fwrite(out.data() + pos, 1, STR_LEN - pos, stdout);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #123.423 ms15 MB + 888 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-07-22 21:14:48 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠