提交记录 22067


用户 题目 状态 得分 用时 内存 语言 代码长度
Xiaohuba mmmd1k. 测测你的双精度矩阵乘法-1k Runtime Error 0 5.475 ms 24656 KB C++ 2.40 KB
提交时间 评测时间
2024-07-31 18:11:16 2024-07-31 18:11:19
#pragma GCC target("avx,fma,avx512f")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#include <emmintrin.h>
#include <immintrin.h>

typedef double vec __attribute__((vector_size(64)));

// micro kernel
// update c[x : x + 6][y : y + 8]
// using a[x : x + 6][l : r] and b[l : r][y : y + 8]

#define _upd_t(idx)                                                            \
  vec a##idx = vec{} + a[(x + idx) * n + k];                                   \
  t##idx##0 = _mm512_fmadd_pd(a##idx, b0, t##idx##0);                          \
  t##idx##1 = _mm512_fmadd_pd(a##idx, b1, t##idx##1);
#define _set_t(i, j) vec t##i##j = c[((x + i) * n + y) / 8 + j];
#define _save_t(i, j) c[((x + i) * n + y) / 8 + j] = t##i##j;

// clang-format off

inline void kernel_6x16(double *a, vec *b, vec *c, int x, int y, int l, int r, int n) {
  _set_t(0, 0); _set_t(0, 1); _set_t(1, 0); _set_t(1, 1);
  _set_t(2, 0); _set_t(2, 1); _set_t(3, 0); _set_t(3, 1);
  _set_t(4, 0); _set_t(4, 1); _set_t(5, 0); _set_t(5, 1);

  for (int k = l; k < r; k++) {
    vec b0 = b[(k * n + y) / 8], b1 = b[(k * n + y) / 8 + 1];
    _upd_t(0); _upd_t(1); _upd_t(2);
    _upd_t(3); _upd_t(4); _upd_t(5);
  }

  _save_t(0, 0); _save_t(0, 1); _save_t(1, 0); _save_t(1, 1);
  _save_t(2, 0); _save_t(2, 1); _save_t(3, 0); _save_t(3, 1);
  _save_t(4, 0); _save_t(4, 1); _save_t(5, 0); _save_t(5, 1);
}

// clang-format on

constexpr int n = 1024, N = n, s1 = 120, s2 = 72, s3 = 32;

void matrix_multiply(int _n, const double *_a, const double *_b, double *_c) {
  int nx = (n + 5) / 6 * 6, ny = n;
  double a[nx * ny], b[nx * ny], c[nx * ny];

  memset(a, 0, sizeof a);
  memset(b, 0, sizeof b);
  memset(c, 0, sizeof c);

  for (int i = 0; i < n; i++) {
    memcpy(&a[i * ny], _a + n * i, n * 8);
    memcpy(&b[i * ny], _b + n * i, n * 8);
  }

  for (int i3 = 0; i3 < ny; i3 += s3)
    // b[:][i3 : i3 + s3]
    for (int i2 = 0; i2 < nx; i2 += s2)
      // a[i2 : i2 + s2][:]
      for (int i1 = 0; i1 < ny; i1 += s1)
        // b[i1 : i1 + s1][i3 : i3 + s3]
        // update c[i2 : i2 + s2][i3 : i3+s3] with [l : r] = [i1 : i1 + s1]
        for (int x = i2; x < std::min(i2 + s2, nx); x += 6)
          for (int y = i3; y < std::min(i3 + s3, ny); y += 16)
            kernel_6x16(a, (vec *)b, (vec *)c, x, y, i1, std::min(i1 + s1, n),
                        ny);

  for (int i = 0; i < n; i++)
    memcpy(_c + n * i, &c[i * ny], n * 8);

  std::free(a), std::free(b), std::free(c);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #15.475 ms24 MB + 80 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2025-01-24 04:25:32 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠