#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>
using namespace std;
// 预取距离
const int PREFETCH_DIST = 64;
// 写入 3 字节 (b 数组)
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
*(uint32_t*)p = val;
}
// 写入 2 字节 (a 数组临时)
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
*(uint16_t*)p = val;
}
// 写入 1 字节 (b 数组临时)
inline void store1(uint8_t* __restrict__ p, uint8_t val) {
*p = val;
}
void sort(uint* a, int n) {
// ---------------------------------------------------------
// 内存规划
// a: 输入数组 (4N 字节)。Pass 1 后数据移至 b,a 可作为临时空间。
// b: 临时数组。
// Pass 1 Output (b): 存储 [B0, B1, B2]。大小 3N。
// Pass 2 Output (a): 存储 [B0, B1]。大小 2N。存储在 a 的前半部分或分段复用。
// Pass 3 Output (b): 存储 [B0]。大小 1N。复用 b 的空间 (Pass 2 读取后释放)。
// ---------------------------------------------------------
// 申请 b 数组。
// 大小 = 3N (Pass 1) + Padding。
// Pass 3 也会复用这块内存,但只需要 1N,所以 3N 足够。
// 需要足够的 Padding 防止 Scatter 溢出。
uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);
// 全局计数器 (B3)
uint cnt_global[256];
memset(cnt_global, 0, sizeof(cnt_global));
// 1.1 统计 B3
for (int i = 0; i < n; i++) {
cnt_global[a[i] >> 24]++;
}
// 1.2 计算 Pass 1 (B3) 写入指针
// 每个 Bucket 预留一些 Padding 用于 SIMD/Unroll 溢出写
uint ptr_global[256];
size_t offset_b3 = 0;
for (int i = 0; i < 256; i++) {
ptr_global[i] = (uint)offset_b3;
// 3 bytes per element + 64 bytes padding
offset_b3 += cnt_global[i] * 3 + 64;
}
// 1.3 Pass 1: Global Scatter (Partition by B3) -> Write to b (3 Bytes)
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint p[256];
memcpy(p, ptr_global, sizeof(p));
for (int i = 0; i < n; i += 16) {
_mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val >> 24;
store3(dst + p[k], val);
p[k] += 3;
}
}
}
// ---------------------------------------------------------
// 分段处理:遍历 B3 的每一个 Bucket
// ---------------------------------------------------------
// a 数组视作字节数组用于 Pass 2 临时存储
uint8_t* a_u8 = (uint8_t*)a;
// 用于 Pass 4 (Final Write) 的全局偏移 (以 uint 为单位)
int a_final_offset = 0;
// 辅助直方图和指针数组
uint cnt_local[256];
uint ptr_local[256];
for (int i_b3 = 0; i_b3 < 256; i_b3++) {
int count_b3 = cnt_global[i_b3];
if (count_b3 == 0) continue;
// 当前 B3 Bucket 在 b 中的起始位置
uint8_t* seg_b_in = b + ptr_global[i_b3];
// Pass 2 Output (Temp in a): 存储 [B0, B1]
// 我们需要一块临时空间。由于我们是顺序处理 B3 Bucket,
// 且 Pass 4 的写入 (4N) 总是快于 Pass 2 的需求 (2N),
// 我们可以安全地利用 a 数组当前处理进度之后的一段空间,
// 或者简单地利用 a_final_offset 对应的字节位置?
// 风险:Pass 4 写 full int 会覆盖 Pass 2 写 short。
// 解决方案:利用 a 的偏移量。
// Pass 4 写入起始位置: a + a_final_offset (字节位置 a_final_offset * 4)
// Pass 2 写入起始位置: 我们必须保证 Pass 2 的数据在被 Pass 3 读取完之前不被 Pass 4 覆盖。
// 对于当前 Bucket,Pass 4 写入 4*count_b3 字节。
// Pass 2 写入 2*count_b3 字节。
// Pass 3 读取 Pass 2 数据,写入 Pass 3 数据 (到 b)。
// 只有 Pass 3 完成后,Pass 4 才开始写。
// 所以,对于 *同一个* Bucket,我们可以复用 a 的同一块内存区域,
// 只要 Pass 2 的数据 (2B) 放在 Pass 4 (4B) 将要写入的区域内即可。
// 因为 Pass 2 写完 -> Pass 3 读完 -> Pass 4 才写。无冲突。
uint8_t* seg_a_temp = a_u8 + (size_t)a_final_offset * 4;
// =====================================================
// Step 2: Pass 2 (Partition by B2)
// Read b (3B: B0,B1,B2) -> Write a (2B: B0,B1)
// =====================================================
// 2.1 统计 B2
memset(cnt_local, 0, sizeof(cnt_local));
for (int i = 0; i < count_b3; i++) {
// b 存储的是 [B0, B1, B2]
// B2 是第 3 个字节 (offset 2)
uint8_t val_b2 = *(seg_b_in + i * 3 + 2);
cnt_local[val_b2]++;
}
// 2.2 计算 Offset (for a temp)
uint32_t tmp_offset = 0;
for (int k = 0; k < 256; k++) {
ptr_local[k] = tmp_offset;
tmp_offset += cnt_local[k] * 2 + 64; // Padding
}
// 保存 B2 的计数,供后续步骤使用
uint cnt_b2_saved[256];
memcpy(cnt_b2_saved, cnt_local, sizeof(cnt_local));
// 保存 B2 的指针基址 (相对 seg_a_temp),供 Step 3 读取
uint ptr_b2_start[256];
memcpy(ptr_b2_start, ptr_local, sizeof(ptr_local));
// 2.3 Scatter B2
{
uint p[256];
memcpy(p, ptr_local, sizeof(p));
uint8_t* src = seg_b_in;
uint8_t* dst = seg_a_temp;
for (int i = 0; i < count_b3; i += 16) {
// Prefetch b
_mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 3], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
if (i + j >= count_b3) break;
// Read [B0, B1, B2]
uint32_t val = *(uint32_t*)(src + (i + j) * 3); // Reads 4 bytes, safe due to padding
uint8_t key = (val >> 16) & 0xFF; // B2
// Write [B0, B1]
store2(dst + p[key], val); // val's low 16 bits are B0, B1
p[key] += 2;
}
}
}
// Pass 2 完成,b 的当前段数据已无用。
// Pass 3 将输出到 b 的当前段 (复用内存)。
uint8_t* seg_b_temp = seg_b_in;
// =====================================================
// Loop over B2 Buckets
// =====================================================
for (int i_b2 = 0; i_b2 < 256; i_b2++) {
int count_b2 = cnt_b2_saved[i_b2];
if (count_b2 == 0) continue;
uint8_t* seg_a_in = seg_a_temp + ptr_b2_start[i_b2];
// =================================================
// Step 3: Pass 3 (Partition by B1)
// Read a (2B: B0,B1) -> Write b (1B: B0)
// =================================================
// 3.1 统计 B1
// 此时只需用一个小数组,因为在 B2 循环内,重置开销 256 次 * 256 * 256 太大?
// 不,总循环次数是 B3(256) * B2(256)。
// 统计 B1 需要清空 cnt。
// 优化:只有 256 个 bin,memset很快。
memset(cnt_local, 0, sizeof(cnt_local));
for (int i = 0; i < count_b2; i++) {
// a 存储 [B0, B1]
uint16_t val = *(uint16_t*)(seg_a_in + i * 2);
cnt_local[val >> 8]++; // High byte is B1
}
// 3.2 计算 Offset (for b temp)
// 复用 b 的空间。由于 b 原本分配给 B3 bucket 的空间是 3*count_b3。
// 现在我们只处理其中一个 B2 bucket,且只写 1 字节/元素。
// 我们可以直接从 seg_b_temp 开始写吗?
// 我们需要为每个 B2 bucket 分配 b 的空间吗?
// 不需要,我们可以复用整个 seg_b_temp 区域,只要不同 B2 bucket 之间不冲突。
// 但是这里是一个串行过程。我们在处理 B2 bucket i 时,
// 我们可以使用 seg_b_temp 的开头作为临时空间。
// 只要大小够 (1 * count_b2 + padding),绝对够。
tmp_offset = 0;
for (int k = 0; k < 256; k++) {
ptr_local[k] = tmp_offset;
tmp_offset += cnt_local[k] + 32; // 1 byte element + padding
}
uint cnt_b1_saved[256];
memcpy(cnt_b1_saved, cnt_local, sizeof(cnt_local));
uint ptr_b1_start[256];
memcpy(ptr_b1_start, ptr_local, sizeof(ptr_local));
// 3.3 Scatter B1
{
uint p[256];
memcpy(p, ptr_local, sizeof(p));
uint8_t* src = seg_a_in;
uint8_t* dst = seg_b_temp;
for (int i = 0; i < count_b2; i += 16) {
_mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 2], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
if (i + j >= count_b2) break;
uint16_t val = *(uint16_t*)(src + (i + j) * 2);
uint8_t key = val >> 8; // B1
store1(dst + p[key], val & 0xFF); // Store B0
p[key]++;
}
}
}
// =================================================
// Step 4: Pass 4 (Generate Final Output)
// Loop B1 -> Sort B0 -> Write Full Int
// =================================================
// 当前的高位部分
uint32_t high_parts = (i_b3 << 24) | (i_b2 << 16); // | (i_b1 << 8) inside loop
for (int i_b1 = 0; i_b1 < 256; i_b1++) {
int count_b1 = cnt_b1_saved[i_b1];
if (count_b1 == 0) continue;
uint8_t* b0_src = seg_b_temp + ptr_b1_start[i_b1];
uint* dst_final = a + a_final_offset;
uint32_t current_high = high_parts | (i_b1 << 8);
// 我们需要对 b0_src 中的 count_b1 个 B0 进行排序
// 如果数量少,直接排序
// 如果数量多,计数排序
if (count_b1 < 64) {
// Small sort
// Copy to stack buffer to avoid modifying source if strict, but here src is temp.
// Sort in place?
// b0_src point to b array.
// bubble sort or std::sort
// std::sort on uint8 is fast
std::sort(b0_src, b0_src + count_b1);
for(int k=0; k<count_b1; k++) {
dst_final[k] = current_high | b0_src[k];
}
} else {
// Counting Sort for B0
uint cnt_b0[256] = {0};
for(int k=0; k<count_b1; k++) cnt_b0[b0_src[k]]++;
int k = 0;
for(int v=0; v<256; v++) {
int c = cnt_b0[v];
if (c == 0) continue;
uint32_t val = current_high | v;
while(c--) {
dst_final[k++] = val;
}
}
}
a_final_offset += count_b1;
}
}
}
free(b);
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 840.106 ms | 667 MB + 648 KB | Wrong Answer | Score: 0 | 显示更多 |