#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>
#define efn 256
#define lda 276
#define idx(i, j) (((i) * lda) + (j))
static void kernel32(long x, long y, long z, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
const long m = 32;
for (long k = 0; k < 2; ++k, z += 16) {
for (long j = 0; j < 8; ++j, y += 4) {
__m256d b00 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 0)]);
__m256d b01 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 1)]);
__m256d b02 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 2)]);
__m256d b03 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 3)]);
__m256d b10 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 0)]);
__m256d b11 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 1)]);
__m256d b12 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 2)]);
__m256d b13 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 3)]);
__m256d b20 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 0)]);
__m256d b21 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 1)]);
__m256d b22 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 2)]);
__m256d b23 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 3)]);
__m256d b30 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 0)]);
__m256d b31 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 1)]);
__m256d b32 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 2)]);
__m256d b33 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 3)]);
for (long i = 0; i < m; ++i) {
__m256d c0 = _mm256_load_pd(&C[idx(x + i, z + 4 * 0)]);
__m256d c1 = _mm256_load_pd(&C[idx(x + i, z + 4 * 1)]);
__m256d c2 = _mm256_load_pd(&C[idx(x + i, z + 4 * 2)]);
__m256d c3 = _mm256_load_pd(&C[idx(x + i, z + 4 * 3)]);
__m256d a;
a = _mm256_set1_pd(A[idx(x + i, y + 0)]);
c0 += a * b00;
c1 += a * b01;
c2 += a * b02;
c3 += a * b03;
a = _mm256_set1_pd(A[idx(x + i, y + 1)]);
c0 += a * b10;
c1 += a * b11;
c2 += a * b12;
c3 += a * b13;
a = _mm256_set1_pd(A[idx(x + i, y + 2)]);
c0 += a * b20;
c1 += a * b21;
c2 += a * b22;
c3 += a * b23;
a = _mm256_set1_pd(A[idx(x + i, y + 3)]);
c0 += a * b30;
c1 += a * b31;
c2 += a * b32;
c3 += a * b33;
_mm256_store_pd(&C[idx(x + i, z + 4 * 0)], c0);
_mm256_store_pd(&C[idx(x + i, z + 4 * 1)], c1);
_mm256_store_pd(&C[idx(x + i, z + 4 * 2)], c2);
_mm256_store_pd(&C[idx(x + i, z + 4 * 3)], c3);
}
}
y -= 32;
}
}
static void gao(int s, int x, int y, int z, int dx, int dy, int dz, int dx2, int dy2, int dz2, int dx3, int dy3, int dz3, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
if (s == 5) {
kernel32(x, y, z, A, B, C);
return;
}
--s;
if (dx < 0) x -= dx << s;
if (dy < 0) y -= dy << s;
if (dz < 0) z -= dz << s;
if (dx2 < 0) x -= dx2 << s;
if (dy2 < 0) y -= dy2 << s;
if (dz2 < 0) z -= dz2 << s;
if (dx3 < 0) x -= dx3 << s;
if (dy3 < 0) y -= dy3 << s;
if (dz3 < 0) z -= dz3 << s;
gao(s, x, y, z, dx2, dy2, dz2, dx3, dy3, dz3, dx, dy, dz, A, B, C);
gao(s, x + (dx << s), y + (dy << s), z + (dz << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx << s) + (dx2 << s), y + (dy << s) + (dy2 << s), z + (dz << s) + (dz2 << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx2 << s), y + (dy2 << s), z + (dz2 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx2 << s) + (dx3 << s), y + (dy2 << s) + (dy3 << s), z + (dz2 << s) + (dz3 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx << s) + (dx2 << s) + (dx3 << s), y + (dy << s) + (dy2 << s) + (dy3 << s), z + (dz << s) + (dz2 << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx << s) + (dx3 << s), y + (dy << s) + (dy3 << s), z + (dz << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx3 << s), y + (dy3 << s), z + (dz3 << s), dx2, dy2, dz2, -dx3, -dy3, -dz3, -dx, -dy, -dz, A, B, C);
}
static void convert(const double* __restrict__ src, double* __restrict__ dst) {
for (int i = 0; i < efn; ++i)
memcpy(dst + i * lda, src + i * efn, sizeof(double) * efn);
}
static void iconvert(const double* __restrict__ src, double* __restrict__ dst) {
for (int i = 0; i < efn; ++i)
memcpy(dst + i * efn, src + i * lda, sizeof(double) * efn);
}
static void submatrix(int n, const double* __restrict__ A, double* __restrict__ B, int p, int q) {
int l1 = (p - 1) * n / 2, r1 = p * n / 2;
int l2 = (q - 1) * n / 2;
for (int i = l1; i < r1; ++i)
memcpy(B + n / 2 * (i - l1), A + n * i + l2, n / 2 * sizeof(double));
}
static void makematrix(int n, double* __restrict__ A, const double* __restrict__ B, int p, int q) {
int l1 = (p - 1) * n / 2, r1 = p * n / 2;
int l2 = (q - 1) * n / 2;
for (int i = l1; i < r1; ++i)
memcpy(A + n * i + l2, B + n / 2 * (i - l1), n / 2 * sizeof(double));
}
static void matrix_plus(int n, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
for (int i = 0; i < n * n; i += 32) {
*(__m256d*) &C[i + 4 * 0] = *(__m256d*) &A[i + 4 * 0] + *(__m256d*) &B[i + 4 * 0];
*(__m256d*) &C[i + 4 * 1] = *(__m256d*) &A[i + 4 * 1] + *(__m256d*) &B[i + 4 * 1];
*(__m256d*) &C[i + 4 * 2] = *(__m256d*) &A[i + 4 * 2] + *(__m256d*) &B[i + 4 * 2];
*(__m256d*) &C[i + 4 * 3] = *(__m256d*) &A[i + 4 * 3] + *(__m256d*) &B[i + 4 * 3];
*(__m256d*) &C[i + 4 * 4] = *(__m256d*) &A[i + 4 * 4] + *(__m256d*) &B[i + 4 * 4];
*(__m256d*) &C[i + 4 * 5] = *(__m256d*) &A[i + 4 * 5] + *(__m256d*) &B[i + 4 * 5];
*(__m256d*) &C[i + 4 * 6] = *(__m256d*) &A[i + 4 * 6] + *(__m256d*) &B[i + 4 * 6];
*(__m256d*) &C[i + 4 * 7] = *(__m256d*) &A[i + 4 * 7] + *(__m256d*) &B[i + 4 * 7];
}
}
static void matrix_minus(int n, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
for (int i = 0; i < n * n; i += 32) {
*(__m256d*) &C[i + 4 * 0] = *(__m256d*) &A[i + 4 * 0] - *(__m256d*) &B[i + 4 * 0];
*(__m256d*) &C[i + 4 * 1] = *(__m256d*) &A[i + 4 * 1] - *(__m256d*) &B[i + 4 * 1];
*(__m256d*) &C[i + 4 * 2] = *(__m256d*) &A[i + 4 * 2] - *(__m256d*) &B[i + 4 * 2];
*(__m256d*) &C[i + 4 * 3] = *(__m256d*) &A[i + 4 * 3] - *(__m256d*) &B[i + 4 * 3];
*(__m256d*) &C[i + 4 * 4] = *(__m256d*) &A[i + 4 * 4] - *(__m256d*) &B[i + 4 * 4];
*(__m256d*) &C[i + 4 * 5] = *(__m256d*) &A[i + 4 * 5] - *(__m256d*) &B[i + 4 * 5];
*(__m256d*) &C[i + 4 * 6] = *(__m256d*) &A[i + 4 * 6] - *(__m256d*) &B[i + 4 * 6];
*(__m256d*) &C[i + 4 * 7] = *(__m256d*) &A[i + 4 * 7] - *(__m256d*) &B[i + 4 * 7];
}
}
void matrix_multiply(int n, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
if (n == efn) {
double A_[lda * lda] __attribute__((aligned(4096))), B_[lda * lda] __attribute__((aligned(4096))), C_[lda * lda] __attribute__((aligned(4096)));
convert(A, A_);
convert(B, B_);
memset(C_, 0, sizeof(C_));
gao(8, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, A_, B_, C_);
iconvert(C_, C);
return;
}
const int m = n / 2;
const int s = m * m;
double M1[s] __attribute__((aligned(4096))), M2[s] __attribute__((aligned(4096))), M3[s] __attribute__((aligned(4096))), M4[s] __attribute__((aligned(4096))), M5[s] __attribute__((aligned(4096))), M6[s] __attribute__((aligned(4096))), M7[s] __attribute__((aligned(4096))), a11[s] __attribute__((aligned(4096))), a12[s] __attribute__((aligned(4096))), a21[s] __attribute__((aligned(4096))), a22[s] __attribute__((aligned(4096))), b11[s] __attribute__((aligned(4096))), b12[s] __attribute__((aligned(4096))), b21[s] __attribute__((aligned(4096))), b22[s] __attribute__((aligned(4096))), t1[s] __attribute__((aligned(4096))), t2[s] __attribute__((aligned(4096)));
submatrix(n, A, a11, 1, 1);
submatrix(n, A, a12, 1, 2);
submatrix(n, A, a21, 2, 1);
submatrix(n, A, a22, 2, 2);
submatrix(n, B, b11, 1, 1);
submatrix(n, B, b12, 1, 2);
submatrix(n, B, b21, 2, 1);
submatrix(n, B, b22, 2, 2);
matrix_plus(m, a11, a22, t1);
matrix_plus(m, b11, b22, t2);
matrix_multiply(m, t1, t2, M1);
matrix_plus(m, a21, a22, t1);
matrix_multiply(m, t1, b11, M2);
matrix_minus(m, b12, b22, t2);
matrix_multiply(m, a11, t2, M3);
matrix_minus(m, b21, b11, t2);
matrix_multiply(m, a22, t2, M4);
matrix_plus(m, a11, a12, t1);
matrix_multiply(m, t1, b22, M5);
matrix_minus(m, a21, a11, t1);
matrix_plus(m, b11, b12, t2);
matrix_multiply(m, t1, t2, M6);
matrix_minus(m, a12, a22, t1);
matrix_plus(m, b21, b22, t2);
matrix_multiply(m, t1, t2, M7);
matrix_plus(m, M1, M7, t1);
matrix_minus(m, t1, M5, t2);
matrix_plus(m, t2, M4, M7); // M7 saves C11
matrix_plus(m, M1, M6, t1);
matrix_plus(m, t1, M3, t2);
matrix_minus(m, t2, M2, M6); // M6 saves C22
matrix_plus(m, M2, M4, M1); // M1 saves C21
matrix_plus(m, M3, M5, M2); // M2 saves C12
makematrix(n, C, M7, 1, 1);
makematrix(n, C, M2, 1, 2);
makematrix(n, C, M1, 2, 1);
makematrix(n, C, M6, 2, 2);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 80.378 ms | 55 MB + 936 KB | Accepted | Score: 100 | 显示更多 |