提交记录 17530


用户 题目 状态 得分 用时 内存 语言 代码长度
yfzcsc mmmd1k. 测测你的双精度矩阵乘法-1k Wrong Answer 0 82.397 ms 25096 KB C++ 13.82 KB
提交时间 评测时间
2022-03-24 20:22:27 2022-03-24 20:22:30
#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <assert.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define CACHELINE 128  // 128B
#define CACHESIZE 1 << 16

#define fo(i, L, R, STEP) for (int i = L, i##__end = R; i < i##__end; i += STEP)
#define fo1(i, L, R) fo(i, L, R, 1)
#define min(x, y) ((x) < (y) ? (x) : (y))
#define getR(x, STEP, MAX) min((x) + (STEP), (MAX))

#define FMLA(result, a, b, c) ((result) = _mm256_fmadd_pd((a), (b), (c)))
#define FLOAD(where) _mm256_load_pd(where)
#define FSTORE(where, what) _mm256_store_pd(where, what)
#define FBROADCAST(where) _mm256_broadcast_sd(where)
#define FBROADCAST2(where) vld2q_dup_f32(where)

void inline mcpy(double* dst, const double* src, int len) {
    memcpy(dst, src, len);
    return;
    len >>= 2;
    int i = 0;
    for (; i < len; i += 32) {
        __m256d a0 = FLOAD(src);
        __m256d a1 = FLOAD(src + 4);
        __m256d a2 = FLOAD(src + 8);
        __m256d a3 = FLOAD(src + 12);
        __m256d a4 = FLOAD(src + 16);
        __m256d a5 = FLOAD(src + 20);
        __m256d a6 = FLOAD(src + 24);
        __m256d a7 = FLOAD(src + 28);
        FSTORE(dst, a0);
        FSTORE(dst + 4, a1);
        FSTORE(dst + 8, a2);
        FSTORE(dst + 12, a3);
        FSTORE(dst + 16, a4);
        FSTORE(dst + 20, a5);
        FSTORE(dst + 24, a6);
        FSTORE(dst + 28, a7);
        src += 32;
        dst += 32;
    }
    for (; i < len; ++i) *dst++ = *src++;
}

void third_try(int n, const double* zA, const double* zB, double* C) {
    // char ___x;
    // char ___str[CACHELINE - ((int)(&___x) & (CACHELINE - 1))];
    double __attribute__((aligned(32))) cache[CACHESIZE];  // to make continue

    int ic = 128, kc = 128, jc = 128;
    const int id = 16, kd = 4, jd = 4;

    if (n <= 512) ic = 64, kc = 128, jc = 128;
    if (n % 128 == 64) ic = 64;
    if (n == 129) ic = 80, kc = 144;
    // for (int i = 64; i <= 512; i += 16)
    // if (n % ic > n % i) ic = i;
    double* _A = cache;
    double* _B = cache + ic * kc;
    double* iC = _B;

    /*
        for (int i = 8; i >= 1; --i)
            if (n <= i * 16) {
                ic = i * 16;
            }
        for (int i = 64; i >= 1; --i)
            if (n <= i * 2) {
                kc = i * 2;
            }
        for (int i = 32; i >= 1; --i)
            if (n <= i * 4) {
                jc = i * 4;
            }*/
    
    double A[1024*1024];
    double B[1024*1024];

    fo(i,0,n,1)fo(j,0,i,1) A[j*n+i]=zA[i*n+j];
    fo(i,0,n,1)fo(j,0,i,1) B[j*n+i]=zB[i*n+j];

    fo(il, 0, n, ic) {
        int ir = getR(il, ic, n);
        fo(kl, 0, n, kc) {
            int kr = getR(kl, kc, n);
            // Pack A
            fo(ill, il, ir, id) {
                size_t Aoffset = kl * n + ill;
                size_t _Aoffset = (ill - il) * kc;
                size_t len = (getR(ill, id, ir) - ill) << 3;
                fo(k, kl, kr, 1) {
                    if ((id << 3) > len) memset(_A + _Aoffset, 0, id << 3);
                    mcpy(_A + _Aoffset, A + Aoffset, len);
                    Aoffset += n;
                    _Aoffset += id;
                }
                if (kr - kl < kc) {
                    memset(_A + _Aoffset, 0, (kc - kr + kl) * id << 3);
                }
            }

            fo(jl, 0, n, jc) {
                int jr = getR(jl, jc, n);

                fo(jll, jl, jr, jd) {
                    int jrr = getR(jll, jd, jr);
                    fo(ill, il, ir, id) {
                        int irr = getR(ill, id, ir);
                        double* C_offset = C + jll * n + ill;
#define CID(row, col) C_##row##_##col
#define LD_C(row, col) register __m256d CID(row, col) = FLOAD(C_offset + (row)*n + (col)*4)
#define ST_C(row, col) FSTORE(C_offset + (row)*n + (col)*4, CID(row, col))
#define LD_C4(row) \
    LD_C(row, 0);  \
    LD_C(row, 1);  \
    LD_C(row, 2);  \
    LD_C(row, 3);

#define AID(k, col) A_##k##_##col
#define BID(row, k) B_##row##_##k
#define LD_A(k, col) register __m256d AID(k, col) = FLOAD(A_offset + (k)*id + (col)*4)
#define LD_B(row, k) register __m256d BID(row, k) = FBROADCAST(B_offset + (row)*n + (k))
#define LD_B2(row)                                            \
    __m256d BID(row, 0);                                  \
    __m256d BID(row, 1);                                  \
    ({                                                        \
        double32x4x2_t tmp = FBROADCAST2(B_offset + (row)*kc); \
        BID(row, 0) = tmp.val[0];                             \
        BID(row, 1) = tmp.val[1];                             \
    })
#define CAL(row, col, k) FMLA(CID(row, col), AID(k, col), BID(row, k), CID(row, col))

#define LD_A4(k) \
    LD_A(k, 0);  \
    LD_A(k, 1);  \
    LD_A(k, 2);  \
    LD_A(k, 3);

#define ST_C4(row) \
    ST_C(row, 0);  \
    ST_C(row, 1);  \
    ST_C(row, 2);  \
    ST_C(row, 3);

#define CAL4(row, k) \
    CAL(row, 0, k);  \
    CAL(row, 1, k);  \
    CAL(row, 2, k);  \
    CAL(row, 3, k);

#define CAL4_K(row, k) \
    LD_B(row, k);      \
    CAL4(row, k);

#define CAL4_DEF_K(row, k) \
    LD_B(row, k);          \
    CAL4(row, k);

#define CAL4_K8(row) \
    CAL4_K(row, 0);  \
    CAL4_K(row, 1);  \
    CAL4_K(row, 2);  \
    CAL4_K(row, 3);

#define CAL4_DEF_K8(row) \
    CAL4_DEF_K(row, 0);  \
    CAL4_DEF_K(row, 1);  \
    CAL4_DEF_K(row, 2);  \
    CAL4_DEF_K(row, 3);

#define CAL4_K8_NEW(row) \
    LD_B2(row);          \
    CAL4(row, 0);        \
    CAL4(row, 1);

#define CAL4_DEF_K8_NEW(row) \
    LD_B2(row);              \
    CAL4(row, 0);            \
    CAL4(row, 1);

#define LD_iC(row, col) register __m256d CID(row, col) = FLOAD(iC + (row)*id + (col)*4)
#define ST_iC(row, col) FSTORE(iC + (row)*id + (col)*4, CID(row, col))

#define ST_iC4(row) \
    ST_iC(row, 0);  \
    ST_iC(row, 1);  \
    ST_iC(row, 2);  \
    ST_iC(row, 3);

#define LD_iC4(row) \
    LD_iC(row, 0);  \
    LD_iC(row, 1);  \
    LD_iC(row, 2);  \
    LD_iC(row, 3);
#define CAL4_iK8(row) CAL4_K8(row)

#define CAL4_DEF_iK8(row) CAL4_DEF_K8(row)

                        if (jll + jd == jrr && (ill + id == irr || (jll + jd < n && ill + 4 < irr))) {
                            LD_C4(0);
                            LD_C4(1);
                            LD_C4(2);
                            LD_C4(3);
                            double* A_offset = _A + (ill - il) * kc;
                            double* B_offset = B + jll * n + kl;
                            fo(kll, kl, kr, kd) {
                                __builtin_prefetch(A_offset + id * kd);
                                __builtin_prefetch(A_offset + id * kd + id);
                                // __builtin_prefetch(A_offset + id * kd + 4);
                                // __builtin_prefetch(A_offset + id * kd + 8);
                                // __builtin_prefetch(A_offset + id * kd + 12);
                                // Do (4x16x8) mm, C(j,i) += A(k,i)*B(j,k)
                                // C([jll,jrr],[ill,irr]) += A([kll,krr],[ill,irr])*B([jll,jrr],[kll,krr])
                                // A([kll,krr],[ill,irr])->_A[(ill-il)*(kr-kl)+[kll,krr]*id]
                                // B([jll,jrr],[kll,krr])->_B[(jll-jl)*(kr-kl)+kll+[jll,jrr]]

                                // double* A_offset = _A + (ill - il) * kc + (kll - kl) * id;
                                // double* B_offset = _B + (jll - jl) * kc + kll - kl;
                                LD_A4(0);
                                LD_A4(1);
                                LD_A4(2);
                                LD_A4(3);
                                CAL4_K8(0);
                                CAL4_DEF_K8(1);
                                CAL4_DEF_K8(2);
                                CAL4_DEF_K8(3);
                                A_offset += id * kd;
                                B_offset += kd;
                            }
                            ST_C4(0);
                            ST_C4(1);
                            ST_C4(2);
                            ST_C4(3);
                        } else if (jll + jd == jrr && jll + jd < n && ill + 4 >= irr) {
                            //assume n > 4
                            LD_C(0, 0);
                            LD_C(1, 0);
                            LD_C(2, 0);
                            LD_C(3, 0);
                            double* A_offset = _A + (ill - il) * kc;
                            double* B_offset = B + jll * n + kl;
                            fo(kll, kl, kr, kd * 2) {
                                __builtin_prefetch(A_offset + id * kd * 2);
                                __builtin_prefetch(A_offset + id * kd * 2 + id);
                                LD_A(0, 0);
                                LD_A(1, 0);
                                LD_A(2, 0);
                                LD_A(3, 0);
                                LD_A(4, 0);
                                LD_A(5, 0);
                                LD_A(6, 0);
                                LD_A(7, 0);

                                LD_B(0, 0);
                                LD_B(0, 1);
                                LD_B(0, 2);
                                LD_B(0, 3);
                                LD_B(1, 0);
                                LD_B(1, 1);
                                LD_B(1, 2);
                                LD_B(1, 3);
                                LD_B(2, 0);
                                LD_B(2, 1);
                                LD_B(2, 2);
                                LD_B(2, 3);
                                LD_B(3, 0);
                                LD_B(3, 1);
                                LD_B(3, 2);
                                LD_B(3, 3);

                                CAL(0, 0, 0);
                                CAL(1, 0, 0);
                                CAL(2, 0, 0);
                                CAL(3, 0, 0);
                                CAL(0, 0, 1);
                                CAL(1, 0, 1);
                                CAL(2, 0, 1);
                                CAL(3, 0, 1);
                                CAL(0, 0, 2);
                                CAL(1, 0, 2);
                                CAL(2, 0, 2);
                                CAL(3, 0, 2);
                                CAL(0, 0, 3);
                                CAL(1, 0, 3);
                                CAL(2, 0, 3);
                                CAL(3, 0, 3);

                               
                                LD_B(0, 4);
                                LD_B(0, 5);
                                LD_B(0, 6);
                                LD_B(0, 7);
                                LD_B(1, 4);
                                LD_B(1, 5);
                                LD_B(1, 6);
                                LD_B(1, 7);
                                LD_B(2, 4);
                                LD_B(2, 5);
                                LD_B(2, 6);
                                LD_B(2, 7);
                                LD_B(3, 4);
                                LD_B(3, 5);
                                LD_B(3, 6);
                                LD_B(3, 7);

                                CAL(0, 0, 4);
                                CAL(1, 0, 4);
                                CAL(2, 0, 4);
                                CAL(3, 0, 4);
                                CAL(0, 0, 5);
                                CAL(1, 0, 5);
                                CAL(2, 0, 5);
                                CAL(3, 0, 5);
                                CAL(0, 0, 6);
                                CAL(1, 0, 6);
                                CAL(2, 0, 6);
                                CAL(3, 0, 6);
                                CAL(0, 0, 7);
                                CAL(1, 0, 7);
                                CAL(2, 0, 7);
                                CAL(3, 0, 7);
                                A_offset += id * kd * 2;
                                B_offset += kd * 2;
                            }
                            ST_C(0, 0);
                            ST_C(1, 0);
                            ST_C(2, 0);
                            ST_C(3, 0);
                        } else {
                            fo(p, 0, jrr - jll, 1) mcpy(iC + p * id, C_offset + p * n, (irr - ill) << 3);
                            double* r_C_offset = C_offset;
                            C_offset = iC;
                            LD_iC4(0);
                            LD_iC4(1);
                            LD_iC4(2);
                            LD_iC4(3);
                            double* A_offset = _A + (ill - il) * kc;
                            double* B_offset = B + jll * n + kl;
                            fo(kll, kl, kr, kd) {
                                __builtin_prefetch(A_offset + id * kd);
                                __builtin_prefetch(A_offset + id * kd + id);
                                LD_A4(0);
                                LD_A4(1);
                                LD_A4(2);
                                LD_A4(3);
                                CAL4_iK8(0);
                                CAL4_DEF_iK8(1);
                                CAL4_DEF_iK8(2);
                                CAL4_DEF_iK8(3);
                                A_offset += id * kd;
                                B_offset += kd;
                            }
                            ST_iC4(0);
                            ST_iC4(1);
                            ST_iC4(2);
                            ST_iC4(3);
                            C_offset = r_C_offset;
                            fo(p, 0, jrr - jll, 1) mcpy(C_offset + p * n, iC + p * id, (irr - ill) << 3);
                        }
                    }
                }
            }
        }
    }
    fo(i,0,n,1)fo(j,0,i,1){
        double t=C[i*n+j];
        C[i*n+j]=C[j*n+i];
        C[j*n+i]=t;
    }
}

void matrix_multiply(int n, const double* A, const double* B, double* C) { third_try(n, A, B, C); }

CompilationN/AN/ACompile OKScore: N/A

Testcase #182.397 ms24 MB + 520 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-17 10:18:18 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠