提交记录 28793
| 提交时间 |
评测时间 |
| 2026-01-18 15:55:48 |
2026-01-18 15:55:53 |
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
const int n = 1e8;
const int PREFETCH_DIST = 64;
// 辅助函数:向地址 p 写入 3 字节的 val (Little Endian)
// 相当于写入 val 的低 24 位
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
*(uint16_t*)p = (uint16_t)val;
*(p + 2) = (uint8_t)(val >> 16);
}
void sort(uint* a, int __n) {
// 1. 直方图统计
// ptr 在这里存储的是 "Byte Offset" (字节偏移量),而不是元素索引
uint cnt[4][256];
uint ptr[4][256];
memset(cnt, 0, sizeof(cnt));
// 统计 Pass 1-4 的分布
for (int i = 0; i < n; i++) {
uint val = a[i];
cnt[0][val & 255]++;
cnt[1][val >> 8 & 255]++;
cnt[2][val >> 16 & 255]++;
cnt[3][val >> 24]++;
}
// 计算 Offset
for (int k = 0; k < 4; k++) {
uint32_t offset = 0;
// 前 3 轮是紧凑存储,每个元素占 3 字节
// 第 4 轮是完整输出,每个元素占 4 字节
uint32_t stride = (k == 3) ? 4 : 3;
int padding = (k == 3) ? 0 : 4;
for (int i = 0; i < 256; i++) {
ptr[k][i] = offset;
offset += cnt[k][i] * stride + padding;
}
}
// 申请临时空间 b,大小为 3 * n
// 为每个 Segment 留 4 字节冗余 -> 256 * 4 = 1024
// 多申请 4 个字节,防止读取最后一个元素时越界 (Read 4 bytes trick)
uint8_t* b = (uint8_t*)malloc(n * 3 + 256 * 4 + 4);
// 我们将 a 强制转换为 uint8_t* 来看待,以便在 Pass 2 进行 3-byte 写入
// a 的原始空间是 4*n,完全足够存 3*n 的数据
uint8_t* a_u8 = (uint8_t*)a;
// --- Pass 1: Read A(4), Write B(3) ---
// Key: B0. Store: [B1][B2][B3]
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint* __restrict__ p = ptr[0];
for (int i = 0; i < n; i += 16) {
// _mm_prefetch(&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val & 255;
store3(dst + p[k], val >> 8);
p[k] += 3; // 步进 3 字节
}
}
}
// --- Pass 2: Read B(3), Write A(3) ---
// Input: [B1][B2][B3]. Key: B1. Seg: B0.
// Store: [B2][B3][B0]
{
uint8_t* __restrict__ src = b;
uint8_t* __restrict__ dst = a_u8; // 复用 A 的内存作为临时 3-byte buffer
uint* __restrict__ p = ptr[1];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[0][seg];
if (count == 0) continue;
// 构造需要插入的高位数据 (B0 移到最高位)
// 目标结构: [B2][B3][B0] -> 对应整数 0x00B0B3B2
// seg 是 B0
uint32_t high_bits = seg << 16;
uint32_t offset = (seg == 0) ? 0 : (ptr[0][seg - 1] + 4);
uint8_t* sp = src + offset;
for (int i = 0; i < count; i++) {
// 读取 4 字节是安全的 (利用 padding)
// 读入: [B1][B2][B3][XX] -> val
uint32_t val = *(uint32_t*)(sp + i * 3);
uint8_t k = val & 255; // B1
// (val >> 8) -> [B2][B3][XX]
// | high_bits -> [B2][B3][B0] (注意: XX 被高位移出或忽略,我们只存低24位)
store3(dst + p[k], ((val >> 8) & 0xFFFF) | high_bits);
p[k] += 3;
}
}
}
// --- Pass 3: Read A(3), Write B(3) ---
// Input: [B2][B3][B0]. Key: B2. Seg: B1.
// Store: [B3][B0][B1]
{
uint8_t* __restrict__ src = a_u8;
uint8_t* __restrict__ dst = b;
uint* __restrict__ p = ptr[2];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[1][seg];
if (count == 0) continue;
// seg 是 B1
uint32_t high_bits = seg << 16;
uint32_t offset = (seg == 0) ? 0 : (ptr[1][seg - 1] + 4);
uint8_t* sp = src + offset;
for (int i = 0; i < count; i++) {
uint32_t val = *(uint32_t*)(sp + i * 3);
// val: [B2][B3][B0][XX]
// key: B2
uint8_t k = val & 255;
// (val >> 8) -> [B3][B0]
// Result: [B3][B0][B1]
store3(dst + p[k], ((val >> 8) & 0xFFFF) | high_bits);
p[k] += 3;
}
}
}
// --- Pass 4: Read B(3), Write A(4) ---
// Input: [B3][B0][B1]. Key: B3. Seg: B2.
// Store: [B0][B1][B2][B3] (Full uint32)
{
uint8_t* __restrict__ src = b;
uint* __restrict__ dst = a; // 最终输出是 4 字节对齐的 int 数组
// 注意:ptr[3] 此时存储的是 index (如果 stride 设为 1) 还是 byte offset?
// 之前我们在计算 ptr[3] 时使用了 stride=4,所以这里存储的是 byte offset。
// 但 dst 是 uint*,我们需要 index。
// 修正:为了性能,直接让 ptr[3] 存 index 可能更好,或者这里转成 uint8* 操作。
// 这里为了统一,我们将 dst 转为 uint8_t* 操作,或者除以 4。
// 既然是指针操作,直接转 uint8_t* dst_u8 = (uint8_t*)a; 方便加 offset。
uint8_t* __restrict__ dst_u8 = (uint8_t*)a;
uint* __restrict__ p = ptr[3];
for (int seg = 0; seg < 256; seg++) {
int count = cnt[2][seg];
if (count == 0) continue;
// seg 是 B2
uint32_t part_mid = seg << 16;
uint32_t offset = (seg == 0) ? 0 : (ptr[2][seg - 1] + 4);
uint8_t* sp = src + offset;
for (int i = 0; i < count; i++) {
uint32_t val = *(uint32_t*)(sp + i * 3);
// val: [B3][B0][B1][XX]
uint32_t k = val & 255; // B3 (Key)
// 我们需要重组 [B0][B1][B2][B3]
// val >> 8 -> [00][B0][B1] (值: ...B1B0)
// part_mid -> [00][00][B2] (值: ...B20000) -> 这一步移位错了,内存顺序是 LE
// 让我们重新推导 LE 下的算术:
// 内存目标: [B0] [B1] [B2] [B3]
// val >> 8 内存变成: [B0] [B1] [XX] ... (实际值是 0x...XXB1B0)
// 所以 (val >> 8) 贡献了低 16 位 (B0, B1)。
// seg 是 B2,需要放到第 3 个字节。 (seg << 16)。
// k 是 B3,需要放到第 4 个字节。 (k << 24)。
uint32_t final_val = ((val >> 8) & 0xFFFF) | part_mid | (k << 24);
// 写入 4 字节 (这里可以用非对齐写,但 A 是对齐的,所以直接强转写最快)
*(uint32_t*)(dst_u8 + p[k]) = final_val;
p[k] += 4; // 步进 4 字节
}
}
}
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 847.119 ms | 667 MB + 632 KB | Accepted | Score: 100 | 显示更多 |
Judge Duck Online | 评测鸭在线
Server Time: 2026-01-18 20:32:31 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠