#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
const int PREFETCH_DIST = 64;
// 辅助函数
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
*(uint32_t*)p = val;
}
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
*(uint16_t*)p = val;
}
inline void store1(uint8_t* __restrict__ p, uint8_t val) {
*p = val;
}
void sort(uint* a, int n) {
// b: 临时数组, 大小 3N + Padding
// Pass 1 Write, Pass 3 Write (reuse), Pass 2 Temp (reuse)
uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);
uint cnt_global[256];
memset(cnt_global, 0, sizeof(cnt_global));
// 1.1 统计 B3 (Global MSD)
for (int i = 0; i < n; i++) {
cnt_global[a[i] >> 24]++;
}
// 1.2 计算 B3 Offset (Pass 1 Write ptr)
uint ptr_global[256];
size_t offset_b3 = 0;
for (int i = 0; i < 256; i++) {
ptr_global[i] = (uint)offset_b3;
offset_b3 += cnt_global[i] * 3 + 64; // Padding
}
// ---------------------------------------------------------
// Pass 1: Global MSD (Partition by B3)
// Read: a (4B) -> Write: b (3B: [B0, B1, B2])
// ---------------------------------------------------------
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint p[256];
memcpy(p, ptr_global, sizeof(p));
for (int i = 0; i < n; i += 16) {
_mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val >> 24;
store3(dst + p[k], val);
p[k] += 3;
}
}
}
// ---------------------------------------------------------
// 分段处理:遍历 B3 的每一个 Bucket
// ---------------------------------------------------------
uint8_t* a_u8 = (uint8_t*)a;
size_t a_offset_final = 0; // Final output offset in a (bytes / 4)
for (int i_b3 = 0; i_b3 < 256; i_b3++) {
int count_b3 = cnt_global[i_b3];
if (count_b3 == 0) continue;
// Input for Pass 2 (from b)
uint8_t* seg_b_in = b + ptr_global[i_b3];
// Output for Pass 2 (Temp)
// Pass 2 需要 2*N + Padding 的空间。
// 为了避免内存覆盖,根据 i_b3 的进度选择临时空间位置。
// 前半段(i_b3 < 128)时,b 的前半段被占用,但 a 的后半段空闲。
// 后半段(i_b3 >= 128)时,b 的前半段已空闲 (Pass 1 数据已消费)。
uint8_t* seg_pass2_out;
if (i_b3 < 128) {
// 使用 a 的末尾区域 (安全,因为 a 的写入从头开始)
// 需要确保不覆盖当前 Final Write 的区域。
// Final Write 目前写到 a_offset_final。
// Pass 2 Out 放在 a 数组的最后端。
seg_pass2_out = a_u8 + ((size_t)n * 4) - ((size_t)count_b3 * 2 + 256 * 64 + 4096);
} else {
// 使用 b 的起始区域 (安全,因为 b 的读取在后半段)
seg_pass2_out = b;
}
// -----------------------------------------------------
// Pass 2: Local MSD (Partition by B2)
// Read b (3B: [B0, B1, B2]) -> Write Temp (2B: [B0, B1])
// -----------------------------------------------------
uint cnt_b2[256];
memset(cnt_b2, 0, sizeof(cnt_b2));
// 统计 B2 (Offset 2 in b)
for (int i = 0; i < count_b3; i++) {
cnt_b2[*(seg_b_in + i * 3 + 2)]++;
}
uint ptr_b2[256];
uint32_t tmp_offset = 0;
// 计算 Pass 2 写入偏移 (Need Padding for Scatter)
for (int k = 0; k < 256; k++) {
ptr_b2[k] = tmp_offset;
tmp_offset += cnt_b2[k] * 2 + 64;
}
uint ptr_b2_start[256]; // 保存起始位置供 Loop B2 使用
memcpy(ptr_b2_start, ptr_b2, sizeof(ptr_b2));
// Scatter B2
{
uint p[256];
memcpy(p, ptr_b2, sizeof(p));
uint8_t* src = seg_b_in;
uint8_t* dst = seg_pass2_out;
int i = 0;
for (; i <= count_b3 - 16; i += 16) {
_mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 3], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint32_t val = *(uint32_t*)(src + (i + j) * 3);
uint8_t key = (val >> 16) & 0xFF; // B2
store2(dst + p[key], val); // Store [B0, B1]
p[key] += 2;
}
}
for (; i < count_b3; i++) {
uint32_t val = *(uint32_t*)(src + i * 3);
uint8_t key = (val >> 16) & 0xFF;
store2(dst + p[key], val);
p[key] += 2;
}
}
// Pass 2 完成,seg_b_in 指向的内存 (3N part) 现在已经无用,可以作为 Pass 3 的输出 buffer
uint8_t* seg_pass3_out_base = seg_b_in;
// -----------------------------------------------------
// 分段处理:遍历 B2 的每一个 Bucket
// -----------------------------------------------------
for (int i_b2 = 0; i_b2 < 256; i_b2++) {
int count_b2 = cnt_b2[i_b2];
if (count_b2 == 0) continue;
uint8_t* src_b2 = seg_pass2_out + ptr_b2_start[i_b2];
uint8_t* dst_b1 = seg_pass3_out_base; // Reuse b slice
// -------------------------------------------------
// Pass 3: LSD Step 1 (Partition by B1)
// Read Temp (2B: [B0, B1]) -> Write b (1B: [B0])
// Group by B1 to prepare for Final Sort
// -------------------------------------------------
uint cnt_b1[256] = {0};
for (int i = 0; i < count_b2; i++) {
// Read B1 (High byte)
cnt_b1[*(src_b2 + i * 2 + 1)]++;
}
uint ptr_b1[256];
tmp_offset = 0;
// Need padding for Pass 3 Scatter
for (int k = 0; k < 256; k++) {
ptr_b1[k] = tmp_offset;
tmp_offset += cnt_b1[k] + 32;
}
uint ptr_b1_read[256];
memcpy(ptr_b1_read, ptr_b1, sizeof(ptr_b1));
// Execute Pass 3 Scatter
{
uint p[256];
memcpy(p, ptr_b1, sizeof(p));
int i = 0;
for (; i <= count_b2 - 16; i += 16) {
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint16_t val = *(uint16_t*)(src_b2 + (i + j) * 2);
uint8_t key = val >> 8; // B1
store1(dst_b1 + p[key], val & 0xFF); // Store B0
p[key]++;
}
}
for (; i < count_b2; i++) {
uint16_t val = *(uint16_t*)(src_b2 + i * 2);
uint8_t key = val >> 8;
store1(dst_b1 + p[key], val & 0xFF);
p[key]++;
}
}
// -------------------------------------------------
// Pass 4: Finalize (Iterate B1, Sort B0)
// Read b (1B: [B0]) -> Write a (4B)
// -------------------------------------------------
uint* dst_final = (uint*)a_u8 + a_offset_final;
uint32_t high_bits = (i_b3 << 24) | (i_b2 << 16);
for (int i_b1 = 0; i_b1 < 256; i_b1++) {
int c = cnt_b1[i_b1];
if (c == 0) continue;
uint8_t* b0_ptr = dst_b1 + ptr_b1_read[i_b1];
uint32_t val_base = high_bits | (i_b1 << 8);
// Small Sort B0
// Typically c is very small (~6). std::sort is fast.
if (c < 32) {
// Insertion sort logic inline or std::sort
// std::sort with uint8_t is very optimized
std::sort(b0_ptr, b0_ptr + c);
for (int k = 0; k < c; k++) {
*dst_final++ = val_base | b0_ptr[k];
}
} else {
// Counting sort for larger buckets (rare but possible)
uint cnt_b0_local[256] = {0};
for (int k = 0; k < c; k++) cnt_b0_local[b0_ptr[k]]++;
for (int v = 0; v < 256; v++) {
int cc = cnt_b0_local[v];
if (cc) {
uint32_t val = val_base | v;
while (cc--) *dst_final++ = val;
}
}
}
}
a_offset_final += count_b2;
}
}
free(b);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 2.271 s | 667 MB + 648 KB | Accepted | Score: 100 | 显示更多 |