#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
using namespace std;
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
enum
{
row1 = L,
col1 = M,
row2 = M,
col2 = N
};
static_assert(col2 % 16 == 0);
static void matmulLxMx16(const double *A, const double *B, double *C)
{
for (size_t i = 0; i < row1; i += 2)
{
double *c_ptr0 = C + i * col2;
double *c_ptr1 = C + i * col2 + col2;
__m256d C0 = _mm256_load_pd(c_ptr0);
__m256d C1 = _mm256_load_pd(c_ptr0 + 4);
__m256d C2 = _mm256_load_pd(c_ptr0 + 8);
__m256d C3 = _mm256_load_pd(c_ptr0 + 12);
__m256d C4 = _mm256_load_pd(c_ptr1);
__m256d C5 = _mm256_load_pd(c_ptr1 + 4);
__m256d C6 = _mm256_load_pd(c_ptr1 + 8);
__m256d C7 = _mm256_load_pd(c_ptr1 + 12);
for (size_t j = 0; j < col1; j++)
{
const double *b_ptr = B + j * col2;
__m256d A_ij0 = _mm256_set1_pd(A[i * col1 + j]);
__m256d A_ij1 = _mm256_set1_pd(A[(i + 1) * col1 + j]);
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d B2 = _mm256_load_pd(b_ptr + 8);
__m256d B3 = _mm256_load_pd(b_ptr + 12);
C0 = _mm256_fmadd_pd(A_ij0, B0, C0);
C1 = _mm256_fmadd_pd(A_ij0, B1, C1);
C2 = _mm256_fmadd_pd(A_ij0, B2, C2);
C3 = _mm256_fmadd_pd(A_ij0, B3, C3);
C4 = _mm256_fmadd_pd(A_ij1, B0, C4);
C5 = _mm256_fmadd_pd(A_ij1, B1, C5);
C6 = _mm256_fmadd_pd(A_ij1, B2, C6);
C7 = _mm256_fmadd_pd(A_ij1, B3, C7);
}
_mm256_store_pd(c_ptr0, C0);
_mm256_store_pd(c_ptr0 + 4, C1);
_mm256_store_pd(c_ptr0 + 8, C2);
_mm256_store_pd(c_ptr0 + 12, C3);
_mm256_store_pd(c_ptr1, C4);
_mm256_store_pd(c_ptr1 + 4, C5);
_mm256_store_pd(c_ptr1 + 8, C6);
_mm256_store_pd(c_ptr1 + 12, C7);
}
}
template <typename Ty>
static void matmul(const Ty *A, const Ty *B, Ty *C)
{
for (size_t i = 0; i < col2; i += 16)
{
matmulLxMx16(A, B + i, C + i);
}
}
};
void matrix_multiply(int n, const double *A, const double *B, double *C) {
MatMulKernel<1024,1024,1024>::matmul(A,B,C);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 437.463 ms | 8 MB + 8 KB | Accepted | Score: 100 | 显示更多 |