提交记录 20116


用户 题目 状态 得分 用时 内存 语言 代码长度
TSKY mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 437.463 ms 8200 KB C++14 2.48 KB
提交时间 评测时间
2023-09-05 21:48:50 2023-09-05 21:48:54
#include <immintrin.h>
#pragma GCC target("fma")
#pragma GCC target("avx2")
using namespace std;
template <size_t L, size_t M, size_t N>
struct MatMulKernel
{
    enum
    {
        row1 = L,
        col1 = M,
        row2 = M,
        col2 = N
    };
    static_assert(col2 % 16 == 0);
    static void matmulLxMx16(const double *A, const double *B, double *C)
    {
        for (size_t i = 0; i < row1; i += 2)
        {
            double *c_ptr0 = C + i * col2;
            double *c_ptr1 = C + i * col2 + col2;
            __m256d C0 = _mm256_load_pd(c_ptr0);
            __m256d C1 = _mm256_load_pd(c_ptr0 + 4);
            __m256d C2 = _mm256_load_pd(c_ptr0 + 8);
            __m256d C3 = _mm256_load_pd(c_ptr0 + 12);
            __m256d C4 = _mm256_load_pd(c_ptr1);
            __m256d C5 = _mm256_load_pd(c_ptr1 + 4);
            __m256d C6 = _mm256_load_pd(c_ptr1 + 8);
            __m256d C7 = _mm256_load_pd(c_ptr1 + 12);
            for (size_t j = 0; j < col1; j++)
            {
                const double *b_ptr = B + j * col2;
                __m256d A_ij0 = _mm256_set1_pd(A[i * col1 + j]);
                __m256d A_ij1 = _mm256_set1_pd(A[(i + 1) * col1 + j]);
                __m256d B0 = _mm256_load_pd(b_ptr);
                __m256d B1 = _mm256_load_pd(b_ptr + 4);
                __m256d B2 = _mm256_load_pd(b_ptr + 8);
                __m256d B3 = _mm256_load_pd(b_ptr + 12);
                C0 = _mm256_fmadd_pd(A_ij0, B0, C0);
                C1 = _mm256_fmadd_pd(A_ij0, B1, C1);
                C2 = _mm256_fmadd_pd(A_ij0, B2, C2);
                C3 = _mm256_fmadd_pd(A_ij0, B3, C3);
                C4 = _mm256_fmadd_pd(A_ij1, B0, C4);
                C5 = _mm256_fmadd_pd(A_ij1, B1, C5);
                C6 = _mm256_fmadd_pd(A_ij1, B2, C6);
                C7 = _mm256_fmadd_pd(A_ij1, B3, C7);
            }
            _mm256_store_pd(c_ptr0, C0);
            _mm256_store_pd(c_ptr0 + 4, C1);
            _mm256_store_pd(c_ptr0 + 8, C2);
            _mm256_store_pd(c_ptr0 + 12, C3);
            _mm256_store_pd(c_ptr1, C4);
            _mm256_store_pd(c_ptr1 + 4, C5);
            _mm256_store_pd(c_ptr1 + 8, C6);
            _mm256_store_pd(c_ptr1 + 12, C7);
        }
    }
    template <typename Ty>
    static void matmul(const Ty *A, const Ty *B, Ty *C)
    {
        for (size_t i = 0; i < col2; i += 16)
        {
            matmulLxMx16(A, B + i, C + i);
        }
    }
};
void matrix_multiply(int n, const double *A, const double *B, double *C) {
	MatMulKernel<1024,1024,1024>::matmul(A,B,C);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1437.463 ms8 MB + 8 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2025-09-14 00:08:13 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠