//written by deepseek-r1
#include <cstring>
#include <immintrin.h>
#include <algorithm>
#include <malloc.h>
#pragma GCC target("avx2")
inline void radix_sort(unsigned* a, int n) {
constexpr int RADIX_BITS = 11; // 改为11位基数提升TLB效率
constexpr int BUCKETS = 1 << RADIX_BITS;
constexpr int PASSES = (32 + RADIX_BITS - 1) / RADIX_BITS;
unsigned* tmp = (unsigned*)_mm_malloc(n * sizeof(unsigned), 64);
unsigned* src = a;
unsigned* dst = tmp;
for (int pass = 0; pass < PASSES; ++pass) {
const int shift = pass * RADIX_BITS;
alignas(64) unsigned count[BUCKETS] = {0};
// 优化1:使用更大的处理块(128元素/迭代)
const int SIMD_BLOCK = 128;
const int main_iters = n / SIMD_BLOCK;
// 优化2:预取指令+非临时存储
for (int i = 0; i < main_iters * SIMD_BLOCK; i += SIMD_BLOCK) {
// 预取下个块
_mm_prefetch((const char*)(src + i + 64), _MM_HINT_T0);
// 处理8个AVX2向量(32元素)
__m256i v[4];
for (int j = 0; j < 4; ++j) {
v[j] = _mm256_loadu_si256((__m256i*)(src + i + j*8));
v[j] = _mm256_srli_epi32(v[j], shift);
v[j] = _mm256_and_si256(v[j], _mm256_set1_epi32(BUCKETS-1));
}
// 展开统计
alignas(32) unsigned buf[32];
for (int j = 0; j < 4; ++j) {
_mm256_store_si256((__m256i*)(buf + j*8), v[j]);
}
// 批量更新直方图(减少cache访问)
for (int j = 0; j < 32; ++j) {
count[buf[j]]++;
}
}
// 优化3:剩余元素向量化处理
int rem = n % SIMD_BLOCK;
int i = main_iters * SIMD_BLOCK;
for (; i + 7 < n; i += 8) {
__m256i v = _mm256_loadu_si256((__m256i*)(src + i));
v = _mm256_srli_epi32(v, shift);
v = _mm256_and_si256(v, _mm256_set1_epi32(BUCKETS-1));
alignas(32) unsigned buf[8];
_mm256_store_si256((__m256i*)buf, v);
for (int j = 0; j < 8; ++j)
count[buf[j]]++;
}
for (; i < n; ++i)
count[(src[i] >> shift) & (BUCKETS-1)]++;
// 优化4:SIMD加速前缀和
unsigned prefix[BUCKETS];
unsigned sum = 0;
for (int b = 0; b < BUCKETS; b += 8) {
__m256i vsum = _mm256_setr_epi32(
sum, sum + count[b], sum + count[b] + count[b+1],
sum + count[b] + count[b+1] + count[b+2],
sum + count[b] + count[b+1] + count[b+2] + count[b+3],
sum + count[b] + count[b+1] + count[b+2] + count[b+3] + count[b+4],
sum + count[b] + count[b+1] + count[b+2] + count[b+3] + count[b+4] + count[b+5],
sum + count[b] + count[b+1] + count[b+2] + count[b+3] + count[b+4] + count[b+5] + count[b+6]
);
_mm256_storeu_si256((__m256i*)(prefix + b), vsum);
sum += count[b] + count[b+1] + count[b+2] + count[b+3]
+ count[b+4] + count[b+5] + count[b+6] + count[b+7];
}
// 优化5:消除指针跳转
unsigned* pos = dst;
for (int b = 0; b < BUCKETS; ++b) {
const unsigned cnt = count[b];
for (unsigned* p = src + prefix[b]; p < src + prefix[b] + cnt; ++p) {
*pos++ = *p;
}
}
std::swap(src, dst);
}
if (src == tmp)
memcpy(a, tmp, n * sizeof(unsigned));
_mm_free(tmp);
}
void sort(unsigned* a, int n) {
radix_sort(a, n);
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 573.16 us | 528 KB | Wrong Answer | Score: 0 | 显示更多 |
Testcase #2 | 591.79 ms | 476 MB + 896 KB | Wrong Answer | Score: 0 | 显示更多 |
Testcase #3 | 1.183 s | 953 MB + 732 KB | Wrong Answer | Score: 0 | 显示更多 |