#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
const int n = 1e8;
const int PREFETCH_DIST = 64;
// 辅助函数:向地址 p 写入 3 字节的 val (Little Endian)
// 相当于写入 val 的低 24 位
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
// *(uint16_t*)p = (uint16_t)val;
// *(p + 2) = (uint8_t)(val >> 16);
*(uint32_t*)p = val; // 直接写入低 24 位
}
void sort(uint* a, int __n) {
// 1. 直方图统计
// ptr 在这里存储的是 "Byte Offset" (字节偏移量),而不是元素索引
uint cnt[4][256];
uint ptr[4][256];
memset(cnt, 0, sizeof(cnt));
// 统计 Pass 1-4 的分布
for (int i = 0; i < n; i++) {
uint val = a[i];
cnt[0][val & 255]++;
cnt[1][val >> 8 & 255]++;
cnt[2][val >> 16 & 255]++;
cnt[3][val >> 24]++;
}
// 计算 Offset
for (int k = 0; k < 4; k++) {
uint32_t offset = 0;
// 前 3 轮是紧凑存储,每个元素占 3 字节
// 第 4 轮是完整输出,每个元素占 4 字节
uint32_t stride = (k == 3) ? 4 : 3;
int padding = (k == 3) ? 0 : 4;
for (int i = 0; i < 256; i++) {
ptr[k][i] = offset;
offset += cnt[k][i] * stride + padding;
}
}
// 申请临时空间 b,大小为 3 * n
// 为每个 Segment 留 4 字节冗余 -> 256 * 4 = 1024
// 多申请 4 个字节,防止读取最后一个元素时越界 (Read 4 bytes trick)
uint8_t* b = (uint8_t*)malloc(n * 3 + 256 * 4 + 4);
// 我们将 a 强制转换为 uint8_t* 来看待,以便在 Pass 2 进行 3-byte 写入
// a 的原始空间是 4*n,完全足够存 3*n 的数据
uint8_t* a_u8 = (uint8_t*)a;
// --- Pass 1: Read A(4), Write B(3) ---
// Key: B0. Store: [B1][B2][B3]
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint* __restrict__ p = ptr[0];
for (int i = 0; i < n; i += 16) {
_mm_prefetch(&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val & 255;
store3(dst + p[k], val >> 8);
p[k] += 3; // 步进 3 字节
}
}
}
// --- Pass 2: Read B(3), Write A(3) ---
// Input: [B1][B2][B3]. Key: B1. Seg: B0.
// Store: [B0][B2][B3]
{
uint8_t* __restrict__ src = b;
uint8_t* __restrict__ dst = a_u8; // 复用 A 的内存作为临时 3-byte buffer
uint* __restrict__ p = ptr[1];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[0][seg];
if (count == 0) continue;
// seg 是 B0
uint32_t b0 = seg;
uint32_t offset = (seg == 0) ? 0 : (ptr[0][seg - 1] + 4);
uint8_t* sp = src + offset;
int i = 0;
for (; i <= count - 21; i += 21) {
_mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
#pragma GCC unroll 21
for (int j = 0; j < 21; j++) {
uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
uint8_t k = val & 255;
store3(dst + p[k], (val & 0xFFFF00) | b0);
p[k] += 3;
}
}
for (; i < count; i++) {
uint32_t val = *(uint32_t*)(sp + i * 3);
uint8_t k = val & 255; // B1
store3(dst + p[k], (val & 0xFFFF00) | b0);
p[k] += 3;
}
}
}
// --- Pass 3: Read A(3), Write B(3) ---
// Input: [B0][B2][B3]. Key: B2. Seg: B1.
// Store: [B0][B1][B3]
{
uint8_t* __restrict__ src = a_u8;
uint8_t* __restrict__ dst = b;
uint* __restrict__ p = ptr[2];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[1][seg];
if (count == 0) continue;
// seg 是 B1
uint32_t b1 = seg << 8;
uint32_t offset = (seg == 0) ? 0 : (ptr[1][seg - 1] + 4);
uint8_t* sp = src + offset;
int i = 0;
for (; i <= count - 21; i += 21) {
_mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
#pragma GCC unroll 21
for (int j = 0; j < 21; j++) {
uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
uint8_t k = (val >> 8) & 255;
store3(dst + p[k], (val & 0xFF00FF) | b1);
p[k] += 3;
}
}
for (; i < count; i++) {
uint32_t val = *(uint32_t*)(sp + i * 3);
// val: [B0][B2][B3]
// key: B2
uint8_t k = (val >> 8) & 255;
store3(dst + p[k], (val & 0xFF00FF) | b1);
p[k] += 3;
}
}
}
// --- Pass 4: Read B(3), Write A(4) ---
// Input: [B0][B1][B3]. Key: B3. Seg: B2.
// Store: [B0][B1][B2][B3] (Full uint32)
{
uint8_t* __restrict__ src = b;
uint* __restrict__ dst = a;
uint8_t* __restrict__ dst_u8 = (uint8_t*)a;
uint* __restrict__ p = ptr[3];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[2][seg];
if (count == 0) continue;
// seg 是 B2
uint32_t b2 = seg << 16;
uint32_t offset = (seg == 0) ? 0 : (ptr[2][seg - 1] + 4);
uint8_t* sp = src + offset;
int i = 0;
for (; i <= count - 21; i += 21) {
_mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
#pragma GCC unroll 21
for (int j = 0; j < 21; j++) {
uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
uint32_t k = (val >> 16) & 255;
uint32_t final_val = _pdep_u32(val, 0xFF00FFFF) | b2;
*(uint32_t*)(dst_u8 + p[k]) = final_val;
p[k] += 4;
}
}
for (; i < count; i++) {
uint32_t val = *(uint32_t*)(sp + i * 3);
uint32_t k = (val >> 16) & 255; // B3 (Key)
uint32_t final_val = _pdep_u32(val, 0xFF00FFFF) | b2;
*(uint32_t*)(dst_u8 + p[k]) = final_val;
p[k] += 4;
}
}
}
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 712.426 ms | 667 MB + 632 KB | Accepted | Score: 100 | 显示更多 |