#include <complex>
#include <immintrin.h>
#ifndef HINT_SIMD_HPP
#define HINT_SIMD_HPP
#pragma GCC target("fma")
#pragma GCC target("avx2")
#pragma GCC target("avx512f")
namespace hint_simd
{
// Use AVX
using Complex = std::complex<double>;
// 2个复数并行
struct Complex2
{
__m256d data;
Complex2()
{
data = _mm256_setzero_pd();
}
Complex2(double input)
{
data = _mm256_set1_pd(input);
}
Complex2(const __m256d &input)
{
data = input;
}
Complex2(const Complex2 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex2(double const *ptr)
{
data = _mm256_loadu_pd(ptr);
}
Complex2(Complex a)
{
data = _mm256_broadcast_pd((__m128d *)&a);
}
Complex2(Complex a, Complex b)
{
data = _mm256_set_m128d(*(__m128d *)&b, *(__m128d *)&a);
}
Complex2(const Complex *ptr)
{
data = _mm256_loadu_pd((const double *)ptr);
}
void clr()
{
data = _mm256_setzero_pd();
}
void store(Complex *a) const
{
_mm256_storeu_pd((double *)a, data);
}
void print() const
{
double ary[4];
_mm256_storeu_pd(ary, data);
printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
}
template <int M>
Complex2 element_mask_neg() const
{
static const __m256d neg_mask = _mm256_castsi256_pd(
_mm256_set_epi64x((M & 8ull) << 60, (M & 4ull) << 61, (M & 2ull) << 62, (M & 1ull) << 63));
return _mm256_xor_pd(data, neg_mask);
}
template <int M>
Complex2 element_permute() const
{
return _mm256_permute_pd(data, M);
}
template <int M>
Complex2 element_permute64() const
{
return _mm256_permute4x64_pd(data, M);
}
Complex2 all_real() const
{
return _mm256_unpacklo_pd(data, data);
// return _mm256_shuffle_pd(data, data, 0);
// return _mm256_movedup_pd(data);
}
Complex2 all_imag() const
{
return _mm256_unpackhi_pd(data, data);
// return _mm256_shuffle_pd(data, data, 15);
// return element_permute<0XF>();
}
Complex2 swap() const
{
return _mm256_shuffle_pd(data, data, 5);
// return element_permute<0X5>();
}
Complex2 mul_neg_i() const
{
static const __m256d subber{};
return Complex2(_mm256_addsub_pd(subber, data)).swap();
// return swap().conj();
}
Complex2 conj() const
{
return element_mask_neg<10>();
}
Complex2 linear_mul(const Complex2 &input) const
{
return _mm256_mul_pd(data, input.data);
}
Complex2 square() const
{
const __m256d rr = all_real().data;
const __m256d ir = swap().data;
const __m256d add = _mm256_add_pd(rr, ir);
const __m256d sub = _mm256_sub_pd(rr, ir);
return _mm256_mul_pd(add, _mm256_blend_pd(sub, data, 0b1010));
}
Complex2 operator+(const Complex2 &input) const
{
return _mm256_add_pd(data, input.data);
}
Complex2 operator-(const Complex2 &input) const
{
return _mm256_sub_pd(data, input.data);
}
Complex2 operator*(const Complex2 &input) const
{
const __m256d a_rr = all_real().data;
const __m256d a_ii = all_imag().data;
const __m256d b_ir = input.swap().data;
return _mm256_addsub_pd(_mm256_mul_pd(a_rr, input.data), _mm256_mul_pd(a_ii, b_ir));
// auto imag = _mm256_mul_pd(all_imag().data, input.swap().data);
// return _mm256_fmaddsub_pd(all_real().data, input.data, imag);
}
Complex2 operator/(const Complex2 &input) const
{
return _mm256_div_pd(data, input.data);
}
};
// 4个复数并行
struct Complex4
{
__m512d data;
Complex4()
{
data = _mm512_setzero_pd();
}
Complex4(double input)
{
data = _mm512_set1_pd(input);
}
Complex4(const __m512d &input)
{
data = input;
}
Complex4(const Complex4 &input)
{
data = input.data;
}
// 从连续的数组构造
Complex4(double const *ptr)
{
data = _mm512_loadu_pd(ptr);
}
Complex4(const Complex *ptr)
{
data = _mm512_loadu_pd((const double *)ptr);
}
void clr()
{
data = _mm512_setzero_pd();
}
void store(Complex *a) const
{
_mm512_storeu_pd((double *)a, data);
}
void print() const
{
double ary[4];
_mm512_storeu_pd(ary, data);
printf("(%lf,%lf) (%lf,%lf)\n", ary[0], ary[1], ary[2], ary[3]);
}
template <int M>
Complex4 element_permute() const
{
return _mm512_permute_pd(data, M);
}
Complex4 all_real() const
{
return _mm512_unpacklo_pd(data, data);
}
Complex4 all_imag() const
{
return _mm512_unpackhi_pd(data, data);
}
Complex4 swap() const
{
return _mm512_shuffle_pd(data, data, 0b01010101);
}
Complex4 mul_neg_i() const
{
static const __m512d subber{};
static const __m512d one = _mm512_set1_pd(1.0);
return Complex4(_mm512_fmaddsub_pd(subber, one, data)).swap();
}
Complex4 linear_mul(Complex4 input) const
{
return _mm512_mul_pd(data, input.data);
}
// Complex4 square() const
// {
// const __m512d rr = all_real().data;
// const __m512d ir = swap().data;
// const __m512d add = _mm512_add_pd(rr, ir);
// const __m512d sub = _mm512_sub_pd(rr, ir);
// return _mm512_mul_pd(add, _mm512_blend_pd(sub, data, 0b1010));
// }
Complex4 operator+(Complex4 input) const
{
return _mm512_add_pd(data, input.data);
}
Complex4 operator-(Complex4 input) const
{
return _mm512_sub_pd(data, input.data);
}
Complex4 operator*(Complex4 input) const
{
const __m512d a_rr = all_real().data;
const __m512d a_ii = all_imag().data;
const __m512d b_ir = input.swap().data;
return _mm512_fmaddsub_pd(a_rr, input.data, _mm512_mul_pd(a_ii, b_ir));
}
Complex4 operator/(Complex4 input) const
{
return _mm512_div_pd(data, input.data);
}
};
}
#endif
using namespace hint_simd;
#include<iostream>
int main()
{
Complex ary[4]={1,2,3,4};
Complex4 a = ary;
a = a * a;
a.store(ary);
for(auto i:ary)
{
std::cout<<i;
}
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 32.2 us | 32 KB | Runtime Error | Score: 0 | 显示更多 |