#include <iostream>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC optimize("inline")
// #pragma GCC optimize("O3")
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
enum
{
row1 = L,
col1 = M,
row2 = M,
col2 = N
};
alignas(4096) double PACKED_B[M * N]{};
void pack_b(const double *B, double *pack_ptr, size_t rows, size_t cols)
{
for (size_t i = 0; i < rows; i++)
{
std::memcpy(pack_ptr + i * cols, B + i * col2, cols * sizeof(double));
}
}
template <int BLOCK>
static void add_dot_2x16(const double *A, const double *B, double *C)
{
__m256d C0 = _mm256_load_pd(C);
__m256d C1 = _mm256_load_pd(C + 4);
__m256d C2 = _mm256_load_pd(C + 8);
__m256d C3 = _mm256_load_pd(C + 12);
__m256d C4 = _mm256_load_pd(C + col2);
__m256d C5 = _mm256_load_pd(C + col2 + 4);
__m256d C6 = _mm256_load_pd(C + col2 + 8);
__m256d C7 = _mm256_load_pd(C + col2 + 12);
#pragma GCC unroll 8
for (int i = 0; i < BLOCK; i++)
{
const double *b_ptr = B + i * 16;
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d B2 = _mm256_load_pd(b_ptr + 8);
__m256d B3 = _mm256_load_pd(b_ptr + 12);
__m256d A0 = _mm256_set1_pd(A[i]);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
C3 = _mm256_fmadd_pd(A0, B3, C3);
__m256d A1 = _mm256_set1_pd(A[i + col1]);
C4 = _mm256_fmadd_pd(A1, B0, C4);
C5 = _mm256_fmadd_pd(A1, B1, C5);
C6 = _mm256_fmadd_pd(A1, B2, C6);
C7 = _mm256_fmadd_pd(A1, B3, C7);
}
_mm256_store_pd(C, C0);
_mm256_store_pd(C + 4, C1);
_mm256_store_pd(C + 8, C2);
_mm256_store_pd(C + 12, C3);
_mm256_store_pd(C + col2, C4);
_mm256_store_pd(C + col2 + 4, C5);
_mm256_store_pd(C + col2 + 8, C6);
_mm256_store_pd(C + col2 + 12, C7);
}
// BLOCK_AxBLOCK_Bx16 L1L2
template <size_t BLOCK_A, size_t BLOCK_B>
static void matmulAxBx16(const double *A, const double *B, double *C)
{
for (size_t i = 0; i < BLOCK_A; i += 2)
{
add_dot_2x16<BLOCK_B>(A + i * col1, B, C + i * col2);
}
}
// 128x256xBLOCK_C L2L3
template <size_t BLOCK_A, size_t BLOCK_B, size_t BLOCK_C>
void matmul128x256xBLK_C(const double *A, const double *B, double *C, size_t row, size_t b_col)
{
double *pack_b_ptr = PACKED_B + b_col * col2;
for (size_t i = 0; i < BLOCK_C; i += 16)
{
if (row == 0)
{
pack_b(B + i, pack_b_ptr + i * BLOCK_B, BLOCK_B, 16);
}
matmulAxBx16<BLOCK_A, BLOCK_B>(A, pack_b_ptr + i * BLOCK_B, C + i);
}
}
void matmul(const double *A, const double *B, double *C)
{
constexpr size_t BLOCK_C = col2;
constexpr size_t BLOCK_B = (32 << 10) / (16 * 8); // L1为32KB
constexpr size_t BLOCK_A = (128 << 10) / (BLOCK_B * 8); // L2为128KB
for (int i = 0; i < col1; i += BLOCK_B)
{
for (int j = 0; j < row1; j += BLOCK_A)
{
matmul128x256xBLK_C<BLOCK_A, BLOCK_B, BLOCK_C>(A + j * col1 + i, B + i * col2, C + j * col2, j, i);
}
}
}
};
static MatMulKernel<1024, 1024, 1024> mm;
void matrix_multiply(int n, const double *A, const double *B, double *C)
{
mm.matmul(A, B, C);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 50.288 ms | 16 MB + 40 KB | Accepted | Score: 100 | 显示更多 |