提交记录 14999


用户 题目 状态 得分 用时 内存 语言 代码长度
mmmd1k mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 54.109 ms 25352 KB C++ 8.26 KB
提交时间 评测时间
2020-11-14 13:58:05 2020-11-14 13:58:08
// modified from https://duck.ac/submission/13766

#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")

#include <string.h>
#include <x86intrin.h>

#define n 1024
#define n_pad_a (1024 + 32)
#define n_pad_b (1024 + 64)

static inline void memcpy_1024(double *dst, const double *src) {
	__asm__ volatile (
		"movq %0, %%r15\n"  // src
		"movq %1, %%r14\n"  // dst
		"movq $32, %%r13\n"
		: :
		"r"(src), "r"(dst)
		:
	);
	
	__asm__ volatile (
		".align 8\n"
		"1:\n"
		"vmovapd 0(%r15), %ymm0\n"
		"vmovapd 32(%r15), %ymm1\n"
		"vmovapd 64(%r15), %ymm2\n"
		"vmovapd 96(%r15), %ymm3\n"
		"vmovapd 128(%r15), %ymm4\n"
		"vmovapd 160(%r15), %ymm5\n"
		"vmovapd 192(%r15), %ymm6\n"
		"vmovapd 224(%r15), %ymm7\n"
		"addq $256, %r15\n"
		"vmovapd %ymm0, 0(%r14)\n"
		"vmovapd %ymm1, 32(%r14)\n"
		"vmovapd %ymm2, 64(%r14)\n"
		"vmovapd %ymm3, 96(%r14)\n"
		"vmovapd %ymm4, 128(%r14)\n"
		"vmovapd %ymm5, 160(%r14)\n"
		"vmovapd %ymm6, 192(%r14)\n"
		"vmovapd %ymm7, 224(%r14)\n"
		"addq $256, %r14\n"
		"decq %r13\n"
		"jne 1b\n"
	);
	
	__asm__ volatile ("" : : :
		"%r15","%r14","%r13",
		"%ymm0","%ymm1","%ymm2","%ymm3",
		"%ymm4","%ymm5","%ymm6","%ymm7"
	);
}

static inline void init_arr_a(const double *src, double *dst) {
	for (int i = 0; i < n; i++) {
		memcpy_1024(dst + i * n_pad_a, src + i * n);
	}
}

static inline void init_arr_b(const double *src, double *dst) {
	for (int i = 0; i < n; i++) {
		memcpy_1024(dst + i * n_pad_b, src + i * n);
	}
}

static inline void kernel_32_32_32(const double *A, const double *B, double *C,
	int i_start, int k_start, int j_start) {
	int i_end = i_start + 32;
	int k_end = k_start + 32;
	int j_end = j_start + 32;
	
	// 4 * 32 * 32
	for (int j_block_id = 0; j_block_id < 1; j_block_id++, j_start += 32) {
		
		for (int i = i_start; i < i_end; i += 4) {
			const double *ai = A + i * n_pad_a;
			
			for (int _j = 0, j = j_start; _j < 4; _j++, j += 8) {
				
				__asm__ volatile (
					"movq %0, %%r15\n"  // cij
					"movq %1, %%r14\n"  // bk_s
					"movq %2, %%r13\n"  // a0_addr
					"movq $15, %%r12\n"
					: :
					"r"(C + i * n + j),
					"r"(B + k_start * n_pad_b + j),
					"r"(ai + k_start)
					:
				);
				
				__asm__ volatile (
					"vmovapd (%r15), %ymm0\n"  // c00
					"vmovapd 32(%r15), %ymm1\n"  // c01
					"vmovapd 8192(%r15), %ymm2\n"  // c10
					"vmovapd 8224(%r15), %ymm3\n"  // c11
					"vmovapd 16384(%r15), %ymm4\n"  // c20
					"vmovapd 16416(%r15), %ymm5\n"  // c21
					"vmovapd 24576(%r15), %ymm6\n"  // c30
					"vmovapd 24608(%r15), %ymm7\n"  // c31
				);
				
				__asm__ volatile (
					"vmovapd (%r14), %ymm14\n"  // b0
					"vmovapd 32(%r14), %ymm15\n"  // b1
					"vbroadcastsd (%r13), %ymm10\n"  // a0
					"vbroadcastsd 8448(%r13), %ymm11\n"  // a1
					"vbroadcastsd 16896(%r13), %ymm12\n"  // a2
					"vbroadcastsd 25344(%r13), %ymm13\n"  // a3
					
					".align 8\n"
					"1:\n"
					"vmovapd 8704(%r14), %ymm8\n"  // next b0
					"vmovapd 8736(%r14), %ymm9\n"  // next b1
					"vfmadd231pd %ymm10, %ymm14, %ymm0\n"  // c00 += a0 * b0
					"vfmadd231pd %ymm10, %ymm15, %ymm1\n"  // c01 += a0 * b1
					"vbroadcastsd 8(%r13), %ymm10\n"  // a0
					"vfmadd231pd %ymm11, %ymm14, %ymm2\n"  // c10 += a1 * b0
					"vfmadd231pd %ymm11, %ymm15, %ymm3\n"  // c11 += a1 * b1
					"vbroadcastsd 8456(%r13), %ymm11\n"  // a1
					"vfmadd231pd %ymm12, %ymm14, %ymm4\n"  // c20 += a2 * b0
					"vfmadd231pd %ymm12, %ymm15, %ymm5\n"  // c21 += a2 * b1
					"vbroadcastsd 16904(%r13), %ymm12\n"  // a2
					"vfmadd231pd %ymm13, %ymm14, %ymm6\n"  // c30 += a3 * b0
					"vfmadd231pd %ymm13, %ymm15, %ymm7\n"  // c31 += a3 * b1
					"vbroadcastsd 25352(%r13), %ymm13\n"  // a3
					
					"addq $17408, %r14\n"  // bk_s += (1024+64)*2
					"addq $16, %r13\n"  // a0_addr += 2
					"vmovapd (%r14), %ymm14\n"  // next b0
					"vmovapd 32(%r14), %ymm15\n"  // next b1
					"vfmadd231pd %ymm10, %ymm8, %ymm0\n"  // c00 += a0 * b0
					"vfmadd231pd %ymm10, %ymm9, %ymm1\n"  // c01 += a0 * b1
					"vbroadcastsd (%r13), %ymm10\n"  // a0
					"vfmadd231pd %ymm11, %ymm8, %ymm2\n"  // c10 += a1 * b0
					"vfmadd231pd %ymm11, %ymm9, %ymm3\n"  // c11 += a1 * b1
					"vbroadcastsd 8448(%r13), %ymm11\n"  // a1
					"vfmadd231pd %ymm12, %ymm8, %ymm4\n"  // c20 += a2 * b0
					"vfmadd231pd %ymm12, %ymm9, %ymm5\n"  // c21 += a2 * b1
					"vbroadcastsd 16896(%r13), %ymm12\n"  // a2
					"vfmadd231pd %ymm13, %ymm8, %ymm6\n"  // c30 += a3 * b0
					"vfmadd231pd %ymm13, %ymm9, %ymm7\n"  // c31 += a3 * b1
					"vbroadcastsd 25344(%r13), %ymm13\n"  // a3
					
					"decq %r12\n"
					"jne 1b\n"
					
					"vmovapd 8704(%r14), %ymm8\n"  // next b0
					"vmovapd 8736(%r14), %ymm9\n"  // next b1
					"vfmadd231pd %ymm10, %ymm14, %ymm0\n"  // c00 += a0 * b0
					"vfmadd231pd %ymm10, %ymm15, %ymm1\n"  // c01 += a0 * b1
					"vbroadcastsd 8(%r13), %ymm10\n"  // a0
					"vfmadd231pd %ymm11, %ymm14, %ymm2\n"  // c10 += a1 * b0
					"vfmadd231pd %ymm11, %ymm15, %ymm3\n"  // c11 += a1 * b1
					"vbroadcastsd 8456(%r13), %ymm11\n"  // a1
					"vfmadd231pd %ymm12, %ymm14, %ymm4\n"  // c20 += a2 * b0
					"vfmadd231pd %ymm12, %ymm15, %ymm5\n"  // c21 += a2 * b1
					"vbroadcastsd 16904(%r13), %ymm12\n"  // a2
					"vfmadd231pd %ymm13, %ymm14, %ymm6\n"  // c30 += a3 * b0
					"vfmadd231pd %ymm13, %ymm15, %ymm7\n"  // c31 += a3 * b1
					"vbroadcastsd 25352(%r13), %ymm13\n"  // a3
					
					"vfmadd231pd %ymm10, %ymm8, %ymm0\n"  // c00 += a0 * b0
					"vfmadd231pd %ymm10, %ymm9, %ymm1\n"  // c01 += a0 * b1
					"vfmadd231pd %ymm11, %ymm8, %ymm2\n"  // c10 += a1 * b0
					"vfmadd231pd %ymm11, %ymm9, %ymm3\n"  // c11 += a1 * b1
					"vfmadd231pd %ymm12, %ymm8, %ymm4\n"  // c20 += a2 * b0
					"vfmadd231pd %ymm12, %ymm9, %ymm5\n"  // c21 += a2 * b1
					"vfmadd231pd %ymm13, %ymm8, %ymm6\n"  // c30 += a3 * b0
					"vfmadd231pd %ymm13, %ymm9, %ymm7\n"  // c31 += a3 * b1
				);
				
				__asm__ volatile (
					"vmovapd %ymm0, (%r15)\n"  // c00
					"vmovapd %ymm1, 32(%r15)\n"  // c01
					"vmovapd %ymm2, 8192(%r15)\n"  // c10
					"vmovapd %ymm3, 8224(%r15)\n"  // c11
					"vmovapd %ymm4, 16384(%r15)\n"  // c20
					"vmovapd %ymm5, 16416(%r15)\n"  // c21
					"vmovapd %ymm6, 24576(%r15)\n"  // c30
					"vmovapd %ymm7, 24608(%r15)\n"  // c31
				);
				
				__asm__ volatile ("" : : :
					"%r15","%r14","%r13","%r12",
					"%ymm0","%ymm1","%ymm2","%ymm3",
					"%ymm4","%ymm5","%ymm6","%ymm7",
					"%ymm8","%ymm9","%ymm10","%ymm11",
					"%ymm12","%ymm13","%ymm14","%ymm15"
				);
			}
		}
	}
}

#undef n

static inline void gao(int s, int x, int y, int z, int dx, int dy, int dz, int dx2, int dy2, int dz2, int dx3, int dy3, int dz3, const double* A, const double* B, double* C) {
    if (s == 5) {
        kernel_32_32_32(A, B, C, x, y, z);
        return;
    }
    --s;
    if (dx < 0) x -= dx << s;
    if (dy < 0) y -= dy << s;
    if (dz < 0) z -= dz << s;
    if (dx2 < 0) x -= dx2 << s;
    if (dy2 < 0) y -= dy2 << s;
    if (dz2 < 0) z -= dz2 << s;
    if (dx3 < 0) x -= dx3 << s;
    if (dy3 < 0) y -= dy3 << s;
    if (dz3 < 0) z -= dz3 << s;
    gao(s, x, y, z, dx2, dy2, dz2, dx3, dy3, dz3, dx, dy, dz, A, B, C);
    gao(s, x + (dx << s), y + (dy << s), z + (dz << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
    gao(s, x + (dx << s) + (dx2 << s), y + (dy << s) + (dy2 << s), z + (dz << s) + (dz2 << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
    gao(s, x + (dx2 << s), y + (dy2 << s), z + (dz2 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
    gao(s, x + (dx2 << s) + (dx3 << s), y + (dy2 << s) + (dy3 << s), z + (dz2 << s) + (dz3 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
    gao(s, x + (dx << s) + (dx2 << s) + (dx3 << s), y + (dy << s) + (dy2 << s) + (dy3 << s), z + (dz << s) + (dz2 << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
    gao(s, x + (dx << s) + (dx3 << s), y + (dy << s) + (dy3 << s), z + (dz << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
    gao(s, x + (dx3 << s), y + (dy3 << s), z + (dz3 << s), dx2, dy2, dz2, -dx3, -dy3, -dz3, -dx, -dy, -dz, A, B, C);
}

void matrix_multiply(int n, const double *_A, const double *_B, double *C) {
	const int i_step = 32;
	const int k_step = 32;
	const int j_step = 32;
	
	static double A[1024 * n_pad_a];
	static double B[1024 * n_pad_b];
	init_arr_a(_A, A);
	init_arr_b(_B, B);
	
	memset(C, 0, n * n * sizeof(double));
	gao(10, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, A, B, C);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #154.109 ms24 MB + 776 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-21 13:49:49 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠