#include <iostream>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC optimize("inline")
// #pragma GCC optimize("Ofast")
// 多重分块,利用L1,L2,L3
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
enum
{
row1 = L,
col1 = M,
row2 = M,
col2 = N
};
alignas(4096) double PACKED_B[M * N]{};
void pack_b(const double *B, double *pack_ptr, size_t rows, size_t cols)
{
#pragma GCC unroll 8
for (size_t i = 0; i < rows; i++)
{
std::memcpy(pack_ptr + i * cols, B + i * col2, cols * sizeof(double));
}
}
template <int BLOCK>
static void add_dot_4x12(const double *A, const double *B, double *C)
{
__m256d C0 = _mm256_load_pd(C);
__m256d C1 = _mm256_load_pd(C + 4);
__m256d C2 = _mm256_load_pd(C + 8);
__m256d C3 = _mm256_load_pd(C + col2);
__m256d C4 = _mm256_load_pd(C + col2 + 4);
__m256d C5 = _mm256_load_pd(C + col2 + 8);
__m256d C6 = _mm256_load_pd(C + col2 * 2);
__m256d C7 = _mm256_load_pd(C + col2 * 2 + 4);
__m256d C8 = _mm256_load_pd(C + col2 * 2 + 8);
__m256d C9 = _mm256_load_pd(C + col2 * 3);
__m256d C10 = _mm256_load_pd(C + col2 * 3 + 4);
__m256d C11 = _mm256_load_pd(C + col2 * 3 + 8);
#pragma GCC unroll 8
for (int i = 0; i < BLOCK; i++)
{
const double *b_ptr = B + i * 12;
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d B2 = _mm256_load_pd(b_ptr + 8);
__m256d B3 = _mm256_load_pd(b_ptr + 12);
__m256d A0 = _mm256_set1_pd(A[i]);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
C2 = _mm256_fmadd_pd(A0, B2, C2);
A0 = _mm256_set1_pd(A[i + col1]);
C3 = _mm256_fmadd_pd(A0, B0, C3);
C4 = _mm256_fmadd_pd(A0, B1, C4);
C5 = _mm256_fmadd_pd(A0, B2, C5);
A0 = _mm256_set1_pd(A[i + col1 * 2]);
C6 = _mm256_fmadd_pd(A0, B0, C6);
C7 = _mm256_fmadd_pd(A0, B1, C7);
C8 = _mm256_fmadd_pd(A0, B2, C8);
A0 = _mm256_set1_pd(A[i + col1 * 3]);
C9 = _mm256_fmadd_pd(A0, B0, C9);
C10 = _mm256_fmadd_pd(A0, B1, C10);
C11 = _mm256_fmadd_pd(A0, B2, C11);
}
_mm256_store_pd(C, C0);
_mm256_store_pd(C + 4, C1);
_mm256_store_pd(C + 8, C2);
_mm256_store_pd(C + col2, C3);
_mm256_store_pd(C + col2 + 4, C4);
_mm256_store_pd(C + col2 + 8, C5);
_mm256_store_pd(C + col2 * 2, C6);
_mm256_store_pd(C + col2 * 2 + 4, C7);
_mm256_store_pd(C + col2 * 2 + 8, C8);
_mm256_store_pd(C + col2 * 3, C9);
_mm256_store_pd(C + col2 * 3 + 4, C10);
_mm256_store_pd(C + col2 * 3 + 8, C11);
}
template <int BLOCK>
static void add_dot_4x8(const double *A, const double *B, double *C)
{
__m256d C0 = _mm256_load_pd(C);
__m256d C1 = _mm256_load_pd(C + 4);
__m256d C2 = _mm256_load_pd(C + col2);
__m256d C3 = _mm256_load_pd(C + col2 + 4);
__m256d C4 = _mm256_load_pd(C + col2 * 2);
__m256d C5 = _mm256_load_pd(C + col2 * 2 + 4);
__m256d C6 = _mm256_load_pd(C + col2 * 3);
__m256d C7 = _mm256_load_pd(C + col2 * 3 + 4);
#pragma GCC unroll 8
for (int i = 0; i < BLOCK; i++)
{
const double *b_ptr = B + i * 8;
__m256d B0 = _mm256_load_pd(b_ptr);
__m256d B1 = _mm256_load_pd(b_ptr + 4);
__m256d A0 = _mm256_set1_pd(A[i]);
C0 = _mm256_fmadd_pd(A0, B0, C0);
C1 = _mm256_fmadd_pd(A0, B1, C1);
__m256d A1 = _mm256_set1_pd(A[i + col1]);
C2 = _mm256_fmadd_pd(A1, B0, C2);
C3 = _mm256_fmadd_pd(A1, B1, C3);
__m256d A2 = _mm256_set1_pd(A[i + col1 * 2]);
C4 = _mm256_fmadd_pd(A2, B0, C4);
C5 = _mm256_fmadd_pd(A2, B1, C5);
__m256d A3 = _mm256_set1_pd(A[i + col1 * 3]);
C6 = _mm256_fmadd_pd(A3, B0, C6);
C7 = _mm256_fmadd_pd(A3, B1, C7);
}
_mm256_store_pd(C, C0);
_mm256_store_pd(C + 4, C1);
_mm256_store_pd(C + col2, C2);
_mm256_store_pd(C + col2 + 4, C3);
_mm256_store_pd(C + col2 * 2, C4);
_mm256_store_pd(C + col2 * 2 + 4, C5);
_mm256_store_pd(C + col2 * 3, C6);
_mm256_store_pd(C + col2 * 3 + 4, C7);
}
template <size_t BLOCK_A, size_t BLOCK_B>
static void matmulAxBx32(const double *A, const double *B, double *C)
{
for (size_t i = 0; i < BLOCK_A; i += 4)
{
add_dot_4x8<BLOCK_B>(A + i * col1, B, C + i * col2);
add_dot_4x12<BLOCK_B>(A + i * col1, B + BLOCK_B * 8, C + i * col2 + 8);
add_dot_4x12<BLOCK_B>(A + i * col1, B + BLOCK_B * 20, C + i * col2 + 20);
}
}
// BLK_AxBLK_BxBLK_C L2L3
template <size_t BLOCK_A, size_t BLOCK_B, size_t BLOCK_C>
void matmulBLKABC_x32(const double *A, const double *B, double *C, size_t row, size_t b_col)
{
double *pack_b_ptr = PACKED_B + b_col * col2;
for (size_t i = 0; i < BLOCK_C; i += 32)
{
if (row == 0)
{
pack_b(B + i, pack_b_ptr + i * BLOCK_B, BLOCK_B, 8);
pack_b(B + i + 8, pack_b_ptr + i * BLOCK_B + BLOCK_B * 8, BLOCK_B, 12);
pack_b(B + i + 20, pack_b_ptr + i * BLOCK_B + BLOCK_B * 20, BLOCK_B, 12);
}
matmulAxBx32<BLOCK_A, BLOCK_B>(A, pack_b_ptr + i * BLOCK_B, C + i);
}
}
// Mem L3
void matmul(const double *A, const double *B, double *C)
{
constexpr size_t BLOCK_C = col2;
constexpr size_t BLOCK_B = (16 << 10) / (32 * 8);
constexpr size_t BLOCK_A = (256 << 10) / (BLOCK_B * 8);
for (int i = 0; i < col1; i += BLOCK_B)
{
for (int j = 0; j < row1; j += BLOCK_A)
{
matmulBLKABC_x32<BLOCK_A, BLOCK_B, BLOCK_C>(A + j * col1 + i, B + i * col2, C + j * col2, j, i);
}
}
}
};
static MatMulKernel<1024, 1024, 1024> mm;
void matrix_multiply(int n, const double *A, const double *B, double *C)
{
mm.matmul(A, B, C);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 64.162 ms | 16 MB + 40 KB | Accepted | Score: 100 | 显示更多 |