#include <iostream>
#include <chrono>
#include <cstdint>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
using namespace std;
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
enum
{
row1 = L,
col1 = M,
row2 = M,
col2 = N
};
static_assert(col2 % 16 == 0);
static void matmul2x4x16(const double *A, const double *B, double *C)
{
__m256d C0 = _mm256_load_pd(C);
__m256d C1 = _mm256_load_pd(C + 4);
__m256d C2 = _mm256_load_pd(C + 8);
__m256d C3 = _mm256_load_pd(C + 12);
__m256d C4 = _mm256_load_pd(C + col2);
__m256d C5 = _mm256_load_pd(C + col2 + 4);
__m256d C6 = _mm256_load_pd(C + col2 + 8);
__m256d C7 = _mm256_load_pd(C + col2 + 12);
{
const double *b_ptr = B;
__m256d A0 = _mm256_set1_pd(A[0]);
__m256d A1 = _mm256_set1_pd(A[col1]);
__m256d A2 = _mm256_set1_pd(A[1]);
__m256d A3 = _mm256_set1_pd(A[col1 + 1]);
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d B2 = _mm256_load_pd(b_ptr + 8);
__m256d B3 = _mm256_load_pd(b_ptr + 12);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
C3 = _mm256_fmadd_pd(A0, B3, C3);
C4 = _mm256_fmadd_pd(A1, B0, C4);
C5 = _mm256_fmadd_pd(A1, B1, C5);
C6 = _mm256_fmadd_pd(A1, B2, C6);
C7 = _mm256_fmadd_pd(A1, B3, C7);
B0 = _mm256_load_pd(B + col2);
B1 = _mm256_load_pd(B + col2 + 4);
B2 = _mm256_load_pd(B + col2 + 8);
B3 = _mm256_load_pd(B + col2 + 12);
C0 = _mm256_fmadd_pd(A2, B0, C0);
C1 = _mm256_fmadd_pd(A2, B1, C1);
C2 = _mm256_fmadd_pd(A2, B2, C2);
C3 = _mm256_fmadd_pd(A2, B3, C3);
C4 = _mm256_fmadd_pd(A3, B0, C4);
C5 = _mm256_fmadd_pd(A3, B1, C5);
C6 = _mm256_fmadd_pd(A3, B2, C6);
C7 = _mm256_fmadd_pd(A3, B3, C7);
A0 = _mm256_set1_pd(A[0]);
A1 = _mm256_set1_pd(A[col1]);
A2 = _mm256_set1_pd(A[1]);
A3 = _mm256_set1_pd(A[col1 + 1]);
B0 = _mm256_load_pd(B + col2 * 2);
B1 = _mm256_load_pd(B + col2 * 2 + 4);
B2 = _mm256_load_pd(B + col2 * 2 + 8);
B3 = _mm256_load_pd(B + col2 * 2 + 12);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
C3 = _mm256_fmadd_pd(A0, B3, C3);
C4 = _mm256_fmadd_pd(A1, B0, C4);
C5 = _mm256_fmadd_pd(A1, B1, C5);
C6 = _mm256_fmadd_pd(A1, B2, C6);
C7 = _mm256_fmadd_pd(A1, B3, C7);
B0 = _mm256_load_pd(B + col2 * 3);
B1 = _mm256_load_pd(B + col2 * 3 + 4);
B2 = _mm256_load_pd(B + col2 * 3 + 8);
B3 = _mm256_load_pd(B + col2 * 3 + 12);
C0 = _mm256_fmadd_pd(A2, B0, C0);
C1 = _mm256_fmadd_pd(A2, B1, C1);
C2 = _mm256_fmadd_pd(A2, B2, C2);
C3 = _mm256_fmadd_pd(A2, B3, C3);
C4 = _mm256_fmadd_pd(A3, B0, C4);
C5 = _mm256_fmadd_pd(A3, B1, C5);
C6 = _mm256_fmadd_pd(A3, B2, C6);
C7 = _mm256_fmadd_pd(A3, B3, C7);
}
_mm256_store_pd(C, C0);
_mm256_store_pd(C + 4, C1);
_mm256_store_pd(C + 8, C2);
_mm256_store_pd(C + 12, C3);
_mm256_store_pd(C + col2, C4);
_mm256_store_pd(C + col2 + 4, C5);
_mm256_store_pd(C + col2 + 8, C6);
_mm256_store_pd(C + col2 + 12, C7);
}
static void matmul(const double *A, const double *B, double *C)
{
for (int i = 0; i < row1; i += 2)
{
auto C_ptr = C + i * col2;
for (int j = 0; j < col1; j += 4)
{
auto A_ptr = A + i * col1 + j;
auto B_ptr = B + j * col2;
for (int k = 0; k < col2; k += 16)
{
matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k);
}
}
}
}
};
void matrix_multiply(int n, const double *A, const double *B, double *C) {
MatMulKernel<1024,1024,1024>::matmul(A,B,C);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 142.896 ms | 8 MB + 32 KB | Wrong Answer | Score: 0 | 显示更多 |