#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
// 预取距离:仅 Pass 1 开启
const int PREFETCH_DIST = 64;
// 辅助函数
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
*(uint32_t*)p = val;
}
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
*(uint16_t*)p = val;
}
void sort(uint* a, int n) {
// ---------------------------------------------------------
// 内存规划
// a: 输入数组 (4N)。
// Pass 1: 只读输入。
// Pass 2: 输出 buffer (2N 布局,含 Padding)。位于该 B3 Bucket 最终 4N 区域的起始处。
// Pass 4: 最终输出 buffer (4N 布局,紧凑)。
// b: 临时数组 (3N + Padding)。
// Pass 1: 输出 buffer (3N 布局)。
// Pass 3: 输出 buffer (2N 布局,复用空间)。
// ---------------------------------------------------------
uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);
// 全局 B3 直方图
uint cnt_global[256];
memset(cnt_global, 0, sizeof(cnt_global));
// 1.1 统计 B3 (Global MSD)
for (int i = 0; i < n; i++) {
cnt_global[a[i] >> 24]++;
}
// 1.2 计算 B3 Offset (Pass 1 Write ptr)
uint ptr_global[256];
size_t offset_b3 = 0;
for (int i = 0; i < 256; i++) {
ptr_global[i] = (uint)offset_b3;
offset_b3 += cnt_global[i] * 3 + 64; // Padding
}
// ---------------------------------------------------------
// Pass 1: Global MSD (Partition by B3)
// Read: a (4B) -> Write: b (3B: [B0, B1, B2])
// ---------------------------------------------------------
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint p[256];
memcpy(p, ptr_global, sizeof(p));
for (int i = 0; i < n; i += 16) {
_mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val >> 24;
store3(dst + p[k], val);
p[k] += 3;
}
}
}
// ---------------------------------------------------------
// 分段处理:遍历 B3 的每一个 Bucket
// ---------------------------------------------------------
uint8_t* a_u8 = (uint8_t*)a;
size_t a_offset_base = 0; // 当前 B3 Bucket 在 a 中的起始元素下标 (用于 Pass 4)
// 局部直方图/指针缓存
uint cnt_local[256];
uint ptr_local[256];
uint ptr_local_start[256];
for (int i_b3 = 0; i_b3 < 256; i_b3++) {
int count_b3 = cnt_global[i_b3];
if (count_b3 == 0) continue;
// Pass 2 输入源 (b)
uint8_t* seg_b_in = b + ptr_global[i_b3];
// Pass 2 输出目标 (a 的当前段)
// 使用当前 B3 bucket 在 a 中分配的空间的起始部分。
// Pass 2 写 2 字节,Pass 4 写 4 字节,所以 2N 的数据完全放得下且不会溢出 4N 的区域。
uint8_t* seg_a_temp = a_u8 + a_offset_base * 4;
// -----------------------------------------------------
// Pass 2: Local MSD (Partition by B2)
// Read b (3B: [B0, B1, B2]) -> Write a (2B: [B0, B1])
// -----------------------------------------------------
memset(cnt_local, 0, sizeof(cnt_local));
// 统计 B2 (Offset 2 in b)
for (int i = 0; i < count_b3; i++) {
cnt_local[*(seg_b_in + i * 3 + 2)]++;
}
uint32_t tmp_offset = 0;
// 计算 Pass 2 写入偏移 (Need Padding)
for (int k = 0; k < 256; k++) {
ptr_local[k] = tmp_offset;
ptr_local_start[k] = tmp_offset; // 保存起始偏移供 Pass 3 读取
tmp_offset += cnt_local[k] * 2 + 64;
}
// 保存 B2 计数
uint cnt_b2_saved[256];
memcpy(cnt_b2_saved, cnt_local, sizeof(cnt_local));
// 保存 B2 指针 start (相对于 seg_a_temp)
uint ptr_b2_start[256];
memcpy(ptr_b2_start, ptr_local_start, sizeof(ptr_local_start));
// 执行 Pass 2 Scatter
{
uint p[256];
memcpy(p, ptr_local, sizeof(p));
uint8_t* src = seg_b_in;
uint8_t* dst = seg_a_temp;
int i = 0;
for (; i <= count_b3 - 16; i += 16) {
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint32_t val = *(uint32_t*)(src + (i + j) * 3);
uint8_t key = (val >> 16) & 0xFF; // B2
store2(dst + p[key], val); // Store [B0, B1] (low 16 bits of val)
p[key] += 2;
}
}
for (; i < count_b3; i++) {
uint32_t val = *(uint32_t*)(src + i * 3);
uint8_t key = (val >> 16) & 0xFF;
store2(dst + p[key], val);
p[key] += 2;
}
}
// Pass 2 完成,b 的对应区域现在空闲,可用作 Pass 3 的输出 Buffer
uint8_t* seg_b_reuse = seg_b_in;
// -----------------------------------------------------
// Loop over B2 Buckets (Reverse Order)
// -----------------------------------------------------
size_t offset_final_end = count_b3;
for (int i_b2 = 255; i_b2 >= 0; i_b2--) {
int count_b2 = cnt_b2_saved[i_b2];
if (count_b2 == 0) continue;
uint8_t* src_pass3 = seg_a_temp + ptr_b2_start[i_b2];
uint8_t* dst_pass3 = seg_b_reuse;
// -------------------------------------------------
// Scan Pass 3 Input: Count B0 and B1
// -------------------------------------------------
uint cnt_b0[256] = {0};
uint cnt_b1[256] = {0};
for (int i = 0; i < count_b2; i++) {
uint16_t v = *(uint16_t*)(src_pass3 + i * 2);
cnt_b0[v & 0xFF]++; // Count B0
cnt_b1[v >> 8]++; // Count B1
}
// Calculate Pass 3 Offsets (Partition by B0) -> Need Padding
tmp_offset = 0;
for (int k = 0; k < 256; k++) {
ptr_local[k] = tmp_offset;
ptr_local_start[k] = tmp_offset;
tmp_offset += cnt_b0[k] * 2 + 64;
}
uint ptr_b0_start[256];
memcpy(ptr_b0_start, ptr_local_start, sizeof(ptr_local_start));
// -------------------------------------------------
// Pass 3: LSD Step 1 (Partition by B0)
// Read a (2B: [B0, B1]) -> Write b (2B: [B0, B1])
// -------------------------------------------------
{
uint p[256];
memcpy(p, ptr_local, sizeof(p));
uint8_t* src = src_pass3;
uint8_t* dst = dst_pass3;
int i = 0;
for (; i <= count_b2 - 16; i += 16) {
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint16_t val = *(uint16_t*)(src + (i + j) * 2);
uint8_t key = val & 0xFF; // B0
store2(dst + p[key], val);
p[key] += 2;
}
}
for (; i < count_b2; i++) {
uint16_t val = *(uint16_t*)(src + i * 2);
uint8_t key = val & 0xFF;
store2(dst + p[key], val);
p[key] += 2;
}
}
// -------------------------------------------------
// Pass 4: LSD Step 2 (Partition by B1) & Finalize
// Read b (2B: [B0, B1]) -> Write a (4B: Full)
// Input b is partitioned by B0, but contains gaps.
// -------------------------------------------------
// Calculate Pass 4 Offsets (Partition by B1) -> No Padding (Final)
uint ptr_pass4[256];
tmp_offset = 0;
for (int k = 0; k < 256; k++) {
ptr_pass4[k] = tmp_offset;
tmp_offset += cnt_b1[k];
}
size_t offset_final_start = offset_final_end - count_b2;
uint* dst_final_base = (uint*)a_u8 + a_offset_base + offset_final_start;
uint32_t high_bits = (i_b3 << 24) | (i_b2 << 16);
uint8_t* src_base = dst_pass3;
// Iterate B0 buckets (to read from b skipping gaps)
for (int b0 = 0; b0 < 256; b0++) {
int c = cnt_b0[b0];
if (c == 0) continue;
uint8_t* src = src_base + ptr_b0_start[b0];
// In this B0 bucket, we read c elements.
// Scatter them to Final a based on B1.
// Since we iterate B0=0..255, elements enter Pass 4 sorted by B0.
int i = 0;
for (; i <= c - 16; i += 16) {
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint16_t val = *(uint16_t*)(src + (i + j) * 2);
uint8_t b1 = val >> 8;
dst_final_base[ptr_pass4[b1]] = high_bits | val;
ptr_pass4[b1]++;
}
}
for (; i < c; i++) {
uint16_t val = *(uint16_t*)(src + i * 2);
uint8_t b1 = val >> 8;
dst_final_base[ptr_pass4[b1]] = high_bits | val;
ptr_pass4[b1]++;
}
}
offset_final_end = offset_final_start;
}
a_offset_base += count_b3;
}
free(b);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 739.174 ms | 667 MB + 648 KB | Accepted | Score: 100 | 显示更多 |