提交记录 20196


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 50.296 ms 16424 KB C++14 3.60 KB
提交时间 评测时间
2023-09-17 23:07:53 2023-09-17 23:07:56
#include <iostream>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC optimize("inline")
#pragma GCC optimize("O3")

template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
    enum
    {
        row1 = L,
        col1 = M,
        row2 = M,
        col2 = N
    };
    alignas(4096) double PACKED_B[M * N]{};
    void pack_b(const double *B, double *pack_ptr, size_t rows, size_t cols)
    {
        for (size_t i = 0; i < rows; i++)
        {
            std::memcpy(pack_ptr + i * cols, B + i * col2, cols * sizeof(double));
        }
    }
    template <int BLOCK>
    static void add_dot_2x16(const double *A, const double *B, double *C)
    {
        __m256d C0 = _mm256_load_pd(C);
        __m256d C1 = _mm256_load_pd(C + 4);
        __m256d C2 = _mm256_load_pd(C + 8);
        __m256d C3 = _mm256_load_pd(C + 12);
        __m256d C4 = _mm256_load_pd(C + col2);
        __m256d C5 = _mm256_load_pd(C + col2 + 4);
        __m256d C6 = _mm256_load_pd(C + col2 + 8);
        __m256d C7 = _mm256_load_pd(C + col2 + 12);
#pragma GCC unroll 8
        for (int i = 0; i < BLOCK; i++)
        {
            const double *b_ptr = B + i * 16;
            __m256d B0 = _mm256_load_pd(b_ptr);
            __m256d B1 = _mm256_load_pd(b_ptr + 4);
            __m256d B2 = _mm256_load_pd(b_ptr + 8);
            __m256d B3 = _mm256_load_pd(b_ptr + 12);

            __m256d A0 = _mm256_set1_pd(A[i]);
            C0 = _mm256_fmadd_pd(A0, B0, C0);
            C1 = _mm256_fmadd_pd(A0, B1, C1);
            C2 = _mm256_fmadd_pd(A0, B2, C2);
            C3 = _mm256_fmadd_pd(A0, B3, C3);

            __m256d A1 = _mm256_set1_pd(A[i + col1]);
            C4 = _mm256_fmadd_pd(A1, B0, C4);
            C5 = _mm256_fmadd_pd(A1, B1, C5);
            C6 = _mm256_fmadd_pd(A1, B2, C6);
            C7 = _mm256_fmadd_pd(A1, B3, C7);
        }
        _mm256_store_pd(C, C0);
        _mm256_store_pd(C + 4, C1);
        _mm256_store_pd(C + 8, C2);
        _mm256_store_pd(C + 12, C3);
        _mm256_store_pd(C + col2, C4);
        _mm256_store_pd(C + col2 + 4, C5);
        _mm256_store_pd(C + col2 + 8, C6);
        _mm256_store_pd(C + col2 + 12, C7);
    }
    // BLOCK_AxBLOCK_Bx16 L1L2
    template <size_t BLOCK_A, size_t BLOCK_B>
    static void matmulAxBx16(const double *A, const double *B, double *C)
    {
        for (size_t i = 0; i < BLOCK_A; i += 2)
        {
            add_dot_2x16<BLOCK_B>(A + i * col1, B, C + i * col2);
        }
    }
    // 128x256xBLOCK_C L2L3
    template <size_t BLOCK_A, size_t BLOCK_B, size_t BLOCK_C>
    void matmul128x256xBLK_C(const double *A, const double *B, double *C, size_t row, size_t b_col)
    {
        double *pack_b_ptr = PACKED_B + b_col * col2;
        for (size_t i = 0; i < BLOCK_C; i += 16)
        {
            if (row == 0)
            {
                pack_b(B + i, pack_b_ptr + i * BLOCK_B, BLOCK_B, 16);
            }
            matmulAxBx16<BLOCK_A, BLOCK_B>(A, pack_b_ptr + i * BLOCK_B, C + i);
        }
    }
    void matmul(const double *A, const double *B, double *C)
    {
        constexpr size_t BLOCK_C = col2;
        constexpr size_t BLOCK_B = (32 << 10) / (16 * 8);
        constexpr size_t BLOCK_A = (128 << 10) / (BLOCK_B * 8);
        for (int i = 0; i < col1; i += BLOCK_B)
        {
            for (int j = 0; j < row1; j += BLOCK_A)
            {
                matmul128x256xBLK_C<BLOCK_A, BLOCK_B, BLOCK_C>(A + j * col1 + i, B + i * col2, C + j * col2, j, i);
            }
        }
    }
};
static MatMulKernel<1024, 1024, 1024> mm;
void matrix_multiply(int n, const double *A, const double *B, double *C)
{
    mm.matmul(A, B, C);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #150.296 ms16 MB + 40 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2024-07-27 15:10:38 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用