#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <assert.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define CACHELINE 128 // 128B
#define CACHESIZE 1 << 16
#define fo(i, L, R, STEP) for (int i = L, i##__end = R; i < i##__end; i += STEP)
#define fo1(i, L, R) fo(i, L, R, 1)
#define min(x, y) ((x) < (y) ? (x) : (y))
#define getR(x, STEP, MAX) min((x) + (STEP), (MAX))
#define FMLA(result, a, b, c) ((result) = _mm256_fmadd_pd((a), (b), (c)))
#define FLOAD(where) _mm256_load_pd(where)
#define FSTORE(where, what) _mm256_store_pd(where, what)
#define FBROADCAST(where) _mm256_broadcast_sd(where)
#define FBROADCAST2(where) vld2q_dup_f32(where)
void inline mcpy(double* dst, const double* src, int len) {
memcpy(dst, src, len);
return;
len >>= 2;
int i = 0;
for (; i < len; i += 32) {
__m256d a0 = FLOAD(src);
__m256d a1 = FLOAD(src + 4);
__m256d a2 = FLOAD(src + 8);
__m256d a3 = FLOAD(src + 12);
__m256d a4 = FLOAD(src + 16);
__m256d a5 = FLOAD(src + 20);
__m256d a6 = FLOAD(src + 24);
__m256d a7 = FLOAD(src + 28);
FSTORE(dst, a0);
FSTORE(dst + 4, a1);
FSTORE(dst + 8, a2);
FSTORE(dst + 12, a3);
FSTORE(dst + 16, a4);
FSTORE(dst + 20, a5);
FSTORE(dst + 24, a6);
FSTORE(dst + 28, a7);
src += 32;
dst += 32;
}
for (; i < len; ++i) *dst++ = *src++;
}
void third_try(int n, const double* A, const double* B, double* C) {
// char ___x;
// char ___str[CACHELINE - ((int)(&___x) & (CACHELINE - 1))];
double __attribute__((aligned(32))) cache[CACHESIZE]; // to make continue
int ic = 128, kc = 128, jc = 128;
const int id = 16, kd = 4, jd = 4;
double* _A = cache;
double* _B = cache + ic * kc;
double* iC = _B + jc * kc;
memset(C, 0, n * n * sizeof(double));
fo(il, 0, n, ic) {
int ir = getR(il, ic, n);
fo(kl, 0, n, kc) {
int kr = getR(kl, kc, n);
// Pack B
fo(ill, il, ir, id) {
size_t Aoffset = kl * n + ill;
size_t _Aoffset = (ill - il) * kc;
size_t len = (getR(ill, id, ir) - ill) << 3;
fo(k, kl, kr, 1) {
if ((id << 3) > len) memset(_A + _Aoffset, 0, id << 3);
mcpy(_A + _Aoffset, A + Aoffset, len);
Aoffset += n;
_Aoffset += id;
}
if (kr - kl < kc) {
memset(_A + _Aoffset, 0, (kc - kr + kl) * id << 3);
}
}
fo(jl, 0, n, jc) {
int jr = getR(jl, jc, n);
// Pack A
fo(jll, jl, jr, jd) {
size_t Boffset = jll * n + kl;
size_t _Boffset = (jll - jl) * kc;
size_t len = (kr - kl) << 3;
fo(j, jll, getR(jll, jd, jr), 1) {
if ((kc << 3) > len) memset(_B + _Boffset, 0, kc << 3);
mcpy(_B + _Boffset, B + Boffset, len);
Boffset += n;
_Boffset += kc;
}
if (jr - jll < jd) {
memset(_B + _Boffset, 0, (jd - jr + jll) * kc << 3);
}
}
fo(jll, jl, jr, jd) {
int jrr = getR(jll, jd, jr);
fo(ill, il, ir, id) {
int irr = getR(ill, id, ir);
double* C_offset = C + jll * n + ill;
#define CID(row, col) C_##row##_##col
#define LD_C(row, col) register __restrict__ __m256d CID(row, col) = FLOAD(C_offset + (row)*n + (col)*4)
#define ST_C(row, col) FSTORE(C_offset + (row)*n + (col)*4, CID(row, col))
#define LD_C4(row) \
LD_C(row, 0); \
LD_C(row, 1); \
LD_C(row, 2); \
LD_C(row, 3);
#define AID(k, col) A_##k##_##col
#define BID(row, k) B_##row##_##k
#define LD_A(k, col) register __restrict__ __m256d AID(k, col) = FLOAD(A_offset + (k)*id + (col)*4)
#define LD_B(row, k) __m256d BID(row, k) = FBROADCAST(B_offset + (row)*kc + (k))
#define CAL(row, col, k) FMLA(CID(row, col), AID(k, col), BID(row, k), CID(row, col))
#define LD_A4(k) \
LD_A(k, 0); \
LD_A(k, 1); \
LD_A(k, 2); \
LD_A(k, 3);
#define ST_C4(row) \
ST_C(row, 0); \
ST_C(row, 1); \
ST_C(row, 2); \
ST_C(row, 3);
#define CAL4(row, k) \
CAL(row, 0, k); \
CAL(row, 1, k); \
CAL(row, 2, k); \
CAL(row, 3, k);
#define CAL4_K(row, k) \
LD_B(row, k); \
CAL4(row, k);
#define CAL4_DEF_K(row, k) \
LD_B(row, k); \
CAL4(row, k);
#define CAL4_K8(row) \
CAL4_K(row, 0); \
CAL4_K(row, 1);\
CAL4_K(row, 2);\
CAL4_K(row, 3);
#define CAL4_DEF_K8(row) \
CAL4_DEF_K(row, 0); \
CAL4_DEF_K(row, 1);\
CAL4_DEF_K(row, 2);\
CAL4_DEF_K(row, 3);
#define LD_iC(row, col) register __m256d CID(row, col) = FLOAD(iC + (row)*id + (col)*4)
#define ST_iC(row, col) FSTORE(iC + (row)*id + (col)*4, CID(row, col))
#define ST_iC4(row) \
ST_iC(row, 0); \
ST_iC(row, 1); \
ST_iC(row, 2); \
ST_iC(row, 3);
#define LD_iC4(row) \
LD_iC(row, 0); \
LD_iC(row, 1); \
LD_iC(row, 2); \
LD_iC(row, 3);
#define CAL4_iK8(row) \
CAL4_K(row, 0); \
CAL4_K(row, 1);
#define CAL4_DEF_iK8(row) \
CAL4_DEF_K(row, 0); \
CAL4_DEF_K(row, 1);
LD_C4(0);
LD_C4(1);
LD_C4(2);
LD_C4(3);
double* A_offset = _A + (ill - il) * kc;
double* B_offset = _B + (jll - jl) * kc;
fo(kll, kl, kr, kd) {
__builtin_prefetch(B_offset + kd);
__builtin_prefetch(A_offset + id * kd);
__builtin_prefetch(A_offset + id * kd + id);
// Do (4x16x8) mm, C(j,i) += A(k,i)*B(j,k)
// C([jll,jrr],[ill,irr]) += A([kll,krr],[ill,irr])*B([jll,jrr],[kll,krr])
// A([kll,krr],[ill,irr])->_A[(ill-il)*(kr-kl)+[kll,krr]*id]
// B([jll,jrr],[kll,krr])->_B[(jll-jl)*(kr-kl)+kll+[jll,jrr]]
// double* A_offset = _A + (ill - il) * kc + (kll - kl) * id;
// double* B_offset = _B + (jll - jl) * kc + kll - kl;
LD_A4(0);
LD_A4(1);
LD_A4(2);
LD_A4(3);
CAL4_K8(0);
CAL4_DEF_K8(1);
CAL4_DEF_K8(2);
CAL4_DEF_K8(3);
A_offset += id * kd;
B_offset += kd;
}
ST_C4(0);
ST_C4(1);
ST_C4(2);
ST_C4(3);
}
}
}
}
}
}
void matrix_multiply(int n, const double* A, const double* B, double* C) { third_try(n, A, B, C); }
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 72.681 ms | 8 MB + 520 KB | Wrong Answer | Score: 0 | 显示更多 |