提交记录 13681


用户 题目 状态 得分 用时 内存 语言 代码长度
wys mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 69.893 ms 16392 KB C++ 1.96 KB
提交时间 评测时间
2020-08-06 00:18:09 2020-08-06 00:18:12
// ikj registers + L1

#pragma GCC target("avx2,fma")

#include <string.h>
#include <x86intrin.h>

static void rotate_row(int n, int step, const double *src, double *dst) {
	memcpy(dst, src + n - step, step * sizeof(double));
	memcpy(dst + step, src, (n - step) * sizeof(double));
}

static void rotate(int n, int j_step, const double *src, double *dst) {
	for (int i = 0; i < n; i++) {
		rotate_row(n, (i * j_step) & 1023, src + i * n, dst + i * n);
	}
}

void matrix_multiply(int n, const double *A, const double *_B, double *C) {
	const int i_step = 256;
	const int k_step = 32;
	const int j_step = 64;  // 16 * 4
	
	static double B[1024 * 1024];
	rotate(n, j_step, _B, B);
	
	memset(C, 0, n * n * sizeof(double));
	
	for (int i_start = 0; i_start < n; i_start += i_step) {
		int i_end = i_start + i_step;
		
		for (int k_start = 0; k_start < n; k_start += k_step) {
			int k_end = k_start + k_step <= n ? k_start + k_step : n;
			
			for (int j_start = 0; j_start < n; j_start += j_step) {
				
				for (int i = i_start; i < i_end; i++) {
					const double *ai = A + i * n;
					double *ci = C + i * n;
					double *ci_s = ci + j_start;
					
					#define LOOP8(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7)
					#define LOOP16(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7) \
						f(8) f(9) f(10) f(11) f(12) f(13) f(14) f(15)
					
					#define LOOP(f) LOOP16(f)
					
					#define CI(i) (* (__m256d *) (ci_s + (i) * 4))
					#define CI_r(i) ci_##i
					#define load(i) __m256d CI_r(i) = CI(i);
					
					LOOP(load)
					
					for (int k = k_start; k < k_end; k++) {
						int j_actual = (j_start + k * 64) & 1023;
						const double *bk = B + k * n;
						const double *bk_s = bk + j_actual;
						const double aik = ai[k];
						__m256d K = _mm256_set1_pd(aik);
						
						#define BK(i) (* (__m256d *) (bk_s + (i) * 4))
						#define add(i) CI_r(i) = _mm256_fmadd_pd(BK(i), K, CI_r(i));
						
						LOOP(add)
					}
					
					#define store(i) CI(i) = CI_r(i);
					
					LOOP(store)
				}
			}
		}
	}
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #169.893 ms16 MB + 8 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-23 22:55:17 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠