提交记录 19817


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY 1004. 【模板题】高精度乘法 Runtime Error 0 32.2 us 32 KB C++14 7.16 KB
提交时间 评测时间
2023-08-03 23:10:59 2023-08-03 23:11:01
#include <complex>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP

#pragma GCC target("fma")
#pragma GCC target("avx2")
#pragma GCC target("avx512f")

namespace hint_simd
{
    // Use AVX
    using Complex = std::complex<double>;
    // 2个复数并行
    struct Complex2
    {
        __m256d data;
        Complex2()
        {
            data = _mm256_setzero_pd();
        }
        Complex2(double input)
        {
            data = _mm256_set1_pd(input);
        }
        Complex2(const __m256d &input)
        {
            data = input;
        }
        Complex2(const Complex2 &input)
        {
            data = input.data;
        }
        // 从连续的数组构造
        Complex2(double const *ptr)
        {
            data = _mm256_loadu_pd(ptr);
        }
        Complex2(Complex a)
        {
            data = _mm256_broadcast_pd((__m128d *)&a);
        }
        Complex2(Complex a, Complex b)
        {
            data = _mm256_set_m128d(*(__m128d *)&b, *(__m128d *)&a);
        }
        Complex2(const Complex *ptr)
        {
            data = _mm256_loadu_pd((const double *)ptr);
        }
        void clr()
        {
            data = _mm256_setzero_pd();
        }
        void store(Complex *a) const
        {
            _mm256_storeu_pd((double *)a, data);
        }
        void print() const
        {
            double ary[4];
            _mm256_storeu_pd(ary, data);
            printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
        }
        template <int M>
        Complex2 element_mask_neg() const
        {
            static const __m256d neg_mask = _mm256_castsi256_pd(
                _mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
            return _mm256_xor_pd(data, neg_mask);
        }
        template <int M>
        Complex2 element_permute() const
        {
            return _mm256_permute_pd(data, M);
        }
        template <int M>
        Complex2 element_permute64() const
        {
            return _mm256_permute4x64_pd(data, M);
        }
        Complex2 all_real() const
        {
            return _mm256_unpacklo_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 0);
            // return _mm256_movedup_pd(data);
        }
        Complex2 all_imag() const
        {
            return _mm256_unpackhi_pd(data, data);
            // return _mm256_shuffle_pd(data, data, 15);
            // return element_permute<0XF>();
        }
        Complex2 swap() const
        {
            return _mm256_shuffle_pd(data, data, 5);
            // return element_permute<0X5>();
        }
        Complex2 mul_neg_i() const
        {
            static const __m256d subber{};
            return Complex2(_mm256_addsub_pd(subber, data)).swap();
            // return swap().conj();
        }
        Complex2 conj() const
        {
            return element_mask_neg<10>();
        }
        Complex2 linear_mul(const Complex2 &input) const
        {
            return _mm256_mul_pd(data, input.data);
        }
        Complex2 square() const
        {
            const __m256d rr = all_real().data;
            const __m256d ir = swap().data;
            const __m256d add = _mm256_add_pd(rr, ir);
            const __m256d sub = _mm256_sub_pd(rr, ir);
            return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
        }
        Complex2 operator+(const Complex2 &input) const
        {
            return _mm256_add_pd(data, input.data);
        }
        Complex2 operator-(const Complex2 &input) const
        {
            return _mm256_sub_pd(data, input.data);
        }
        Complex2 operator*(const Complex2 &input) const
        {
            const __m256d a_rr = all_real().data;
            const __m256d a_ii = all_imag().data;
            const __m256d b_ir = input.swap().data;
            return _mm256_addsub_pd(_mm256_mul_pd(a_rr, input.data), _mm256_mul_pd(a_ii, b_ir));
            // auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
            // return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
        }
        Complex2 operator/(const Complex2 &input) const
        {
            return _mm256_div_pd(data, input.data);
        }
    };
    // 4个复数并行
    struct Complex4
    {
        __m512d data;
        Complex4()
        {
            data = _mm512_setzero_pd();
        }
        Complex4(double input)
        {
            data = _mm512_set1_pd(input);
        }
        Complex4(const __m512d &input)
        {
            data = input;
        }
        Complex4(const Complex4 &input)
        {
            data = input.data;
        }
        // 从连续的数组构造
        Complex4(double const *ptr)
        {
            data = _mm512_loadu_pd(ptr);
        }
        Complex4(const Complex *ptr)
        {
            data = _mm512_loadu_pd((const double *)ptr);
        }
        void clr()
        {
            data = _mm512_setzero_pd();
        }
        void store(Complex *a) const
        {
            _mm512_storeu_pd((double *)a, data);
        }
        void print() const
        {
            double ary[4];
            _mm512_storeu_pd(ary, data);
            printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
        }
        template <int M>
        Complex4 element_permute() const
        {
            return _mm512_permute_pd(data, M);
        }
        Complex4 all_real() const
        {
            return _mm512_unpacklo_pd(data, data);
        }
        Complex4 all_imag() const
        {
            return _mm512_unpackhi_pd(data, data);
        }
        Complex4 swap() const
        {
            return _mm512_shuffle_pd(data, data, 0b01010101);
        }
        Complex4 mul_neg_i() const
        {
            static const __m512d subber{};
            static const __m512d one = _mm512_set1_pd(1.0);
            return Complex4(_mm512_fmaddsub_pd(subber, one, data)).swap();
        }
        Complex4 linear_mul(Complex4 input) const
        {
            return _mm512_mul_pd(data, input.data);
        }
        // Complex4 square() const
        // {
        //     const __m512d rr = all_real().data;
        //     const __m512d ir = swap().data;
        //     const __m512d add = _mm512_add_pd(rr, ir);
        //     const __m512d sub = _mm512_sub_pd(rr, ir);
        //     return _mm512_mul_pd(add, _mm512_blend_pd(sub, data, 0b1010));
        // }
        Complex4 operator+(Complex4 input) const
        {
            return _mm512_add_pd(data, input.data);
        }
        Complex4 operator-(Complex4 input) const
        {
            return _mm512_sub_pd(data, input.data);
        }
        Complex4 operator*(Complex4 input) const
        {
            const __m512d a_rr = all_real().data;
            const __m512d a_ii = all_imag().data;
            const __m512d b_ir = input.swap().data;
            return _mm512_fmaddsub_pd(a_rr, input.data, _mm512_mul_pd(a_ii, b_ir));
        }
        Complex4 operator/(Complex4 input) const
        {
            return _mm512_div_pd(data, input.data);
        }
    };
}
#endif
using namespace hint_simd;
#include<iostream>
int main()
{
    Complex ary[4]={1,2,3,4};
    Complex4 a = ary;
    a = a * a;
    a.store(ary);
    for(auto i:ary)
    {
        std::cout<<i;
    }
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #132.2 us32 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-14 14:59:47 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠