#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>
static void rotate_row(int n, int step, const double *src, double *dst) {
memcpy(dst, src + n - step, step * sizeof(double));
memcpy(dst + step, src, (n - step) * sizeof(double));
}
static void rotate(int n, int j_step, const double *src, double *dst) {
for (int i = 0; i < n; i++) {
rotate_row(n, (i * j_step) & 511, src + i * n, dst + i * n);
}
}
inline void cat(int n, const double *A, const double *_B, double *C) {
const int i_step = 256;
const int k_step = 32;
const int j_step = 64; // 16 * 4
static double B[512 * 512];
rotate(n, j_step, _B, B);
memset(C, 0, n * n * sizeof(double));
for (int i_start = 0; i_start < n; i_start += i_step) {
int i_end = i_start + i_step;
for (int k_start = 0; k_start < n; k_start += k_step) {
int k_end = k_start + k_step <= n ? k_start + k_step : n;
for (int j_start = 0; j_start < n; j_start += j_step) {
for (int i = i_start; i < i_end; i++) {
const double *ai = A + i * n;
double *ci = C + i * n;
double *ci_s = ci + j_start;
#define LOOP8(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7)
#define LOOP16(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7) \
f(8) f(9) f(10) f(11) f(12) f(13) f(14) f(15)
#define LOOP(f) LOOP16(f)
#define CI(i) (* (__m256d *) (ci_s + (i) * 4))
#define CI_r(i) ci_##i
#define load(i) __m256d CI_r(i) = CI(i);
LOOP(load)
for (int k = k_start; k < k_end; k++) {
int j_actual = (j_start + k * 64) & 511;
const double *bk = B + k * n;
const double *bk_s = bk + j_actual;
const double aik = ai[k];
__m256d K = _mm256_set1_pd(aik);
#define BK(i) (* (__m256d *) (bk_s + (i) * 4))
#define add(i) CI_r(i) = _mm256_fmadd_pd(BK(i), K, CI_r(i));
LOOP(add)
}
#define store(i) CI(i) = CI_r(i);
LOOP(store)
}
}
}
}
}
#define n 1024
#define idx(i, j) (((i) * 512) + (j))
const int M = 512, R=512;
static void convert(const double* __restrict__ src, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i)
memcpy(dst + i * R, src + i * 1024, sizeof(double) * M);
}
static void convertA(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
dst [ i * R+ j] = src1 [ i * 1024+ j] + src2 [ i * 1024+ j];
}
static void convertS(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
dst [ i * R+ j] = src1 [ i * 1024+ j] - src2 [ i * 1024+ j];
}
static void iconvert(const double* __restrict__ src, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i)
memcpy(dst + i * 1024, src + i * R, sizeof(double) * M);
}
static void iconvertA(const double* __restrict__ src, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
dst [ i * 1024+ j] += src [ i * R+ j];
}
static void iconvertS(const double* __restrict__ src, double* __restrict__ dst) {
// memcpy(dst, src, sizeof(double) * 1024 * 1024);
for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
dst [ i * 1024+ j] -= src [ i * R+ j];
}
void matrix_multiply(int, const double* _A, const double* _B, double* _C) {
double A[7][R * R], B[7][R * R], C[7][R * R] __attribute__((aligned(4096)));
const int DL = 512*1024, DR=512;
memset(C, 0, sizeof(C));
#define QAQ(k, l) cat(512, A[k], B[k], C[l]);
convertA(_A, _A+DL+DR, A[0]); convertA(_B, _B+DL+DR, B[0]); QAQ(0, 0)
convertA(_A+DL, _A+DL+DR, A[1]); convert(_B, B[1]); QAQ(1, 1)
convert(_A, A[2]); convertS(_B+DR, _B+DR+DL, B[2]); QAQ(2, 2)
convert(_A+DR+DL, A[3]); convertS(_B+DL, _B, B[3]); QAQ(3, 3)
convertA(_A, _A+DR, A[4]); convert(_B+DL+DR, B[4]); QAQ(4, 4)
convertS(_A+DL, _A, A[5]); convertA(_B, _B+DR, B[5]); QAQ(5, 5)
convertS(_A+DR, _A+DL+DR, A[6]); convertA(_B+DL, _B+DL+DR, B[6]); QAQ(6, 6)
iconvert(C[0], _C); iconvertA(C[3], _C); iconvertS(C[4], _C); iconvertA(C[6], _C);
iconvert(C[2], _C+DR); iconvertA(C[4], _C+DR);
iconvert(C[1], _C+DL); iconvertA(C[3], _C+DL);
iconvert(C[0], _C+DL+DR); iconvertS(C[1], _C+DL+DR); iconvertA(C[2], _C+DL+DR); iconvertA(C[5], _C+DL+DR);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 79.179 ms | 52 MB + 16 KB | Accepted | Score: 100 | 显示更多 |