#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <assert.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define CACHELINE 128 // 128B
#define CACHESIZE 1 << 16
#define fo(i, L, R, STEP) for (int i = L, i##__end = R; i < i##__end; i += STEP)
#define fo1(i, L, R) fo(i, L, R, 1)
#define min(x, y) ((x) < (y) ? (x) : (y))
#define getR(x, STEP, MAX) min((x) + (STEP), (MAX))
#define FMLA(result, a, b, c) ((result) = _mm256_fmadd_pd((a), (b), (c)))
#define FLOAD(where) _mm256_load_pd(where)
#define FSTORE(where, what) _mm256_store_pd(where, what)
#define FBROADCAST(where) _mm256_broadcast_sd(where)
#define FBROADCAST2(where) vld2q_dup_f32(where)
void inline mcpy(double* dst, const double* src, int len) {
memcpy(dst, src, len);
return;
len >>= 2;
int i = 0;
for (; i < len; i += 32) {
__m256d a0 = FLOAD(src);
__m256d a1 = FLOAD(src + 4);
__m256d a2 = FLOAD(src + 8);
__m256d a3 = FLOAD(src + 12);
__m256d a4 = FLOAD(src + 16);
__m256d a5 = FLOAD(src + 20);
__m256d a6 = FLOAD(src + 24);
__m256d a7 = FLOAD(src + 28);
FSTORE(dst, a0);
FSTORE(dst + 4, a1);
FSTORE(dst + 8, a2);
FSTORE(dst + 12, a3);
FSTORE(dst + 16, a4);
FSTORE(dst + 20, a5);
FSTORE(dst + 24, a6);
FSTORE(dst + 28, a7);
src += 32;
dst += 32;
}
for (; i < len; ++i) *dst++ = *src++;
}
void third_try(int n, const double* A, const double* B, double* C) {
// char ___x;
// char ___str[CACHELINE - ((int)(&___x) & (CACHELINE - 1))];
double __attribute__((aligned(32))) cache[CACHESIZE]; // to make continue
int ic = 256, kc = 256, jc = 128;
const int id = 16, kd = 4, jd = 4;
if (n <= 512) ic = 64, kc = 128, jc = 128;
if (n % 128 == 64) ic = 64;
if (n == 129) ic = 80, kc = 144;
// for (int i = 64; i <= 512; i += 16)
// if (n % ic > n % i) ic = i;
double* _A = cache;
double* _B = cache + ic * kc;
double* iC = _B;
/*
for (int i = 8; i >= 1; --i)
if (n <= i * 16) {
ic = i * 16;
}
for (int i = 64; i >= 1; --i)
if (n <= i * 2) {
kc = i * 2;
}
for (int i = 32; i >= 1; --i)
if (n <= i * 4) {
jc = i * 4;
}*/
fo(il, 0, n, ic) {
int ir = getR(il, ic, n);
fo(kl, 0, n, kc) {
int kr = getR(kl, kc, n);
// Pack A
fo(ill, il, ir, id) {
size_t Aoffset = kl * n + ill;
size_t _Aoffset = (ill - il) * kc;
size_t len = (getR(ill, id, ir) - ill) << 3;
fo(k, kl, kr, 1) {
if ((id << 3) > len) memset(_A + _Aoffset, 0, id << 3);
mcpy(_A + _Aoffset, A + Aoffset, len);
Aoffset += n;
_Aoffset += id;
}
if (kr - kl < kc) {
memset(_A + _Aoffset, 0, (kc - kr + kl) * id << 3);
}
}
fo(jl, 0, n, jc) {
int jr = getR(jl, jc, n);
fo(jll, jl, jr, jd) {
int jrr = getR(jll, jd, jr);
fo(ill, il, ir, id) {
int irr = getR(ill, id, ir);
double* C_offset = C + jll * n + ill;
#define CID(row, col) C_##row##_##col
#define LD_C(row, col) register __m256d CID(row, col) = FLOAD(C_offset + (row)*n + (col)*4)
#define ST_C(row, col) FSTORE(C_offset + (row)*n + (col)*4, CID(row, col))
#define LD_C4(row) \
LD_C(row, 0); \
LD_C(row, 1); \
LD_C(row, 2); \
LD_C(row, 3);
#define AID(k, col) A_##k##_##col
#define BID(row, k) B_##row##_##k
#define LD_A(k, col) register __m256d AID(k, col) = FLOAD(A_offset + (k)*id + (col)*4)
#define LD_B(row, k) register __m256d BID(row, k) = FBROADCAST(B_offset + (row)*n + (k))
#define LD_B2(row) \
__m256d BID(row, 0); \
__m256d BID(row, 1); \
({ \
double32x4x2_t tmp = FBROADCAST2(B_offset + (row)*kc); \
BID(row, 0) = tmp.val[0]; \
BID(row, 1) = tmp.val[1]; \
})
#define CAL(row, col, k) FMLA(CID(row, col), AID(k, col), BID(row, k), CID(row, col))
#define LD_A4(k) \
LD_A(k, 0); \
LD_A(k, 1); \
LD_A(k, 2); \
LD_A(k, 3);
#define ST_C4(row) \
ST_C(row, 0); \
ST_C(row, 1); \
ST_C(row, 2); \
ST_C(row, 3);
#define CAL4(row, k) \
CAL(row, 0, k); \
CAL(row, 1, k); \
CAL(row, 2, k); \
CAL(row, 3, k);
#define CAL4_K(row, k) \
LD_B(row, k); \
CAL4(row, k);
#define CAL4_DEF_K(row, k) \
LD_B(row, k); \
CAL4(row, k);
#define CAL4_K8(row) \
CAL4_K(row, 0); \
CAL4_K(row, 1); \
CAL4_K(row, 2); \
CAL4_K(row, 3);
#define CAL4_DEF_K8(row) \
CAL4_DEF_K(row, 0); \
CAL4_DEF_K(row, 1); \
CAL4_DEF_K(row, 2); \
CAL4_DEF_K(row, 3);
#define CAL4_K8_NEW(row) \
LD_B2(row); \
CAL4(row, 0); \
CAL4(row, 1);
#define CAL4_DEF_K8_NEW(row) \
LD_B2(row); \
CAL4(row, 0); \
CAL4(row, 1);
#define LD_iC(row, col) register __m256d CID(row, col) = FLOAD(iC + (row)*id + (col)*4)
#define ST_iC(row, col) FSTORE(iC + (row)*id + (col)*4, CID(row, col))
#define ST_iC4(row) \
ST_iC(row, 0); \
ST_iC(row, 1); \
ST_iC(row, 2); \
ST_iC(row, 3);
#define LD_iC4(row) \
LD_iC(row, 0); \
LD_iC(row, 1); \
LD_iC(row, 2); \
LD_iC(row, 3);
#define CAL4_iK8(row) CAL4_K8(row)
#define CAL4_DEF_iK8(row) CAL4_DEF_K8(row)
if (jll + jd == jrr && (ill + id == irr || (jll + jd < n && ill + 4 < irr))) {
LD_C4(0);
LD_C4(1);
LD_C4(2);
LD_C4(3);
const double* A_offset = _A + (ill - il) * kc;
const double* B_offset = B + jll * n + kl;
fo(kll, kl, kr, kd) {
__builtin_prefetch(A_offset + id * kd);
__builtin_prefetch(A_offset + id * kd + id);
// __builtin_prefetch(A_offset + id * kd + 4);
// __builtin_prefetch(A_offset + id * kd + 8);
// __builtin_prefetch(A_offset + id * kd + 12);
// Do (4x16x8) mm, C(j,i) += A(k,i)*B(j,k)
// C([jll,jrr],[ill,irr]) += A([kll,krr],[ill,irr])*B([jll,jrr],[kll,krr])
// A([kll,krr],[ill,irr])->_A[(ill-il)*(kr-kl)+[kll,krr]*id]
// B([jll,jrr],[kll,krr])->_B[(jll-jl)*(kr-kl)+kll+[jll,jrr]]
// double* A_offset = _A + (ill - il) * kc + (kll - kl) * id;
// double* B_offset = _B + (jll - jl) * kc + kll - kl;
LD_A4(0);
LD_A4(1);
LD_A4(2);
LD_A4(3);
CAL4_K8(0);
CAL4_DEF_K8(1);
CAL4_DEF_K8(2);
CAL4_DEF_K8(3);
A_offset += id * kd;
B_offset += kd;
}
ST_C4(0);
ST_C4(1);
ST_C4(2);
ST_C4(3);
} else {
fo(p, 0, jrr - jll, 1) mcpy(iC + p * id, C_offset + p * n, (irr - ill) << 3);
double* r_C_offset = C_offset;
C_offset = iC;
LD_iC4(0);
LD_iC4(1);
LD_iC4(2);
LD_iC4(3);
const double* A_offset = _A + (ill - il) * kc;
const double* B_offset = B + jll * n + kl;
fo(kll, kl, kr, kd) {
__builtin_prefetch(A_offset + id * kd);
__builtin_prefetch(A_offset + id * kd + id);
LD_A4(0);
LD_A4(1);
LD_A4(2);
LD_A4(3);
CAL4_iK8(0);
CAL4_DEF_iK8(1);
CAL4_DEF_iK8(2);
CAL4_DEF_iK8(3);
A_offset += id * kd;
B_offset += kd;
}
ST_iC4(0);
ST_iC4(1);
ST_iC4(2);
ST_iC4(3);
C_offset = r_C_offset;
fo(p, 0, jrr - jll, 1) mcpy(C_offset + p * n, iC + p * id, (irr - ill) << 3);
}
}
}
}
}
}
}
void matrix_multiply(int n, const double* A, const double* B, double* C) { third_try(n, A, B, C); }
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 72.013 ms | 8 MB + 520 KB | Wrong Answer | Score: 0 | 显示更多 |