#13759
// 41 GFLOPS (why?)
#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")
#include <string.h>
#include <x86intrin.h>
#define n 1024
#define n_pad (1024 + 64)
static inline void init_arr(const double *src, double *dst) {
for (int i = 0; i < n; i++) {
memcpy(dst + i * n_pad, src + i * n, n * sizeof(double));
}
}
static inline void kernel_4_32_64(const double *A, const double *B, double *C,
int i_start, int k_start, int j_start) {
int i_end = i_start + 4;
int k_end = k_start + 32;
int j_end = j_start + 64;
{
int i = i_start;
const double *ai = A + i * n_pad;
for (int j = j_start; j < j_end; j += 2 * 4) {
// Load cij
double *cij = C + i * n + j;
__m256d c00 = * (__m256d *) (cij + n * 0), c01 = * (__m256d *) (cij + 4 + n * 0);
__m256d c10 = * (__m256d *) (cij + n * 1), c11 = * (__m256d *) (cij + 4 + n * 1);
__m256d c20 = * (__m256d *) (cij + n * 2), c21 = * (__m256d *) (cij + 4 + n * 2);
__m256d c30 = * (__m256d *) (cij + n * 3), c31 = * (__m256d *) (cij + 4 + n * 3);
__m256d a0, a1, a2, a3;
__m256d b0, b1;
#pragma GCC unroll 4
for (int k = k_start; k < k_end; k++) {
const double *bk = B + k * n_pad;
const double *bk_s = bk + j;
// Load aik
a0 = _mm256_broadcast_sd(ai + k + n_pad * 0);
a1 = _mm256_broadcast_sd(ai + k + n_pad * 1);
a2 = _mm256_broadcast_sd(ai + k + n_pad * 2);
a3 = _mm256_broadcast_sd(ai + k + n_pad * 3);
// Load bkj
b0 = * (__m256d *) bk_s, b1 = * (__m256d *) (bk_s + 4);
// Calc cij
c00 = _mm256_fmadd_pd(a0, b0, c00); c01 = _mm256_fmadd_pd(a0, b1, c01);
c10 = _mm256_fmadd_pd(a1, b0, c10); c11 = _mm256_fmadd_pd(a1, b1, c11);
c20 = _mm256_fmadd_pd(a2, b0, c20); c21 = _mm256_fmadd_pd(a2, b1, c21);
c30 = _mm256_fmadd_pd(a3, b0, c30); c31 = _mm256_fmadd_pd(a3, b1, c31);
}
// Store cij
* (__m256d *) (cij + n * 0) = c00; * (__m256d *) (cij + 4 + n * 0) = c01;
* (__m256d *) (cij + n * 1) = c10; * (__m256d *) (cij + 4 + n * 1) = c11;
* (__m256d *) (cij + n * 2) = c20; * (__m256d *) (cij + 4 + n * 2) = c21;
* (__m256d *) (cij + n * 3) = c30; * (__m256d *) (cij + 4 + n * 3) = c31;
}
}
}
static inline void kernel_255_32_64(const double *A, const double *B, double *C,
int i_start, int k_start, int j_start) {
int i_end = i_start + 255;
int k_end = k_start + 32;
int j_end = j_start + 64;
for (int i = i_start; i < i_end; i += 5) {
const double *ai = A + i * n_pad;
for (int j = j_start; j < j_end; j += 2 * 4) {
__asm__ volatile (
"movq %0, %%r15\n" // cij
"movq %1, %%r14\n" // bk_s
"movq %2, %%r13\n" // a0_addr
"movq $32, %%r12\n"
: :
"r"(C + i * n + j),
"r"(B + k_start * n_pad + j),
"r"(ai + k_start)
:
);
__asm__ volatile (
"vmovapd (%r15), %ymm0\n" // c00
"vmovapd 32(%r15), %ymm1\n" // c01
"vmovapd 8192(%r15), %ymm2\n" // c10
"vmovapd 8224(%r15), %ymm3\n" // c11
"vmovapd 16384(%r15), %ymm4\n" // c20
"vmovapd 16416(%r15), %ymm5\n" // c21
"vmovapd 24576(%r15), %ymm6\n" // c30
"vmovapd 24608(%r15), %ymm7\n" // c31
"vmovapd 32768(%r15), %ymm8\n" // c40
"vmovapd 32800(%r15), %ymm9\n" // c41
);
__asm__ volatile (
".align 8\n"
"1:\n"
"vmovapd (%r14), %ymm14\n" // b0
"vmovapd 32(%r14), %ymm15\n" // b1
"vbroadcastsd (%r13), %ymm10\n" // a0
"vbroadcastsd 8256(%r13), %ymm11\n" // a1
"vbroadcastsd 16512(%r13), %ymm12\n" // a2
"vbroadcastsd 24768(%r13), %ymm13\n" // a3
"vfmadd231pd %ymm10, %ymm14, %ymm0\n" // c00 += a0 * b0
"vfmadd231pd %ymm10, %ymm15, %ymm1\n" // c01 += a0 * b1
"vbroadcastsd 33024(%r13), %ymm10\n" // a4 (a0)
"vfmadd231pd %ymm11, %ymm14, %ymm2\n" // c10 += a1 * b0
"vfmadd231pd %ymm11, %ymm15, %ymm3\n" // c11 += a1 * b1
"vfmadd231pd %ymm12, %ymm14, %ymm4\n" // c20 += a2 * b0
"vfmadd231pd %ymm12, %ymm15, %ymm5\n" // c21 += a2 * b1
"vfmadd231pd %ymm13, %ymm14, %ymm6\n" // c30 += a3 * b0
"vfmadd231pd %ymm13, %ymm15, %ymm7\n" // c31 += a3 * b1
"vfmadd231pd %ymm10, %ymm14, %ymm8\n" // c40 += a4 * b0
"vfmadd231pd %ymm10, %ymm15, %ymm9\n" // c41 += a4 * b1
"addq $8256, %r14\n" // bk_s += 1032
"addq $8, %r13\n" // a0_addr++
"decq %r12\n"
"jne 1b\n"
);
__asm__ volatile (
"vmovapd %ymm0, (%r15)\n" // c00
"vmovapd %ymm1, 32(%r15)\n" // c01
"vmovapd %ymm2, 8192(%r15)\n" // c10
"vmovapd %ymm3, 8224(%r15)\n" // c11
"vmovapd %ymm4, 16384(%r15)\n" // c20
"vmovapd %ymm5, 16416(%r15)\n" // c21
"vmovapd %ymm6, 24576(%r15)\n" // c30
"vmovapd %ymm7, 24608(%r15)\n" // c31
"vmovapd %ymm8, 32768(%r15)\n" // c40
"vmovapd %ymm9, 32800(%r15)\n" // c41
);
__asm__ volatile ("" : : :
"%r15","%r14","%r13","%r12",
"%ymm0","%ymm1","%ymm2","%ymm3",
"%ymm4","%ymm5","%ymm6","%ymm7",
"%ymm8","%ymm9","%ymm10","%ymm11",
"%ymm12","%ymm13","%ymm14","%ymm15"
);
}
}
}
#undef n
void matrix_multiply(int n, const double *_A, const double *_B, double *C) {
const int i_step = 255;
const int k_step = 32;
const int j_step = 64; // 16 * 4
static double A[1024 * n_pad];
static double B[1024 * n_pad];
init_arr(_A, A);
init_arr(_B, B);
memset(C, 0, n * n * sizeof(double));
for (int i_start = 0; i_start + i_step <= n; i_start += i_step) {
int i_end = i_start + i_step;
for (int k_start = 0; k_start < n; k_start += k_step) {
int k_end = k_start + k_step <= n ? k_start + k_step : n;
for (int j_start = 0; j_start < n; j_start += j_step) {
int j_end = j_start + j_step;
kernel_255_32_64(A, B, C, i_start, k_start, j_start);
}
}
}
{
int i_start = n - n % i_step; // n - 4
for (int k_start = 0; k_start < n; k_start += k_step) {
int k_end = k_start + k_step <= n ? k_start + k_step : n;
for (int j_start = 0; j_start < n; j_start += j_step) {
int j_end = j_start + j_step;
kernel_4_32_64(A, B, C, i_start, k_start, j_start);
}
}
}
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 51.866 ms | 25 MB + 8 KB | Wrong Answer | Score: 0 | 显示更多 |