提交记录 20122


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 147.246 ms 8224 KB C++14 4.32 KB
提交时间 评测时间
2023-09-06 23:21:56 2023-09-06 23:22:00
#include <iostream>
#include <chrono>
#include <cstdint>
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
using namespace std;
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
    enum
    {
        row1 = L,
        col1 = M,
        row2 = M,
        col2 = N
    };
    static_assert(col2 % 16 == 0);
    static void matmul2x4x16(const double *A, const double *B, double *C)
    {
        __m256d C0 = _mm256_load_pd(C);
        __m256d C1 = _mm256_load_pd(C + 4);
        __m256d C2 = _mm256_load_pd(C + 8);
        __m256d C3 = _mm256_load_pd(C + 12);
        __m256d C4 = _mm256_load_pd(C + col2);
        __m256d C5 = _mm256_load_pd(C + col2 + 4);
        __m256d C6 = _mm256_load_pd(C + col2 + 8);
        __m256d C7 = _mm256_load_pd(C + col2 + 12);
        {
            const double *b_ptr = B;
            __m256d A0 = _mm256_set1_pd(A[0]);
            __m256d A1 = _mm256_set1_pd(A[col1]);
            __m256d A2 = _mm256_set1_pd(A[1]);
            __m256d A3 = _mm256_set1_pd(A[col1 + 1]);
            __m256d B0 = _mm256_load_pd(b_ptr);
            __m256d B1 = _mm256_load_pd(b_ptr + 4);
            __m256d B2 = _mm256_load_pd(b_ptr + 8);
            __m256d B3 = _mm256_load_pd(b_ptr + 12);
            C0 = _mm256_fmadd_pd(A0, B0, C0);
            C1 = _mm256_fmadd_pd(A0, B1, C1);
            C2 = _mm256_fmadd_pd(A0, B2, C2);
            C3 = _mm256_fmadd_pd(A0, B3, C3);
            C4 = _mm256_fmadd_pd(A1, B0, C4);
            C5 = _mm256_fmadd_pd(A1, B1, C5);
            C6 = _mm256_fmadd_pd(A1, B2, C6);
            C7 = _mm256_fmadd_pd(A1, B3, C7);
            B0 = _mm256_load_pd(B + col2);
            B1 = _mm256_load_pd(B + col2 + 4);
            B2 = _mm256_load_pd(B + col2 + 8);
            B3 = _mm256_load_pd(B + col2 + 12);
            C0 = _mm256_fmadd_pd(A2, B0, C0);
            C1 = _mm256_fmadd_pd(A2, B1, C1);
            C2 = _mm256_fmadd_pd(A2, B2, C2);
            C3 = _mm256_fmadd_pd(A2, B3, C3);
            C4 = _mm256_fmadd_pd(A3, B0, C4);
            C5 = _mm256_fmadd_pd(A3, B1, C5);
            C6 = _mm256_fmadd_pd(A3, B2, C6);
            C7 = _mm256_fmadd_pd(A3, B3, C7);

            A0 = _mm256_set1_pd(A[2]);
            A1 = _mm256_set1_pd(A[col1 + 2]);
            A2 = _mm256_set1_pd(A[3]);
            A3 = _mm256_set1_pd(A[col1 + 3]);
            B0 = _mm256_load_pd(B + col2 * 2);
            B1 = _mm256_load_pd(B + col2 * 2 + 4);
            B2 = _mm256_load_pd(B + col2 * 2 + 8);
            B3 = _mm256_load_pd(B + col2 * 2 + 12);
            C0 = _mm256_fmadd_pd(A0, B0, C0);
            C1 = _mm256_fmadd_pd(A0, B1, C1);
            C2 = _mm256_fmadd_pd(A0, B2, C2);
            C3 = _mm256_fmadd_pd(A0, B3, C3);
            C4 = _mm256_fmadd_pd(A1, B0, C4);
            C5 = _mm256_fmadd_pd(A1, B1, C5);
            C6 = _mm256_fmadd_pd(A1, B2, C6);
            C7 = _mm256_fmadd_pd(A1, B3, C7);
            B0 = _mm256_load_pd(B + col2 * 3);
            B1 = _mm256_load_pd(B + col2 * 3 + 4);
            B2 = _mm256_load_pd(B + col2 * 3 + 8);
            B3 = _mm256_load_pd(B + col2 * 3 + 12);
            C0 = _mm256_fmadd_pd(A2, B0, C0);
            C1 = _mm256_fmadd_pd(A2, B1, C1);
            C2 = _mm256_fmadd_pd(A2, B2, C2);
            C3 = _mm256_fmadd_pd(A2, B3, C3);
            C4 = _mm256_fmadd_pd(A3, B0, C4);
            C5 = _mm256_fmadd_pd(A3, B1, C5);
            C6 = _mm256_fmadd_pd(A3, B2, C6);
            C7 = _mm256_fmadd_pd(A3, B3, C7);
        }
        _mm256_store_pd(C, C0);
        _mm256_store_pd(C + 4, C1);
        _mm256_store_pd(C + 8, C2);
        _mm256_store_pd(C + 12, C3);
        _mm256_store_pd(C + col2, C4);
        _mm256_store_pd(C + col2 + 4, C5);
        _mm256_store_pd(C + col2 + 8, C6);
        _mm256_store_pd(C + col2 + 12, C7);
    }
    static void matmul(const double *A, const double *B, double *C)
    {
        for (int i = 0; i < row1; i += 2)
        {
            auto C_ptr = C + i * col2;
            for (int j = 0; j < col1; j += 4)
            {
                auto A_ptr = A + i * col1 + j;
                auto B_ptr = B + j * col2;
                for (int k = 0; k < col2; k += 16)
                {
                    matmul2x4x16(A_ptr, B_ptr + k, C_ptr + k);
                }
            }
        }
    }
};
void matrix_multiply(int n, const double *A, const double *B, double *C) {
	MatMulKernel<1024,1024,1024>::matmul(A,B,C);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1147.246 ms8 MB + 32 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-13 22:26:43 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠