提交记录 27826
提交时间 |
评测时间 |
2025-01-27 15:33:03 |
2025-01-27 15:33:05 |
//written by deepseek-r1
#include <cstring>
#include <immintrin.h>
#include <algorithm>
#include <malloc.h>
#pragma GCC target("avx2")
__attribute__((target("avx2")))
inline void radix_sort(unsigned* a, int n) {
constexpr int RADIX_BITS = 8;
constexpr int BUCKETS = 1 << RADIX_BITS;
unsigned* tmp = (unsigned*)_mm_malloc(n * sizeof(unsigned), 32);
unsigned* src = a;
unsigned* dst = tmp;
for (int shift = 0; shift < 32; shift += RADIX_BITS) {
alignas(32) unsigned count[BUCKETS] = {0};
// SIMD优化的直方图统计
const int BLOCK = 32;
const int main_end = n & ~(BLOCK-1);
#pragma GCC unroll 4
for (int i = 0; i < main_end; i += BLOCK) {
// 使用非对齐加载以兼容任意地址
__m256i v0 = _mm256_loadu_si256((__m256i*)(src + i));
__m256i v1 = _mm256_loadu_si256((__m256i*)(src + i + 8));
__m256i v2 = _mm256_loadu_si256((__m256i*)(src + i + 16));
__m256i v3 = _mm256_loadu_si256((__m256i*)(src + i + 24));
// 移位掩码操作
v0 = _mm256_srli_epi32(v0, shift);
v1 = _mm256_srli_epi32(v1, shift);
v2 = _mm256_srli_epi32(v2, shift);
v3 = _mm256_srli_epi32(v3, shift);
v0 = _mm256_and_si256(v0, _mm256_set1_epi32(0xFF));
v1 = _mm256_and_si256(v1, _mm256_set1_epi32(0xFF));
v2 = _mm256_and_si256(v2, _mm256_set1_epi32(0xFF));
v3 = _mm256_and_si256(v3, _mm256_set1_epi32(0xFF));
// 直接累加到直方图
alignas(32) unsigned buf[32];
_mm256_store_si256((__m256i*)buf, v0);
_mm256_store_si256((__m256i*)(buf+8), v1);
_mm256_store_si256((__m256i*)(buf+16), v2);
_mm256_store_si256((__m256i*)(buf+24), v3);
#pragma GCC unroll 8
for (int j = 0; j < BLOCK; j++) {
count[buf[j]]++;
}
}
// 处理尾部元素
for (int i = main_end; i < n; ++i) {
count[(src[i] >> shift) & 0xFF]++;
}
// 展开的前缀和计算
unsigned prefix[BUCKETS];
prefix[0] = 0;
#pragma GCC unroll 16
for (int i = 1; i < BUCKETS; i++) {
prefix[i] = prefix[i-1] + count[i-1];
}
// 指针数组优化写入位置
unsigned* pos[BUCKETS];
#pragma GCC ivdep
for (int b = 0; b < BUCKETS; ++b) {
pos[b] = dst + prefix[b];
}
// 连续访问优化
#pragma GCC ivdep
for (int i = 0; i < n; ++i) {
const unsigned val = src[i];
const unsigned b = (val >> shift) & 0xFF;
*pos[b]++ = val;
}
std::swap(src, dst);
}
// 确保结果在原始数组
if (src == tmp) {
memcpy(a, tmp, n * sizeof(unsigned));
}
_mm_free(tmp);
}
__attribute__((target("default")))
void sort(unsigned* a, int n) {
radix_sort(a, n);
}
Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |
Judge Duck Online | 评测鸭在线
Server Time: 2025-04-05 13:49:17 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠