#include <stdio.h>
#include <stdlib.h>
#include <immintrin.h>
#include <string.h>
#include <pthread.h>
#pragma GCC optimize("-Ofast", "-funroll-all-loops", "-ffast-math", "-ftree-vectorize")
#pragma GCC target("avx2")
int THREAD_COUNT = 16; //The num of thread
typedef struct {
int size;
int start_row;
double *dest;
const double *src1;
const double *src2;
} ThreadData;
void mul_matrix_4(double *const dest, const double *const src1, const double *const src2, const int row1, const int col1, const int row2, const int col2, const int size) {
__m256d ymm8 = _mm256_loadu_pd(&src2[row2 * size + col2]);
__m256d ymm9 = _mm256_loadu_pd(&src2[(1 + row2) * size + col2]);
__m256d ymm10 = _mm256_loadu_pd(&src2[(2 + row2) * size + col2]);
__m256d ymm11 = _mm256_loadu_pd(&src2[(3 + row2) * size + col2]);
for (int i = 0; i < 4; ++i) {
__m256d ymm0 = _mm256_set1_pd(src1[(i + row1) * size + col1]);
__m256d ymm1 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 1]);
__m256d ymm2 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 2]);
__m256d ymm3 = _mm256_set1_pd(src1[(i + row1) * size + col1 + 3]);
__m256d ymm16 = _mm256_mul_pd(ymm0, ymm8);
__m256d ymm17 = _mm256_mul_pd(ymm1, ymm9);
__m256d ymm18 = _mm256_mul_pd(ymm2, ymm10);
__m256d ymm19 = _mm256_mul_pd(ymm3, ymm11);
__m256d ymm24 = _mm256_loadu_pd(&dest[(i + row1) * size + col2]);
ymm24 = _mm256_add_pd(ymm16, ymm24);
ymm24 = _mm256_add_pd(ymm17, ymm24);
ymm24 = _mm256_add_pd(ymm18, ymm24);
ymm24 = _mm256_add_pd(ymm19, ymm24);
_mm256_store_pd(&dest[(i + row1) * size + col2], ymm24);
}
}
void *thread_mul_matrix(void *arg) {
ThreadData *data = (ThreadData *)arg;
int size = data->size;
int start_row = data->start_row;
double *dest = data->dest;
const double *src1 = data->src1;
const double *src2 = data->src2;
for (int i = start_row; i < size; i += 4 * THREAD_COUNT) {
for (int j = 0; j < size; j += 4) {
for (int k = 0; k < size; k += 4) {
mul_matrix_4(dest, src1, src2, i, k, k, j, size);
}
}
}
return NULL;
}
void mul_matrix_parallel(double *const dest, const double *const src1, const double *const src2, int size) {
pthread_t threads[THREAD_COUNT];
ThreadData data[THREAD_COUNT];
for (int i = 0; i < THREAD_COUNT; i++) {
data[i] = (ThreadData){size, i * 4, dest, src1, src2};
pthread_create(&threads[i], NULL, thread_mul_matrix, &data[i]);
}
for (int i = 0; i < THREAD_COUNT; i++) {
pthread_join(threads[i], NULL);
}
}
void matrix_multiply(int n, const double *A, const double *B, double *C) {
mul_matrix_parallel(C, A, B, n);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 6.28 us | 8 KB | Runtime Error | Score: 0 | 显示更多 |