提交记录 13759


用户 题目 状态 得分 用时 内存 语言 代码长度
wys mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 51.538 ms 24712 KB C++ 5.92 KB
提交时间 评测时间
2020-08-07 04:04:12 2020-08-07 04:04:15
// 41 GFLOPS (why?)

#pragma GCC target("avx2,fma")
#pragma GCC optimize("Ofast")

#include <string.h>
#include <x86intrin.h>

#define n 1024
#define n_pad (1024 + 8)

static inline void init_arr(const double *src, double *dst) {
	for (int i = 0; i < n; i++) {
		memcpy(dst + i * n_pad, src + i * n, n * sizeof(double));
	}
}

static inline void kernel_4_32_64(const double *A, const double *B, double *C,
	int i_start, int k_start, int j_start) {
	int i_end = i_start + 4;
	int k_end = k_start + 32;
	int j_end = j_start + 64;
	
	{
		int i = i_start;
		const double *ai = A + i * n_pad;
		
		for (int j = j_start; j < j_end; j += 2 * 4) {
			// Load cij
			double *cij = C + i * n + j;
			__m256d c00 = * (__m256d *) (cij + n * 0), c01 = * (__m256d *) (cij + 4 + n * 0);
			__m256d c10 = * (__m256d *) (cij + n * 1), c11 = * (__m256d *) (cij + 4 + n * 1);
			__m256d c20 = * (__m256d *) (cij + n * 2), c21 = * (__m256d *) (cij + 4 + n * 2);
			__m256d c30 = * (__m256d *) (cij + n * 3), c31 = * (__m256d *) (cij + 4 + n * 3);
			
			__m256d a0, a1, a2, a3;
			__m256d b0, b1;
			
			#pragma GCC unroll 4
			for (int k = k_start; k < k_end; k++) {
				const double *bk = B + k * n_pad;
				const double *bk_s = bk + j;
				
				// Load aik
				a0 = _mm256_broadcast_sd(ai + k + n_pad * 0);
				a1 = _mm256_broadcast_sd(ai + k + n_pad * 1);
				a2 = _mm256_broadcast_sd(ai + k + n_pad * 2);
				a3 = _mm256_broadcast_sd(ai + k + n_pad * 3);
				
				// Load bkj
				b0 = * (__m256d *) bk_s, b1 = * (__m256d *) (bk_s + 4);
				
				// Calc cij
				c00 = _mm256_fmadd_pd(a0, b0, c00); c01 = _mm256_fmadd_pd(a0, b1, c01);
				c10 = _mm256_fmadd_pd(a1, b0, c10); c11 = _mm256_fmadd_pd(a1, b1, c11);
				c20 = _mm256_fmadd_pd(a2, b0, c20); c21 = _mm256_fmadd_pd(a2, b1, c21);
				c30 = _mm256_fmadd_pd(a3, b0, c30); c31 = _mm256_fmadd_pd(a3, b1, c31);
			}
			
			// Store cij
			* (__m256d *) (cij + n * 0) = c00; * (__m256d *) (cij + 4 + n * 0) = c01;
			* (__m256d *) (cij + n * 1) = c10; * (__m256d *) (cij + 4 + n * 1) = c11;
			* (__m256d *) (cij + n * 2) = c20; * (__m256d *) (cij + 4 + n * 2) = c21;
			* (__m256d *) (cij + n * 3) = c30; * (__m256d *) (cij + 4 + n * 3) = c31;
		}
	}
}

static inline void kernel_255_32_64(const double *A, const double *B, double *C,
	int i_start, int k_start, int j_start) {
	int i_end = i_start + 255;
	int k_end = k_start + 32;
	int j_end = j_start + 64;
	
	for (int i = i_start; i < i_end; i += 5) {
		const double *ai = A + i * n_pad;
		
		for (int j = j_start; j < j_end; j += 2 * 4) {
			__asm__ volatile (
				"movq %0, %%r15\n"  // cij
				"movq %1, %%r14\n"  // bk_s
				"movq %2, %%r13\n"  // a0_addr
				"movq $32, %%r12\n"
				: :
				"r"(C + i * n + j),
				"r"(B + k_start * n_pad + j),
				"r"(ai + k_start)
				:
			);
			
			__asm__ volatile (
				"vmovapd (%r15), %ymm0\n"  // c00
				"vmovapd 32(%r15), %ymm1\n"  // c01
				"vmovapd 8192(%r15), %ymm2\n"  // c10
				"vmovapd 8224(%r15), %ymm3\n"  // c11
				"vmovapd 16384(%r15), %ymm4\n"  // c20
				"vmovapd 16416(%r15), %ymm5\n"  // c21
				"vmovapd 24576(%r15), %ymm6\n"  // c30
				"vmovapd 24608(%r15), %ymm7\n"  // c31
				"vmovapd 32768(%r15), %ymm8\n"  // c40
				"vmovapd 32800(%r15), %ymm9\n"  // c41
			);
			
			__asm__ volatile (
				".align 8\n"
				"1:\n"
				"vmovapd (%r14), %ymm14\n"  // b0
				"vmovapd 32(%r14), %ymm15\n"  // b1
				"vbroadcastsd (%r13), %ymm10\n"  // a0
				"vbroadcastsd 8256(%r13), %ymm11\n"  // a1
				"vbroadcastsd 16512(%r13), %ymm12\n"  // a2
				"vbroadcastsd 24768(%r13), %ymm13\n"  // a3
				"vfmadd231pd %ymm10, %ymm14, %ymm0\n"  // c00 += a0 * b0
				"vfmadd231pd %ymm10, %ymm15, %ymm1\n"  // c01 += a0 * b1
				"vbroadcastsd 33024(%r13), %ymm10\n"  // a4 (a0)
				"vfmadd231pd %ymm11, %ymm14, %ymm2\n"  // c10 += a1 * b0
				"vfmadd231pd %ymm11, %ymm15, %ymm3\n"  // c11 += a1 * b1
				"vfmadd231pd %ymm12, %ymm14, %ymm4\n"  // c20 += a2 * b0
				"vfmadd231pd %ymm12, %ymm15, %ymm5\n"  // c21 += a2 * b1
				"vfmadd231pd %ymm13, %ymm14, %ymm6\n"  // c30 += a3 * b0
				"vfmadd231pd %ymm13, %ymm15, %ymm7\n"  // c31 += a3 * b1
				"vfmadd231pd %ymm10, %ymm14, %ymm8\n"  // c40 += a4 * b0
				"vfmadd231pd %ymm10, %ymm15, %ymm9\n"  // c41 += a4 * b1
				"addq $8256, %r14\n"  // bk_s += 1032
				"addq $8, %r13\n"  // a0_addr++
				"decq %r12\n"
				"jne 1b\n"
			);
			
			__asm__ volatile (
				"vmovapd %ymm0, (%r15)\n"  // c00
				"vmovapd %ymm1, 32(%r15)\n"  // c01
				"vmovapd %ymm2, 8192(%r15)\n"  // c10
				"vmovapd %ymm3, 8224(%r15)\n"  // c11
				"vmovapd %ymm4, 16384(%r15)\n"  // c20
				"vmovapd %ymm5, 16416(%r15)\n"  // c21
				"vmovapd %ymm6, 24576(%r15)\n"  // c30
				"vmovapd %ymm7, 24608(%r15)\n"  // c31
				"vmovapd %ymm8, 32768(%r15)\n"  // c40
				"vmovapd %ymm9, 32800(%r15)\n"  // c41
			);
			
			__asm__ volatile ("" : : :
				"%r15","%r14","%r13","%r12",
				"%ymm0","%ymm1","%ymm2","%ymm3",
				"%ymm4","%ymm5","%ymm6","%ymm7",
				"%ymm8","%ymm9","%ymm10","%ymm11",
				"%ymm12","%ymm13","%ymm14","%ymm15"
			);
		}
	}
}

#undef n

void matrix_multiply(int n, const double *_A, const double *_B, double *C) {
	const int i_step = 255;
	const int k_step = 32;
	const int j_step = 64;  // 16 * 4
	
	static double A[1024 * n_pad];
	static double B[1024 * n_pad];
	init_arr(_A, A);
	init_arr(_B, B);
	
	memset(C, 0, n * n * sizeof(double));
	
	for (int i_start = 0; i_start + i_step <= n; i_start += i_step) {
		int i_end = i_start + i_step;
		
		for (int k_start = 0; k_start < n; k_start += k_step) {
			int k_end = k_start + k_step <= n ? k_start + k_step : n;
			
			for (int j_start = 0; j_start < n; j_start += j_step) {
				int j_end = j_start + j_step;
				
				kernel_255_32_64(A, B, C, i_start, k_start, j_start);
			}
		}
	}
	
	{
		int i_start = n - n % i_step;  // n - 4
		
		for (int k_start = 0; k_start < n; k_start += k_step) {
			int k_end = k_start + k_step <= n ? k_start + k_step : n;
			
			for (int j_start = 0; j_start < n; j_start += j_step) {
				int j_end = j_start + j_step;
				
				kernel_4_32_64(A, B, C, i_start, k_start, j_start);
			}
		}
	}
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #151.538 ms24 MB + 136 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-23 18:58:41 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠