提交记录 17522


用户 题目 状态 得分 用时 内存 语言 代码长度
yfzcsc mmmd1k. 测测你的双精度矩阵乘法-1k Compile Error 0 0 ns 0 KB C++ 7.10 KB
提交时间 评测时间
2022-03-22 00:21:07 2022-03-22 00:21:08
#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <assert.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define CACHELINE 128  // 128B
#define CACHESIZE 1 << 16

#define fo(i, L, R, STEP) for (int i = L, i##__end = R; i < i##__end; i += STEP)
#define fo1(i, L, R) fo(i, L, R, 1)
#define min(x, y) ((x) < (y) ? (x) : (y))
#define getR(x, STEP, MAX) min((x) + (STEP), (MAX))

#define FMLA(result, a, b, c) ((result) = _mm256_fmadd_pd((a), (b), (c)))
#define FLOAD(where) _mm256_load_pd(where)
#define FSTORE(where, what) _mm256_store_pd(where, what)
#define FBROADCAST(where) _mm256_broadcast_sd(where)
#define FBROADCAST2(where) vld2q_dup_f32(where)

void inline mcpy(double* dst, const double* src, int len) {
    memcpy(dst, src, len);
    return;
    len >>= 2;
    int i = 0;
    for (; i < len; i += 32) {
        __m256d a0 = FLOAD(src);
        __m256d a1 = FLOAD(src + 4);
        __m256d a2 = FLOAD(src + 8);
        __m256d a3 = FLOAD(src + 12);
        __m256d a4 = FLOAD(src + 16);
        __m256d a5 = FLOAD(src + 20);
        __m256d a6 = FLOAD(src + 24);
        __m256d a7 = FLOAD(src + 28);
        FSTORE(dst, a0);
        FSTORE(dst + 4, a1);
        FSTORE(dst + 8, a2);
        FSTORE(dst + 12, a3);
        FSTORE(dst + 16, a4);
        FSTORE(dst + 20, a5);
        FSTORE(dst + 24, a6);
        FSTORE(dst + 28, a7);
        src += 32;
        dst += 32;
    }
    for (; i < len; ++i) *dst++ = *src++;
}

void third_try(int n, const double* A, const double* B, double* C) {
    // char ___x;
    // char ___str[CACHELINE - ((int)(&___x) & (CACHELINE - 1))];
    double __attribute__((aligned(32))) cache[CACHESIZE];  // to make continue

    int ic = 128, kc = 128, jc = 128;
    const int id = 16, kd = 4, jd = 4;

    double* _A = cache;
    double* _B = cache + ic * kc;
    double* iC = _B + jc * kc;

    memset(C, 0, n * n * sizeof(double));
    fo(il, 0, n, ic) {
        int ir = getR(il, ic, n);
        fo(kl, 0, n, kc) {
            int kr = getR(kl, kc, n);
            // Pack B
            fo(ill, il, ir, id) {
                size_t Aoffset = kl * n + ill;
                size_t _Aoffset = (ill - il) * kc;
                size_t len = (getR(ill, id, ir) - ill) << 3;
                fo(k, kl, kr, 1) {
                    if ((id << 3) > len) memset(_A + _Aoffset, 0, id << 3);
                    mcpy(_A + _Aoffset, A + Aoffset, len);
                    Aoffset += n;
                    _Aoffset += id;
                }
                if (kr - kl < kc) {
                    memset(_A + _Aoffset, 0, (kc - kr + kl) * id << 3);
                }
            }

            fo(jl, 0, n, jc) {
                int jr = getR(jl, jc, n);
                // Pack A
                fo(jll, jl, jr, jd) {
                    size_t Boffset = jll * n + kl;
                    size_t _Boffset = (jll - jl) * kc;
                    size_t len = (kr - kl) << 3;
                    fo(j, jll, getR(jll, jd, jr), 1) {
                        if ((kc << 3) > len) memset(_B + _Boffset, 0, kc << 3);
                        mcpy(_B + _Boffset, B + Boffset, len);
                        Boffset += n;
                        _Boffset += kc;
                    }
                    if (jr - jll < jd) {
                        memset(_B + _Boffset, 0, (jd - jr + jll) * kc << 3);
                    }
                }

                fo(jll, jl, jr, jd) {
                    int jrr = getR(jll, jd, jr);
                    fo(ill, il, ir, id) {
                        int irr = getR(ill, id, ir);
                        double* C_offset = C + jll * n + ill;
#define CID(row, col) C_##row##_##col
#define LD_C(row, col) register __m256d CID(row, col) = FLOAD(C_offset + (row)*n + (col)*4)
#define ST_C(row, col) FSTORE(C_offset + (row)*n + (col)*4, CID(row, col))
#define LD_C4(row) \
    LD_C(row, 0);  \
    LD_C(row, 1);  \
    LD_C(row, 2);  \
    LD_C(row, 3);

#define AID(k, col) A_##k##_##col
#define BID(row, k) B_##row##_##k
#define LD_A(k, col) register __m256d AID(k, col) = FLOAD(A_offset + (k)*id + (col)*4)
#define LD_B(row, k) __m256d BID(row, k) = FBROADCAST(B_offset + (row)*kc + (k))

#define CAL(row, col, k) FMLA(CID(row, col), AID(k, col), BID(row, k), CID(row, col))

#define LD_A4(k) \
    LD_A(k, 0);  \
    LD_A(k, 1);  \
    LD_A(k, 2);  \
    LD_A(k, 3);

#define ST_C4(row) \
    ST_C(row, 0);  \
    ST_C(row, 1);  \
    ST_C(row, 2);  \
    ST_C(row, 3);

#define CAL4(row, k) \
    CAL(row, 0, k);  \
    CAL(row, 1, k);  \
    CAL(row, 2, k);  \
    CAL(row, 3, k);

#define CAL4_K(row, k) \
    LD_B(row, k);      \
    CAL4(row, k);

#define CAL4_DEF_K(row, k) \
    LD_B(row, k);          \
    CAL4(row, k);

#define CAL4_K8(row) \
    CAL4_K(row, 0);  \
    CAL4_K(row, 1);\
    CAL4_K(row, 2);\
    CAL4_K(row, 3);

#define CAL4_DEF_K8(row) \
    CAL4_DEF_K(row, 0);  \
    CAL4_DEF_K(row, 1);\
    CAL4_DEF_K(row, 2);\
    CAL4_DEF_K(row, 3);

#define LD_iC(row, col) register __m256d CID(row, col) = FLOAD(iC + (row)*id + (col)*4)
#define ST_iC(row, col) FSTORE(iC + (row)*id + (col)*4, CID(row, col))

#define ST_iC4(row) \
    ST_iC(row, 0);  \
    ST_iC(row, 1);  \
    ST_iC(row, 2);  \
    ST_iC(row, 3);

#define LD_iC4(row) \
    LD_iC(row, 0);  \
    LD_iC(row, 1);  \
    LD_iC(row, 2);  \
    LD_iC(row, 3);
#define CAL4_iK8(row) \
    CAL4_K(row, 0);   \
    CAL4_K(row, 1);

#define CAL4_DEF_iK8(row) \
    CAL4_DEF_K(row, 0);   \
    CAL4_DEF_K(row, 1);

                        LD_C4(0);
                        LD_C4(1);
                        LD_C4(2);
                        LD_C4(3);
                        double* A_offset = _A + (ill - il) * kc;
                        double* B_offset = _B + (jll - jl) * kc;
                        fo(kll, kl, kr, kd) {
                            __builtin_prefetch(B_offset + kd);
                            __builtin_prefetch(A_offset + id * kd);
                            __builtin_prefetch(A_offset + id * kd + id);
                            // Do (4x16x8) mm, C(j,i) += A(k,i)*B(j,k)
                            // C([jll,jrr],[ill,irr]) += A([kll,krr],[ill,irr])*B([jll,jrr],[kll,krr])
                            // A([kll,krr],[ill,irr])->_A[(ill-il)*(kr-kl)+[kll,krr]*id]
                            // B([jll,jrr],[kll,krr])->_B[(jll-jl)*(kr-kl)+kll+[jll,jrr]]

                            // double* A_offset = _A + (ill - il) * kc + (kll - kl) * id;
                            // double* B_offset = _B + (jll - jl) * kc + kll - kl;
                            LD_A4(0);
                            LD_A4(1);
                            CAL4_K8(0);
                            CAL4_DEF_K8(1);
                            CAL4_DEF_K8(2);
                            CAL4_DEF_K8(3);
                            A_offset += id * kd;
                            B_offset += kd;
                        }
                        ST_C4(0);
                        ST_C4(1);
                        ST_C4(2);
                        ST_C4(3);
                        
                    }
                }
            }
        }
    }
}

void matrix_multiply(int n, const double* A, const double* B, double* C) { third_try(n, A, B, C); }

CompilationN/AN/ACompile ErrorScore: N/A


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-17 10:16:52 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠