// ikj registers + L1
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>
static void rotate_row(int n, int step, const double *src, double *dst) {
memcpy(dst, src + n - step, step * sizeof(double));
memcpy(dst + step, src, (n - step) * sizeof(double));
}
static void rotate(int n, int j_step, const double *src, double *dst) {
for (int i = 0; i < n; i++) {
rotate_row(n, (i * j_step) & 1023, src + i * n, dst + i * n);
}
}
void matrix_multiply(int n, const double *A, const double *_B, double *C) {
const int i_step = 256;
const int k_step = 32;
const int j_step = 64; // 16 * 4
static double B[1024 * 1024];
rotate(n, j_step, _B, B);
memset(C, 0, n * n * sizeof(double));
for (int i_start = 0; i_start < n; i_start += i_step) {
int i_end = i_start + i_step;
for (int k_start = 0; k_start < n; k_start += k_step) {
int k_end = k_start + k_step <= n ? k_start + k_step : n;
for (int j_start = 0; j_start < n; j_start += j_step) {
for (int i = i_start; i < i_end; i++) {
const double *ai = A + i * n;
double *ci = C + i * n;
double *ci_s = ci + j_start;
#define LOOP8(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7)
#define LOOP16(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7) \
f(8) f(9) f(10) f(11) f(12) f(13) f(14) f(15)
#define LOOP(f) LOOP16(f)
#define CI(i) (* (__m256d *) (ci_s + (i) * 4))
#define CI_r(i) ci_##i
#define load(i) __m256d CI_r(i) = CI(i);
LOOP(load)
for (int k = k_start; k < k_end; k++) {
int j_actual = (j_start + k * 64) & 1023;
const double *bk = B + k * n;
const double *bk_s = bk + j_actual;
const double aik = ai[k];
__m256d K = _mm256_set1_pd(aik);
#define BK(i) (* (__m256d *) (bk_s + (i) * 4))
#define add(i) CI_r(i) = _mm256_fmadd_pd(BK(i), K, CI_r(i));
LOOP(add)
}
#define store(i) CI(i) = CI_r(i);
LOOP(store)
}
}
}
}
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 69.893 ms | 16 MB + 8 KB | Accepted | Score: 100 | 显示更多 |