提交记录 13690


用户 题目 状态 得分 用时 内存 语言 代码长度
user1 mmmd1k. 测测你的双精度矩阵乘法-1k Accepted 100 79.179 ms 53264 KB C++ 4.62 KB
提交时间 评测时间
2020-08-06 01:50:42 2020-08-06 01:50:46
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>



static void rotate_row(int n, int step, const double *src, double *dst) {
	memcpy(dst, src + n - step, step * sizeof(double));
	memcpy(dst + step, src, (n - step) * sizeof(double));
}

static void rotate(int n, int j_step, const double *src, double *dst) {
	for (int i = 0; i < n; i++) {
		rotate_row(n, (i * j_step) & 511, src + i * n, dst + i * n);
	}
}
inline void cat(int n, const double *A, const double *_B, double *C) {
	const int i_step = 256;
	const int k_step = 32;
	const int j_step = 64;  // 16 * 4
	
	static double B[512 * 512];
	rotate(n, j_step, _B, B);
	
	memset(C, 0, n * n * sizeof(double));
	
	for (int i_start = 0; i_start < n; i_start += i_step) {
		int i_end = i_start + i_step;
		
		for (int k_start = 0; k_start < n; k_start += k_step) {
			int k_end = k_start + k_step <= n ? k_start + k_step : n;
			
			for (int j_start = 0; j_start < n; j_start += j_step) {
				
				for (int i = i_start; i < i_end; i++) {
					const double *ai = A + i * n;
					double *ci = C + i * n;
					double *ci_s = ci + j_start;
					
					#define LOOP8(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7)
					#define LOOP16(f) f(0) f(1) f(2) f(3) f(4) f(5) f(6) f(7) \
						f(8) f(9) f(10) f(11) f(12) f(13) f(14) f(15)
					
					#define LOOP(f) LOOP16(f)
					
					#define CI(i) (* (__m256d *) (ci_s + (i) * 4))
					#define CI_r(i) ci_##i
					#define load(i) __m256d CI_r(i) = CI(i);
					
					LOOP(load)
					
					for (int k = k_start; k < k_end; k++) {
						int j_actual = (j_start + k * 64) & 511;
						const double *bk = B + k * n;
						const double *bk_s = bk + j_actual;
						const double aik = ai[k];
						__m256d K = _mm256_set1_pd(aik);
						
						#define BK(i) (* (__m256d *) (bk_s + (i) * 4))
						#define add(i) CI_r(i) = _mm256_fmadd_pd(BK(i), K, CI_r(i));
						
						LOOP(add)
					}
					
					#define store(i) CI(i) = CI_r(i);
					
					LOOP(store)
				}
			}
		}
	}
}
#define n 1024
#define idx(i, j) (((i) * 512) + (j))

const int M = 512, R=512;
static void convert(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i)
        memcpy(dst + i * R, src + i * 1024, sizeof(double) * M);
}
static void convertA(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * R+ j] = src1 [ i * 1024+ j] + src2 [ i * 1024+ j];
}
static void convertS(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * R+ j] = src1 [ i * 1024+ j] - src2 [ i * 1024+ j];
}
static void iconvert(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i)
        memcpy(dst + i * 1024, src + i * R, sizeof(double) * M);
}
static void iconvertA(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * 1024+ j] += src [ i * R+ j];
}
static void iconvertS(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * 1024+ j] -= src [ i * R+ j];
}
void matrix_multiply(int, const double* _A, const double* _B, double* _C) {
    double A[7][R * R], B[7][R * R], C[7][R * R] __attribute__((aligned(4096)));
    const int DL = 512*1024, DR=512;
    memset(C, 0, sizeof(C));
#define QAQ(k, l) cat(512, A[k], B[k], C[l]);
    convertA(_A, _A+DL+DR, A[0]); convertA(_B, _B+DL+DR, B[0]); QAQ(0, 0) 
    convertA(_A+DL, _A+DL+DR, A[1]); convert(_B, B[1]); QAQ(1, 1)
    convert(_A, A[2]); convertS(_B+DR, _B+DR+DL, B[2]); QAQ(2, 2)
    convert(_A+DR+DL, A[3]); convertS(_B+DL, _B, B[3]); QAQ(3, 3)
    convertA(_A, _A+DR, A[4]); convert(_B+DL+DR, B[4]); QAQ(4, 4)
    convertS(_A+DL, _A, A[5]); convertA(_B, _B+DR, B[5]); QAQ(5, 5)
    convertS(_A+DR, _A+DL+DR, A[6]); convertA(_B+DL, _B+DL+DR, B[6]); QAQ(6, 6)
    iconvert(C[0], _C); iconvertA(C[3], _C); iconvertS(C[4], _C); iconvertA(C[6], _C); 
    iconvert(C[2], _C+DR); iconvertA(C[4], _C+DR); 
    iconvert(C[1], _C+DL); iconvertA(C[3], _C+DL); 
    iconvert(C[0], _C+DL+DR); iconvertS(C[1], _C+DL+DR); iconvertA(C[2], _C+DL+DR); iconvertA(C[5], _C+DL+DR); 
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #179.179 ms52 MB + 16 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-23 22:29:07 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠