提交记录 20442


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 46.396 ms 16424 KB C++14 6.27 KB
提交时间 评测时间
2023-10-16 19:31:19 2023-10-16 19:31:22
#include <iostream>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC optimize("inline")
// #pragma GCC optimize("Ofast")

// 多重分块,利用L1,L2,L3
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
    enum
    {
        row1 = L,
        col1 = M,
        row2 = M,
        col2 = N
    };
    alignas(4096) double PACKED_B[M * N]{};
    void pack_b(const double *B, double *pack_ptr, size_t rows, size_t cols)
    {
#pragma GCC unroll 8
        for (size_t i = 0; i < rows; i++)
        {
            std::memcpy(pack_ptr + i * cols, B + i * col2, cols * sizeof(double));
        }
    }
    template <int BLOCK>
    static void add_dot_4x12(const double *A, const double *B, double *C)
    {
        __m256d C0 = _mm256_load_pd(C);
        __m256d C1 = _mm256_load_pd(C + 4);
        __m256d C2 = _mm256_load_pd(C + 8);

        __m256d C3 = _mm256_load_pd(C + col2);
        __m256d C4 = _mm256_load_pd(C + col2 + 4);
        __m256d C5 = _mm256_load_pd(C + col2 + 8);

        __m256d C6 = _mm256_load_pd(C + col2 * 2);
        __m256d C7 = _mm256_load_pd(C + col2 * 2 + 4);
        __m256d C8 = _mm256_load_pd(C + col2 * 2 + 8);

        __m256d C9 = _mm256_load_pd(C + col2 * 3);
        __m256d C10 = _mm256_load_pd(C + col2 * 3 + 4);
        __m256d C11 = _mm256_load_pd(C + col2 * 3 + 8);
#pragma GCC unroll 8
        for (int i = 0; i < BLOCK; i++)
        {
            const double *b_ptr = B + i * 12;
            __m256d B0 = _mm256_load_pd(b_ptr);
            __m256d B1 = _mm256_load_pd(b_ptr + 4);
            __m256d B2 = _mm256_load_pd(b_ptr + 8);
            __m256d B3 = _mm256_load_pd(b_ptr + 12);

            __m256d A0 = _mm256_set1_pd(A[i]);
            C0 = _mm256_fmadd_pd(A0, B0, C0);
            C1 = _mm256_fmadd_pd(A0, B1, C1);
            C2 = _mm256_fmadd_pd(A0, B2, C2);

            A0 = _mm256_set1_pd(A[i + col1]);
            C3 = _mm256_fmadd_pd(A0, B0, C3);
            C4 = _mm256_fmadd_pd(A0, B1, C4);
            C5 = _mm256_fmadd_pd(A0, B2, C5);

            A0 = _mm256_set1_pd(A[i + col1 * 2]);
            C6 = _mm256_fmadd_pd(A0, B0, C6);
            C7 = _mm256_fmadd_pd(A0, B1, C7);
            C8 = _mm256_fmadd_pd(A0, B2, C8);

            A0 = _mm256_set1_pd(A[i + col1 * 3]);
            C9 = _mm256_fmadd_pd(A0, B0, C9);
            C10 = _mm256_fmadd_pd(A0, B1, C10);
            C11 = _mm256_fmadd_pd(A0, B2, C11);
        }
        _mm256_store_pd(C, C0);
        _mm256_store_pd(C + 4, C1);
        _mm256_store_pd(C + 8, C2);

        _mm256_store_pd(C + col2, C3);
        _mm256_store_pd(C + col2 + 4, C4);
        _mm256_store_pd(C + col2 + 8, C5);

        _mm256_store_pd(C + col2 * 2, C6);
        _mm256_store_pd(C + col2 * 2 + 4, C7);
        _mm256_store_pd(C + col2 * 2 + 8, C8);

        _mm256_store_pd(C + col2 * 3, C9);
        _mm256_store_pd(C + col2 * 3 + 4, C10);
        _mm256_store_pd(C + col2 * 3 + 8, C11);
    }
    template <int BLOCK>
    static void add_dot_4x8(const double *A, const double *B, double *C)
    {
        __m256d C0 = _mm256_load_pd(C);
        __m256d C1 = _mm256_load_pd(C + 4);
        __m256d C2 = _mm256_load_pd(C + col2);
        __m256d C3 = _mm256_load_pd(C + col2 + 4);
        __m256d C4 = _mm256_load_pd(C + col2 * 2);
        __m256d C5 = _mm256_load_pd(C + col2 * 2 + 4);
        __m256d C6 = _mm256_load_pd(C + col2 * 3);
        __m256d C7 = _mm256_load_pd(C + col2 * 3 + 4);
#pragma GCC unroll 8
        for (int i = 0; i < BLOCK; i++)
        {
            const double *b_ptr = B + i * 8;
            __m256d B0 = _mm256_load_pd(b_ptr);
            __m256d B1 = _mm256_load_pd(b_ptr + 4);

            __m256d A0 = _mm256_set1_pd(A[i]);
            C0 = _mm256_fmadd_pd(A0, B0, C0);
            C1 = _mm256_fmadd_pd(A0, B1, C1);

            __m256d A1 = _mm256_set1_pd(A[i + col1]);
            C2 = _mm256_fmadd_pd(A1, B0, C2);
            C3 = _mm256_fmadd_pd(A1, B1, C3);

            __m256d A2 = _mm256_set1_pd(A[i + col1 * 2]);
            C4 = _mm256_fmadd_pd(A2, B0, C4);
            C5 = _mm256_fmadd_pd(A2, B1, C5);

            __m256d A3 = _mm256_set1_pd(A[i + col1 * 3]);
            C6 = _mm256_fmadd_pd(A3, B0, C6);
            C7 = _mm256_fmadd_pd(A3, B1, C7);
        }
        _mm256_store_pd(C, C0);
        _mm256_store_pd(C + 4, C1);
        _mm256_store_pd(C + col2, C2);
        _mm256_store_pd(C + col2 + 4, C3);
        _mm256_store_pd(C + col2 * 2, C4);
        _mm256_store_pd(C + col2 * 2 + 4, C5);
        _mm256_store_pd(C + col2 * 3, C6);
        _mm256_store_pd(C + col2 * 3 + 4, C7);
    }
    template <size_t BLOCK_A, size_t BLOCK_B>
    static void matmulAxBx32(const double *A, const double *B, double *C)
    {
        for (size_t i = 0; i < BLOCK_A; i += 4)
        {
            add_dot_4x12<BLOCK_B>(A + i * col1, B, C + i * col2);
            add_dot_4x12<BLOCK_B>(A + i * col1, B + BLOCK_B * 12, C + i * col2 + 12);
            add_dot_4x8<BLOCK_B>(A + i * col1, B + BLOCK_B * 24, C + i * col2 + 24);
        }
    }
    // BLK_AxBLK_BxBLK_C L2L3
    template <size_t BLOCK_A, size_t BLOCK_B, size_t BLOCK_C>
    void matmulBLKABC_x32(const double *A, const double *B, double *C, size_t row, size_t b_col)
    {
        double *pack_b_ptr = PACKED_B + b_col * col2;
        for (size_t i = 0; i < BLOCK_C; i += 32)
        {
            if (row == 0)
            {
                pack_b(B + i, pack_b_ptr + i * BLOCK_B, BLOCK_B, 12);
                pack_b(B + i + 12, pack_b_ptr + i * BLOCK_B + BLOCK_B * 12, BLOCK_B, 12);
                pack_b(B + i + 24, pack_b_ptr + i * BLOCK_B + BLOCK_B * 24, BLOCK_B, 8);
            }
            matmulAxBx32<BLOCK_A, BLOCK_B>(A, pack_b_ptr + i * BLOCK_B, C + i);
        }
    }
    // Mem L3
    void matmul(const double *A, const double *B, double *C)
    {
        constexpr size_t BLOCK_C = col2;
        constexpr size_t BLOCK_B = (32 << 10) / (32 * 8);
        constexpr size_t BLOCK_A = (128 << 10) / (BLOCK_B * 8);
        for (int j = 0; j < row1; j += BLOCK_A)
        {
            for (int i = 0; i < col1; i += BLOCK_B)
            {
                matmulBLKABC_x32<BLOCK_A, BLOCK_B, BLOCK_C>(A + j * col1 + i, B + i * col2, C + j * col2, j, i);
            }
        }
    }
};
static MatMulKernel<1024, 1024, 1024> mm;
void matrix_multiply(int n, const double *A, const double *B, double *C)
{
    mm.matmul(A, B, C);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #146.396 ms16 MB + 40 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-01-18 21:14:37 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠