#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>
#define n 1024
#define idx(i, j) (((i) << 10) + (j))
static void gao(int s, long x, long y, long z, int dx, int dy, int dz, int dx2, int dy2, int dz2, int dx3, int dy3, int dz3, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
if (s == 5) {
const long m = 32;
for (long i = 0; i < m; ++i) {
__m256d c0 = _mm256_load_pd(&C[idx(x + i, z + 4 * 0)]);
__m256d c1 = _mm256_load_pd(&C[idx(x + i, z + 4 * 1)]);
__m256d c2 = _mm256_load_pd(&C[idx(x + i, z + 4 * 2)]);
__m256d c3 = _mm256_load_pd(&C[idx(x + i, z + 4 * 3)]);
__m256d c4 = _mm256_load_pd(&C[idx(x + i, z + 4 * 4)]);
__m256d c5 = _mm256_load_pd(&C[idx(x + i, z + 4 * 5)]);
__m256d c6 = _mm256_load_pd(&C[idx(x + i, z + 4 * 6)]);
__m256d c7 = _mm256_load_pd(&C[idx(x + i, z + 4 * 7)]);
for (long j = 0; j < m; j += 1) {
__m256d a0 = _mm256_set1_pd(A[idx(x + i, y + j + 0)]);
c0 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 0)]), c0);
c1 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 1)]), c1);
c2 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 2)]), c2);
c3 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 3)]), c3);
c4 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 4)]), c4);
c5 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 5)]), c5);
c6 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 6)]), c6);
c7 = _mm256_fmadd_pd(a0, _mm256_load_pd(&B[idx(y + j + 0, z + 4 * 7)]), c7);
}
_mm256_store_pd(&C[idx(x + i, z + 4 * 0)], c0);
_mm256_store_pd(&C[idx(x + i, z + 4 * 1)], c1);
_mm256_store_pd(&C[idx(x + i, z + 4 * 2)], c2);
_mm256_store_pd(&C[idx(x + i, z + 4 * 3)], c3);
_mm256_store_pd(&C[idx(x + i, z + 4 * 4)], c4);
_mm256_store_pd(&C[idx(x + i, z + 4 * 5)], c5);
_mm256_store_pd(&C[idx(x + i, z + 4 * 6)], c6);
_mm256_store_pd(&C[idx(x + i, z + 4 * 7)], c7);
}
return;
}
--s;
if (dx < 0) x -= dx << s;
if (dy < 0) y -= dy << s;
if (dz < 0) z -= dz << s;
if (dx2 < 0) x -= dx2 << s;
if (dy2 < 0) y -= dy2 << s;
if (dz2 < 0) z -= dz2 << s;
if (dx3 < 0) x -= dx3 << s;
if (dy3 < 0) y -= dy3 << s;
if (dz3 < 0) z -= dz3 << s;
gao(s, x, y, z, dx2, dy2, dz2, dx3, dy3, dz3, dx, dy, dz, A, B, C);
gao(s, x + (dx << s), y + (dy << s), z + (dz << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx << s) + (dx2 << s), y + (dy << s) + (dy2 << s), z + (dz << s) + (dz2 << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx2 << s), y + (dy2 << s), z + (dz2 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx2 << s) + (dx3 << s), y + (dy2 << s) + (dy3 << s), z + (dz2 << s) + (dz3 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx << s) + (dx2 << s) + (dx3 << s), y + (dy << s) + (dy2 << s) + (dy3 << s), z + (dz << s) + (dz2 << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx << s) + (dx3 << s), y + (dy << s) + (dy3 << s), z + (dz << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx3 << s), y + (dy3 << s), z + (dz3 << s), dx2, dy2, dz2, -dx3, -dy3, -dz3, -dx, -dy, -dz, A, B, C);
}
void matrix_multiply(int, const double* A, const double* B, double* C) {
memset(C, 0, 1024 * 1024 * sizeof(double));
gao(10, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, A, B, C);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 142.438 ms | 8 MB + 8 KB | Accepted | Score: 100 | 显示更多 |