#pragma GCC target("avx,fma,avx512f")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#include <emmintrin.h>
#include <immintrin.h>
typedef double vec __attribute__((vector_size(64)));
// micro kernel
// update c[x : x + 6][y : y + 8]
// using a[x : x + 6][l : r] and b[l : r][y : y + 8]
#define _upd_t(idx) \
vec a##idx = vec{} + a[(x + idx) * n + k]; \
t##idx##0 = _mm512_fmadd_pd(a##idx, b0, t##idx##0); \
t##idx##1 = _mm512_fmadd_pd(a##idx, b1, t##idx##1);
#define _set_t(i, j) vec t##i##j = c[((x + i) * n + y) / 8 + j];
#define _save_t(i, j) c[((x + i) * n + y) / 8 + j] = t##i##j;
// clang-format off
inline void kernel_6x16(double *a, vec *b, vec *c, int x, int y, int l, int r, int n) {
_set_t(0, 0); _set_t(0, 1); _set_t(1, 0); _set_t(1, 1);
_set_t(2, 0); _set_t(2, 1); _set_t(3, 0); _set_t(3, 1);
_set_t(4, 0); _set_t(4, 1); _set_t(5, 0); _set_t(5, 1);
for (int k = l; k < r; k++) {
vec b0 = b[(k * n + y) / 8], b1 = b[(k * n + y) / 8 + 1];
_upd_t(0); _upd_t(1); _upd_t(2);
_upd_t(3); _upd_t(4); _upd_t(5);
}
_save_t(0, 0); _save_t(0, 1); _save_t(1, 0); _save_t(1, 1);
_save_t(2, 0); _save_t(2, 1); _save_t(3, 0); _save_t(3, 1);
_save_t(4, 0); _save_t(4, 1); _save_t(5, 0); _save_t(5, 1);
}
// clang-format on
constexpr int n = 1024, N = n, s1 = 120, s2 = 72, s3 = 32;
void matrix_multiply(int _n, const double *_a, const double *_b, double *_c) {
int nx = (n + 5) / 6 * 6, ny = n;
double a[nx * ny], b[nx * ny], c[nx * ny];
memset(a, 0, sizeof a);
memset(b, 0, sizeof b);
memset(c, 0, sizeof c);
for (int i = 0; i < n; i++) {
memcpy(&a[i * ny], _a + n * i, n * 8);
memcpy(&b[i * ny], _b + n * i, n * 8);
}
for (int i3 = 0; i3 < ny; i3 += s3)
// b[:][i3 : i3 + s3]
for (int i2 = 0; i2 < nx; i2 += s2)
// a[i2 : i2 + s2][:]
for (int i1 = 0; i1 < ny; i1 += s1)
// b[i1 : i1 + s1][i3 : i3 + s3]
// update c[i2 : i2 + s2][i3 : i3+s3] with [l : r] = [i1 : i1 + s1]
for (int x = i2; x < std::min(i2 + s2, nx); x += 6)
for (int y = i3; y < std::min(i3 + s3, ny); y += 16)
kernel_6x16(a, (vec *)b, (vec *)c, x, y, i1, std::min(i1 + s1, n),
ny);
for (int i = 0; i < n; i++)
memcpy(_c + n * i, &c[i * ny], n * 8);
std::free(a), std::free(b), std::free(c);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 5.475 ms | 24 MB + 80 KB | Runtime Error | Score: 0 | 显示更多 |