// modified from https://duck.ac/submission/13766
#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <string.h>
#include <x86intrin.h>
#define n 1024
#define n_pad_a (1024 + 32)
#define n_pad_b (1024 + 64)
static inline void memcpy_1024(double *dst, const double *src) {
__asm__ volatile (
"movq %0, %%r15\n" // src
"movq %1, %%r14\n" // dst
"movq $32, %%r13\n"
: :
"r"(src), "r"(dst)
:
);
__asm__ volatile (
".align 8\n"
"1:\n"
"vmovapd 0(%r15), %ymm0\n"
"vmovapd 32(%r15), %ymm1\n"
"vmovapd 64(%r15), %ymm2\n"
"vmovapd 96(%r15), %ymm3\n"
"vmovapd 128(%r15), %ymm4\n"
"vmovapd 160(%r15), %ymm5\n"
"vmovapd 192(%r15), %ymm6\n"
"vmovapd 224(%r15), %ymm7\n"
"addq $256, %r15\n"
"vmovapd %ymm0, 0(%r14)\n"
"vmovapd %ymm1, 32(%r14)\n"
"vmovapd %ymm2, 64(%r14)\n"
"vmovapd %ymm3, 96(%r14)\n"
"vmovapd %ymm4, 128(%r14)\n"
"vmovapd %ymm5, 160(%r14)\n"
"vmovapd %ymm6, 192(%r14)\n"
"vmovapd %ymm7, 224(%r14)\n"
"addq $256, %r14\n"
"decq %r13\n"
"jne 1b\n"
);
__asm__ volatile ("" : : :
"%r15","%r14","%r13",
"%ymm0","%ymm1","%ymm2","%ymm3",
"%ymm4","%ymm5","%ymm6","%ymm7"
);
}
static inline void init_arr_a(const double *src, double *dst) {
for (int i = 0; i < n; i++) {
memcpy_1024(dst + i * n_pad_a, src + i * n);
}
}
static inline void init_arr_b(const double *src, double *dst) {
for (int i = 0; i < n; i++) {
memcpy_1024(dst + i * n_pad_b, src + i * n);
}
}
static inline void kernel_32_32_32(const double *A, const double *B, double *C,
int i_start, int k_start, int j_start) {
int i_end = i_start + 32;
int k_end = k_start + 32;
int j_end = j_start + 32;
// 4 * 32 * 32
for (int j_block_id = 0; j_block_id < 1; j_block_id++, j_start += 32) {
for (int i = i_start; i < i_end; i += 4) {
const double *ai = A + i * n_pad_a;
for (int _j = 0, j = j_start; _j < 4; _j++, j += 8) {
__asm__ volatile (
"movq %0, %%r15\n" // cij
"movq %1, %%r14\n" // bk_s
"movq %2, %%r13\n" // a0_addr
"movq $15, %%r12\n"
: :
"r"(C + i * n + j),
"r"(B + k_start * n_pad_b + j),
"r"(ai + k_start)
:
);
__asm__ volatile (
"vmovapd (%r15), %ymm0\n" // c00
"vmovapd 32(%r15), %ymm1\n" // c01
"vmovapd 8192(%r15), %ymm2\n" // c10
"vmovapd 8224(%r15), %ymm3\n" // c11
"vmovapd 16384(%r15), %ymm4\n" // c20
"vmovapd 16416(%r15), %ymm5\n" // c21
"vmovapd 24576(%r15), %ymm6\n" // c30
"vmovapd 24608(%r15), %ymm7\n" // c31
);
__asm__ volatile (
"vmovapd (%r14), %ymm14\n" // b0
"vmovapd 32(%r14), %ymm15\n" // b1
"vbroadcastsd (%r13), %ymm10\n" // a0
"vbroadcastsd 8448(%r13), %ymm11\n" // a1
"vbroadcastsd 16896(%r13), %ymm12\n" // a2
"vbroadcastsd 25344(%r13), %ymm13\n" // a3
".align 8\n"
"1:\n"
"vmovapd 8704(%r14), %ymm8\n" // next b0
"vmovapd 8736(%r14), %ymm9\n" // next b1
"vfmadd231pd %ymm10, %ymm14, %ymm0\n" // c00 += a0 * b0
"vfmadd231pd %ymm10, %ymm15, %ymm1\n" // c01 += a0 * b1
"vbroadcastsd 8(%r13), %ymm10\n" // a0
"vfmadd231pd %ymm11, %ymm14, %ymm2\n" // c10 += a1 * b0
"vfmadd231pd %ymm11, %ymm15, %ymm3\n" // c11 += a1 * b1
"vbroadcastsd 8456(%r13), %ymm11\n" // a1
"vfmadd231pd %ymm12, %ymm14, %ymm4\n" // c20 += a2 * b0
"vfmadd231pd %ymm12, %ymm15, %ymm5\n" // c21 += a2 * b1
"vbroadcastsd 16904(%r13), %ymm12\n" // a2
"vfmadd231pd %ymm13, %ymm14, %ymm6\n" // c30 += a3 * b0
"vfmadd231pd %ymm13, %ymm15, %ymm7\n" // c31 += a3 * b1
"vbroadcastsd 25352(%r13), %ymm13\n" // a3
"addq $17408, %r14\n" // bk_s += (1024+64)*2
"addq $16, %r13\n" // a0_addr += 2
"vmovapd (%r14), %ymm14\n" // next b0
"vmovapd 32(%r14), %ymm15\n" // next b1
"vfmadd231pd %ymm10, %ymm8, %ymm0\n" // c00 += a0 * b0
"vfmadd231pd %ymm10, %ymm9, %ymm1\n" // c01 += a0 * b1
"vbroadcastsd (%r13), %ymm10\n" // a0
"vfmadd231pd %ymm11, %ymm8, %ymm2\n" // c10 += a1 * b0
"vfmadd231pd %ymm11, %ymm9, %ymm3\n" // c11 += a1 * b1
"vbroadcastsd 8448(%r13), %ymm11\n" // a1
"vfmadd231pd %ymm12, %ymm8, %ymm4\n" // c20 += a2 * b0
"vfmadd231pd %ymm12, %ymm9, %ymm5\n" // c21 += a2 * b1
"vbroadcastsd 16896(%r13), %ymm12\n" // a2
"vfmadd231pd %ymm13, %ymm8, %ymm6\n" // c30 += a3 * b0
"vfmadd231pd %ymm13, %ymm9, %ymm7\n" // c31 += a3 * b1
"vbroadcastsd 25344(%r13), %ymm13\n" // a3
"decq %r12\n"
"jne 1b\n"
"vmovapd 8704(%r14), %ymm8\n" // next b0
"vmovapd 8736(%r14), %ymm9\n" // next b1
"vfmadd231pd %ymm10, %ymm14, %ymm0\n" // c00 += a0 * b0
"vfmadd231pd %ymm10, %ymm15, %ymm1\n" // c01 += a0 * b1
"vbroadcastsd 8(%r13), %ymm10\n" // a0
"vfmadd231pd %ymm11, %ymm14, %ymm2\n" // c10 += a1 * b0
"vfmadd231pd %ymm11, %ymm15, %ymm3\n" // c11 += a1 * b1
"vbroadcastsd 8456(%r13), %ymm11\n" // a1
"vfmadd231pd %ymm12, %ymm14, %ymm4\n" // c20 += a2 * b0
"vfmadd231pd %ymm12, %ymm15, %ymm5\n" // c21 += a2 * b1
"vbroadcastsd 16904(%r13), %ymm12\n" // a2
"vfmadd231pd %ymm13, %ymm14, %ymm6\n" // c30 += a3 * b0
"vfmadd231pd %ymm13, %ymm15, %ymm7\n" // c31 += a3 * b1
"vbroadcastsd 25352(%r13), %ymm13\n" // a3
"vfmadd231pd %ymm10, %ymm8, %ymm0\n" // c00 += a0 * b0
"vfmadd231pd %ymm10, %ymm9, %ymm1\n" // c01 += a0 * b1
"vfmadd231pd %ymm11, %ymm8, %ymm2\n" // c10 += a1 * b0
"vfmadd231pd %ymm11, %ymm9, %ymm3\n" // c11 += a1 * b1
"vfmadd231pd %ymm12, %ymm8, %ymm4\n" // c20 += a2 * b0
"vfmadd231pd %ymm12, %ymm9, %ymm5\n" // c21 += a2 * b1
"vfmadd231pd %ymm13, %ymm8, %ymm6\n" // c30 += a3 * b0
"vfmadd231pd %ymm13, %ymm9, %ymm7\n" // c31 += a3 * b1
);
__asm__ volatile (
"vmovapd %ymm0, (%r15)\n" // c00
"vmovapd %ymm1, 32(%r15)\n" // c01
"vmovapd %ymm2, 8192(%r15)\n" // c10
"vmovapd %ymm3, 8224(%r15)\n" // c11
"vmovapd %ymm4, 16384(%r15)\n" // c20
"vmovapd %ymm5, 16416(%r15)\n" // c21
"vmovapd %ymm6, 24576(%r15)\n" // c30
"vmovapd %ymm7, 24608(%r15)\n" // c31
);
__asm__ volatile ("" : : :
"%r15","%r14","%r13","%r12",
"%ymm0","%ymm1","%ymm2","%ymm3",
"%ymm4","%ymm5","%ymm6","%ymm7",
"%ymm8","%ymm9","%ymm10","%ymm11",
"%ymm12","%ymm13","%ymm14","%ymm15"
);
}
}
}
}
#undef n
static inline void gao(int s, int x, int y, int z, int dx, int dy, int dz, int dx2, int dy2, int dz2, int dx3, int dy3, int dz3, const double* A, const double* B, double* C) {
if (s == 5) {
kernel_32_32_32(A, B, C, x, y, z);
return;
}
--s;
if (dx < 0) x -= dx << s;
if (dy < 0) y -= dy << s;
if (dz < 0) z -= dz << s;
if (dx2 < 0) x -= dx2 << s;
if (dy2 < 0) y -= dy2 << s;
if (dz2 < 0) z -= dz2 << s;
if (dx3 < 0) x -= dx3 << s;
if (dy3 < 0) y -= dy3 << s;
if (dz3 < 0) z -= dz3 << s;
gao(s, x, y, z, dx2, dy2, dz2, dx3, dy3, dz3, dx, dy, dz, A, B, C);
gao(s, x + (dx << s), y + (dy << s), z + (dz << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx << s) + (dx2 << s), y + (dy << s) + (dy2 << s), z + (dz << s) + (dz2 << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
gao(s, x + (dx2 << s), y + (dy2 << s), z + (dz2 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx2 << s) + (dx3 << s), y + (dy2 << s) + (dy3 << s), z + (dz2 << s) + (dz3 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
gao(s, x + (dx << s) + (dx2 << s) + (dx3 << s), y + (dy << s) + (dy2 << s) + (dy3 << s), z + (dz << s) + (dz2 << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx << s) + (dx3 << s), y + (dy << s) + (dy3 << s), z + (dz << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
gao(s, x + (dx3 << s), y + (dy3 << s), z + (dz3 << s), dx2, dy2, dz2, -dx3, -dy3, -dz3, -dx, -dy, -dz, A, B, C);
}
void matrix_multiply(int n, const double *_A, const double *_B, double *C) {
const int i_step = 32;
const int k_step = 32;
const int j_step = 32;
static double A[1024 * n_pad_a];
static double B[1024 * n_pad_b];
init_arr_a(_A, A);
init_arr_b(_B, B);
memset(C, 0, n * n * sizeof(double));
gao(10, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, A, B, C);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 54.109 ms | 24 MB + 776 KB | Accepted | Score: 100 | 显示更多 |