提交记录 21682


用户 题目 状态 得分 用时 内存 语言 代码长度
ntland mmmd1k. 测测你的双精度矩阵乘法-1k Runtime Error 0 6.19 us 8 KB C++ 2.75 KB
提交时间 评测时间
2024-04-28 20:40:59 2024-04-28 20:41:01
#include <stdio.h>
#include <stdlib.h>
#include <immintrin.h>
#include <string.h>
#include <pthread.h>
#pragma GCC optimize("-Ofast", "-funroll-all-loops", "-ffast-math", "-ftree-vectorize")
#pragma GCC target("avx2")

int THREAD_COUNT = 16; //The num of thread

typedef struct {
    int size;
    int start_row;
    double *dest;
    const double *src1;
    const double *src2;
} ThreadData;

void mul_matrix_4(double *const dest, const double *const src1, const double *const src2, const int row1, const int col1, const int row2, const int col2, const int size) {
        __m256d ymm8 = _mm256_loadu_pd(&src2[row2 * size + col2]);
        __m256d ymm9 = _mm256_loadu_pd(&src2[(1 + row2) * size + col2]);
        __m256d ymm10 = _mm256_loadu_pd(&src2[(2 + row2) * size + col2]);
        __m256d ymm11 = _mm256_loadu_pd(&src2[(3 + row2) * size + col2]);
    for (int i = 0; i < 4; ++i) {
        __m256d ymm0 = _mm256_set1_pd(src1[(i + row1) * size + col1]);
        __m256d ymm1 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 1]);
        __m256d ymm2 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 2]);
        __m256d ymm3 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 3]);

        __m256d ymm16 = _mm256_mul_pd(ymm0, ymm8);
        __m256d ymm17 = _mm256_mul_pd(ymm1, ymm9);
        __m256d ymm18 = _mm256_mul_pd(ymm2, ymm10);
        __m256d ymm19 = _mm256_mul_pd(ymm3, ymm11);

        __m256d ymm24 = _mm256_loadu_pd(&dest[(i + row1) * size + col2]);
        ymm24 = _mm256_add_pd(ymm16, ymm24);
        ymm24 = _mm256_add_pd(ymm17, ymm24);
        ymm24 = _mm256_add_pd(ymm18, ymm24);
        ymm24 = _mm256_add_pd(ymm19, ymm24);

        _mm256_store_pd(&dest[(i + row1) * size + col2], ymm24);
    }
}


void *thread_mul_matrix(void *arg) {
    ThreadData *data = (ThreadData *)arg;
    int size = data->size;
    int start_row = data->start_row;
    double *dest = data->dest;
    const double *src1 = data->src1;
    const double *src2 = data->src2;

    for (int i = start_row; i < size; i += 4 * THREAD_COUNT) {
        for (int j = 0; j < size; j += 4) {
            for (int k = 0; k < size; k += 4) {
                mul_matrix_4(dest, src1, src2, i, k, k, j, size);
            }
        }
    }
    return NULL;
}

void mul_matrix_parallel(double *const dest, const double *const src1, const double *const src2, int size) {
    pthread_t threads[THREAD_COUNT];
    ThreadData data[THREAD_COUNT];
    for (int i = 0; i < THREAD_COUNT; i++) {
        data[i] = (ThreadData){size, i * 4, dest, src1, src2};
        pthread_create(&threads[i], NULL, thread_mul_matrix, &data[i]);
    }
    for (int i = 0; i < THREAD_COUNT; i++) {
        pthread_join(threads[i], NULL);
    }
}

void matrix_multiply(int n, const double *A, const double *B, double *C) {
    mul_matrix_parallel(C, A, B, n);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #16.19 us8 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-07-18 11:19:47 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠