#include <iostream>
#include <chrono>
#include <cstdint>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
using namespace std;
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
enum
{
row1 = L,
col1 = M,
row2 = M,
col2 = N
};
static void matmul2x4x16(const double *A, const double *B, double *C)
{
__m256d C0 = _mm256_load_pd(C);
__m256d C1 = _mm256_load_pd(C + 4);
__m256d C2 = _mm256_load_pd(C + 8);
__m256d C3 = _mm256_load_pd(C + 12);
__m256d C4 = _mm256_load_pd(C + col2);
__m256d C5 = _mm256_load_pd(C + col2 + 4);
__m256d C6 = _mm256_load_pd(C + col2 + 8);
__m256d C7 = _mm256_load_pd(C + col2 + 12);
{
const double *b_ptr = B;
__m256d A0 = _mm256_set1_pd(A[0]);
__m256d A1 = _mm256_set1_pd(A[col1]);
__m256d A2 = _mm256_set1_pd(A[1]);
__m256d A3 = _mm256_set1_pd(A[col1 + 1]);
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d B2 = _mm256_load_pd(b_ptr + 8);
__m256d B3 = _mm256_load_pd(b_ptr + 12);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
C3 = _mm256_fmadd_pd(A0, B3, C3);
C4 = _mm256_fmadd_pd(A1, B0, C4);
C5 = _mm256_fmadd_pd(A1, B1, C5);
C6 = _mm256_fmadd_pd(A1, B2, C6);
C7 = _mm256_fmadd_pd(A1, B3, C7);
B0 = _mm256_load_pd(B + col2);
B1 = _mm256_load_pd(B + col2 + 4);
B2 = _mm256_load_pd(B + col2 + 8);
B3 = _mm256_load_pd(B + col2 + 12);
C0 = _mm256_fmadd_pd(A2, B0, C0);
C1 = _mm256_fmadd_pd(A2, B1, C1);
C2 = _mm256_fmadd_pd(A2, B2, C2);
C3 = _mm256_fmadd_pd(A2, B3, C3);
C4 = _mm256_fmadd_pd(A3, B0, C4);
C5 = _mm256_fmadd_pd(A3, B1, C5);
C6 = _mm256_fmadd_pd(A3, B2, C6);
C7 = _mm256_fmadd_pd(A3, B3, C7);
A0 = _mm256_set1_pd(A[2]);
A1 = _mm256_set1_pd(A[col1 + 2]);
A2 = _mm256_set1_pd(A[3]);
A3 = _mm256_set1_pd(A[col1 + 3]);
B0 = _mm256_load_pd(B + col2 * 2);
B1 = _mm256_load_pd(B + col2 * 2 + 4);
B2 = _mm256_load_pd(B + col2 * 2 + 8);
B3 = _mm256_load_pd(B + col2 * 2 + 12);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
C3 = _mm256_fmadd_pd(A0, B3, C3);
C4 = _mm256_fmadd_pd(A1, B0, C4);
C5 = _mm256_fmadd_pd(A1, B1, C5);
C6 = _mm256_fmadd_pd(A1, B2, C6);
C7 = _mm256_fmadd_pd(A1, B3, C7);
B0 = _mm256_load_pd(B + col2 * 3);
B1 = _mm256_load_pd(B + col2 * 3 + 4);
B2 = _mm256_load_pd(B + col2 * 3 + 8);
B3 = _mm256_load_pd(B + col2 * 3 + 12);
C0 = _mm256_fmadd_pd(A2, B0, C0);
C1 = _mm256_fmadd_pd(A2, B1, C1);
C2 = _mm256_fmadd_pd(A2, B2, C2);
C3 = _mm256_fmadd_pd(A2, B3, C3);
C4 = _mm256_fmadd_pd(A3, B0, C4);
C5 = _mm256_fmadd_pd(A3, B1, C5);
C6 = _mm256_fmadd_pd(A3, B2, C6);
C7 = _mm256_fmadd_pd(A3, B3, C7);
}
_mm256_store_pd(C, C0);
_mm256_store_pd(C + 4, C1);
_mm256_store_pd(C + 8, C2);
_mm256_store_pd(C + 12, C3);
_mm256_store_pd(C + col2, C4);
_mm256_store_pd(C + col2 + 4, C5);
_mm256_store_pd(C + col2 + 8, C6);
_mm256_store_pd(C + col2 + 12, C7);
}
template <int L1, int M1, int N1>
static void matmul_template(const double *A, const double *B, double *C)
{
static_assert(L1 % 2 == 0);
static_assert(M1 % 4 == 0);
static_assert(N1 % 64 == 0);
for (int i = 0; i < L1; i += 2)
{
auto C_ptr = C + i * col2;
for (int j = 0; j < M1; j += 4)
{
auto A_ptr = A + i * col1 + j;
auto B_ptr = B + j * col2;
for (int k = 0; k < N1; k += 64)
{
matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k);
matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k + 16);
matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k + 32);
matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k + 48);
}
}
}
}
static void matmul(const double *A, const double *B, double *C)
{
constexpr int L_step = 16;
constexpr int M_step = 16;
constexpr int N_step = 512;
static_assert(row1 % L_step == 0);
static_assert(col1 % M_step == 0);
static_assert(col2 % N_step == 0);
for (int i = 0; i < row1; i += L_step)
{
auto C_ptr = C + i * col2;
for (int j = 0; j < col1; j += M_step)
{
auto A_ptr = A + i * col1 + j;
auto B_ptr = B + j * col2;
for (int k = 0; k < col2; k += N_step)
{
matmul_template<L_step, M_step, N_step>(A_ptr, B_ptr + k, C_ptr + k);
}
}
}
}
};
void matrix_multiply(int n, const double *A, const double *B, double *C) {
MatMulKernel<1024,1024,1024>::matmul(A,B,C);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 86.223 ms | 8 MB + 32 KB | Wrong Answer | Score: 0 | 显示更多 |