提交记录 13620


用户 题目 状态 得分 用时 内存 语言 代码长度
user1 mmmd1k. 测测你的双精度矩阵乘法-1k Wrong Answer 0 76.345 ms 62640 KB C++ 8.86 KB
提交时间 评测时间
2020-08-05 18:58:49 2020-08-05 18:58:52
#pragma GCC optimize("Ofast,no-stack-protector")
#pragma GCC target("avx2,fma")
#include <string.h>
#include <x86intrin.h>

#define n 1024
#define idx(i, j) (((i) * 1088) + (j))

static void kernel32(long x, long y, long z, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
    const long m = 32;
    for (long k = 0; k < 2; ++k, z += 16) {
        for (long j = 0; j < 4; ++j, y += 8) {
            __m256d b00 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 0)]);
            __m256d b01 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 1)]);
            __m256d b02 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 2)]);
            __m256d b03 = _mm256_load_pd(&B[idx(y + 0, z + 4 * 3)]);
            __m256d b10 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 0)]);
            __m256d b11 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 1)]);
            __m256d b12 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 2)]);
            __m256d b13 = _mm256_load_pd(&B[idx(y + 1, z + 4 * 3)]);
            __m256d b20 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 0)]);
            __m256d b21 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 1)]);
            __m256d b22 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 2)]);
            __m256d b23 = _mm256_load_pd(&B[idx(y + 2, z + 4 * 3)]);
            __m256d b30 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 0)]);
            __m256d b31 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 1)]);
            __m256d b32 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 2)]);
            __m256d b33 = _mm256_load_pd(&B[idx(y + 3, z + 4 * 3)]);
            __m256d b40 = _mm256_load_pd(&B[idx(y + 4, z + 4 * 0)]);
            __m256d b41 = _mm256_load_pd(&B[idx(y + 4, z + 4 * 1)]);
            __m256d b42 = _mm256_load_pd(&B[idx(y + 4, z + 4 * 2)]);
            __m256d b43 = _mm256_load_pd(&B[idx(y + 4, z + 4 * 3)]);
            __m256d b50 = _mm256_load_pd(&B[idx(y + 5, z + 4 * 0)]);
            __m256d b51 = _mm256_load_pd(&B[idx(y + 5, z + 4 * 1)]);
            __m256d b52 = _mm256_load_pd(&B[idx(y + 5, z + 4 * 2)]);
            __m256d b53 = _mm256_load_pd(&B[idx(y + 5, z + 4 * 3)]);
            __m256d b60 = _mm256_load_pd(&B[idx(y + 6, z + 4 * 0)]);
            __m256d b61 = _mm256_load_pd(&B[idx(y + 6, z + 4 * 1)]);
            __m256d b62 = _mm256_load_pd(&B[idx(y + 6, z + 4 * 2)]);
            __m256d b63 = _mm256_load_pd(&B[idx(y + 6, z + 4 * 3)]);
            __m256d b70 = _mm256_load_pd(&B[idx(y + 7, z + 4 * 0)]);
            __m256d b71 = _mm256_load_pd(&B[idx(y + 7, z + 4 * 1)]);
            __m256d b72 = _mm256_load_pd(&B[idx(y + 7, z + 4 * 2)]);
            __m256d b73 = _mm256_load_pd(&B[idx(y + 7, z + 4 * 3)]);
            for (long i = 0; i < m; ++i) {
                __m256d c0 = _mm256_load_pd(&C[idx(x + i, z + 4 * 0)]);
                __m256d c1 = _mm256_load_pd(&C[idx(x + i, z + 4 * 1)]);
                __m256d c2 = _mm256_load_pd(&C[idx(x + i, z + 4 * 2)]);
                __m256d c3 = _mm256_load_pd(&C[idx(x + i, z + 4 * 3)]);
                __m256d a;
                a = _mm256_set1_pd(A[idx(x + i, y + 0)]);
                c0 += a * b00;
                c1 += a * b01;
                c2 += a * b02;
                c3 += a * b03;
                a = _mm256_set1_pd(A[idx(x + i, y + 1)]);
                c0 += a * b10;
                c1 += a * b11;
                c2 += a * b12;
                c3 += a * b13;
                a = _mm256_set1_pd(A[idx(x + i, y + 2)]);
                c0 += a * b20;
                c1 += a * b21;
                c2 += a * b22;
                c3 += a * b23;
                a = _mm256_set1_pd(A[idx(x + i, y + 3)]);
                c0 += a * b30;
                c1 += a * b31;
                c2 += a * b32;
                c3 += a * b33;
                a = _mm256_set1_pd(A[idx(x + i, y + 4)]);
                c0 += a * b40;
                c1 += a * b41;
                c2 += a * b42;
                c3 += a * b43;
                a = _mm256_set1_pd(A[idx(x + i, y + 5)]);
                c0 += a * b50;
                c1 += a * b51;
                c2 += a * b52;
                c3 += a * b53;
                a = _mm256_set1_pd(A[idx(x + i, y + 6)]);
                c0 += a * b60;
                c1 += a * b61;
                c2 += a * b62;
                c3 += a * b63;
                a = _mm256_set1_pd(A[idx(x + i, y + 7)]);
                c0 += a * b70;
                c1 += a * b71;
                c2 += a * b72;
                c3 += a * b73;
                _mm256_store_pd(&C[idx(x + i, z + 4 * 0)], c0);
                _mm256_store_pd(&C[idx(x + i, z + 4 * 1)], c1);
                _mm256_store_pd(&C[idx(x + i, z + 4 * 2)], c2);
                _mm256_store_pd(&C[idx(x + i, z + 4 * 3)], c3);
            }
        }
        y -= 32;
    }
}

static void gao(int s, int x, int y, int z, int dx, int dy, int dz, int dx2, int dy2, int dz2, int dx3, int dy3, int dz3, const double* __restrict__ A, const double* __restrict__ B, double* __restrict__ C) {
    if (s == 5) {
        kernel32(x, y, z, A, B, C);
        return;
    }
    --s;
    if (dx < 0) x -= dx << s;
    if (dy < 0) y -= dy << s;
    if (dz < 0) z -= dz << s;
    if (dx2 < 0) x -= dx2 << s;
    if (dy2 < 0) y -= dy2 << s;
    if (dz2 < 0) z -= dz2 << s;
    if (dx3 < 0) x -= dx3 << s;
    if (dy3 < 0) y -= dy3 << s;
    if (dz3 < 0) z -= dz3 << s;
    gao(s, x, y, z, dx2, dy2, dz2, dx3, dy3, dz3, dx, dy, dz, A, B, C);
    gao(s, x + (dx << s), y + (dy << s), z + (dz << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
    gao(s, x + (dx << s) + (dx2 << s), y + (dy << s) + (dy2 << s), z + (dz << s) + (dz2 << s), dx3, dy3, dz3, dx, dy, dz, dx2, dy2, dz2, A, B, C);
    gao(s, x + (dx2 << s), y + (dy2 << s), z + (dz2 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
    gao(s, x + (dx2 << s) + (dx3 << s), y + (dy2 << s) + (dy3 << s), z + (dz2 << s) + (dz3 << s), -dx, -dy, -dz, -dx2, -dy2, -dz2, dx3, dy3, dz3, A, B, C);
    gao(s, x + (dx << s) + (dx2 << s) + (dx3 << s), y + (dy << s) + (dy2 << s) + (dy3 << s), z + (dz << s) + (dz2 << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
    gao(s, x + (dx << s) + (dx3 << s), y + (dy << s) + (dy3 << s), z + (dz << s) + (dz3 << s), -dx3, -dy3, -dz3, dx, dy, dz, -dx2, -dy2, -dz2, A, B, C);
    gao(s, x + (dx3 << s), y + (dy3 << s), z + (dz3 << s), dx2, dy2, dz2, -dx3, -dy3, -dz3, -dx, -dy, -dz, A, B, C);
}
const int M = 512, R=576;
static void convert(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i)
        memcpy(dst + i * R, src + i * 1024, sizeof(double) * M);
}
static void convertA(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * R+ j] = src1 [ i * 1024+ j] + src2 [ i * 1024+ j];
}
static void convertS(const double* __restrict__ src1, const double* __restrict__ src2, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * R+ j] = src1 [ i * 1024+ j] - src2 [ i * 1024+ j];
}
static void iconvert(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i)
        memcpy(dst + i * 1024, src + i * R, sizeof(double) * M);
}
static void iconvertA(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * 1024+ j] += src [ i * R+ j];
}
static void iconvertS(const double* __restrict__ src, double* __restrict__ dst) {
    // memcpy(dst, src, sizeof(double) * 1024 * 1024);
    for (int i = 0; i < M; ++i) for (int j = 0; j < M; ++j)
        dst [ i * 1024+ j] -= src [ i * R+ j];
}
void matrix_multiply(int, const double* _A, const double* _B, double* _C) {
    double A[7][R * R], B[7][R * R], C[7][R * R] __attribute__((aligned(4096)));
    const int DL = 512*1024, DR=512;
    memset(C, 0, sizeof(C));
#define QAQ(k, l) gao(9, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, A[k], B[k], C[l]);
    convertA(_A, _A+DL+DR, A[0]); convertA(_B, _B+DL+DR, B[0]); QAQ(0, 0) 
    convertA(_A+DL, _A+DL+DR, A[1]); convert(_B, B[1]); QAQ(1, 1)
    convert(_A, A[2]); convertS(_B+DR, _B+DR+DL, B[2]); QAQ(2, 2)
    convert(_A+DR+DL, A[3]); convertS(_B+DL, _B, B[3]); QAQ(3, 3)
    convertA(_A, _A+DR, A[4]); convert(_B+DL+DR, B[4]); QAQ(4, 4)
    convertS(_A+DL, _A, A[5]); convertA(_B, _B+DR, B[5]); QAQ(5, 5)
    convertA(_A+DR, _A+DL+DR, A[6]); convertA(_B+DL, _B+DL+DR, B[6]); QAQ(6, 6)
    iconvert(C[0], _C); iconvertA(C[3], _C); iconvertS(C[4], _C); iconvertA(C[6], _C); 
    iconvert(C[2], _C+DR); iconvertA(C[4], _C+DR); 
    iconvert(C[1], _C+DL); iconvertA(C[3], _C+DL); 
    iconvert(C[0], _C+DL+DR); iconvertS(C[1], _C+DL+DR); iconvertA(C[2], _C+DL+DR); iconvertA(C[5], _C+DL+DR); 
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #176.345 ms61 MB + 176 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-03-24 01:02:24 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠