提交记录 28792


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Accepted 100 746.449 ms 683640 KB C++17 6.18 KB
提交时间 评测时间
2026-01-18 15:55:07 2026-01-18 15:55:13
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

const int n = 1e8;
const int PREFETCH_DIST = 64;

// 辅助函数:向地址 p 写入 3 字节的 val (Little Endian)
// 相当于写入 val 的低 24 位
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    // *(uint16_t*)p = (uint16_t)val;
    // *(p + 2) = (uint8_t)(val >> 16);
    *(uint32_t*)p = val; // 直接写入低 24 位
}

void sort(uint* a, int __n) {
    // 1. 直方图统计
    // ptr 在这里存储的是 "Byte Offset" (字节偏移量),而不是元素索引
    uint cnt[4][256];
    uint ptr[4][256];
    
    memset(cnt, 0, sizeof(cnt));
    
    // 统计 Pass 1-4 的分布
    for (int i = 0; i < n; i++) {
        uint val = a[i];
        cnt[0][val & 255]++;
        cnt[1][val >> 8 & 255]++;
        cnt[2][val >> 16 & 255]++;
        cnt[3][val >> 24]++;
    }

    // 计算 Offset
    for (int k = 0; k < 4; k++) {
        uint32_t offset = 0;
        // 前 3 轮是紧凑存储,每个元素占 3 字节
        // 第 4 轮是完整输出,每个元素占 4 字节
        uint32_t stride = (k == 3) ? 4 : 3;
        int padding = (k == 3) ? 0 : 4;
        for (int i = 0; i < 256; i++) {
            ptr[k][i] = offset;
            offset += cnt[k][i] * stride + padding;
        }
    }

    // 申请临时空间 b,大小为 3 * n
    // 为每个 Segment 留 4 字节冗余 -> 256 * 4 = 1024
    // 多申请 4 个字节,防止读取最后一个元素时越界 (Read 4 bytes trick)
    uint8_t* b = (uint8_t*)malloc(n * 3 + 256 * 4 + 4);
    
    // 我们将 a 强制转换为 uint8_t* 来看待,以便在 Pass 2 进行 3-byte 写入
    // a 的原始空间是 4*n,完全足够存 3*n 的数据
    uint8_t* a_u8 = (uint8_t*)a;

    // --- Pass 1: Read A(4), Write B(3) ---
    // Key: B0. Store: [B1][B2][B3]
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint* __restrict__ p = ptr[0];

        for (int i = 0; i < n; i += 16) {
            // _mm_prefetch(&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val & 255;
                store3(dst + p[k], val >> 8);
                p[k] += 3; // 步进 3 字节
            }
        }
    }

    // --- Pass 2: Read B(3), Write A(3) ---
    // Input: [B1][B2][B3]. Key: B1. Seg: B0.
    // Store: [B2][B3][B0]
    {
        uint8_t* __restrict__ src = b;
        uint8_t* __restrict__ dst = a_u8; // 复用 A 的内存作为临时 3-byte buffer
        uint* __restrict__ p = ptr[1];
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[0][seg];
            if (count == 0) continue;

            // 构造需要插入的高位数据 (B0 移到最高位)
            // 目标结构: [B2][B3][B0] -> 对应整数 0x00B0B3B2
            // seg 是 B0
            uint32_t high_bits = seg << 16; 

            uint32_t offset = (seg == 0) ? 0 : (ptr[0][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            for (int i = 0; i < count; i++) {
                // 读取 4 字节是安全的 (利用 padding)
                // 读入: [B1][B2][B3][XX] -> val
                uint32_t val = *(uint32_t*)(sp + i * 3);
                
                uint8_t k = val & 255; // B1
                
                // (val >> 8) -> [B2][B3][XX]
                // | high_bits -> [B2][B3][B0] (注意: XX 被高位移出或忽略,我们只存低24位)
                store3(dst + p[k], ((val >> 8) & 0xFFFF) | high_bits);
                p[k] += 3;
            }
        }
    }

    // --- Pass 3: Read A(3), Write B(3) ---
    // Input: [B2][B3][B0]. Key: B2. Seg: B1.
    // Store: [B3][B0][B1]
    {
        uint8_t* __restrict__ src = a_u8;
        uint8_t* __restrict__ dst = b;
        uint* __restrict__ p = ptr[2];
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[1][seg];
            if (count == 0) continue;

            // seg 是 B1
            uint32_t high_bits = seg << 16;

            uint32_t offset = (seg == 0) ? 0 : (ptr[1][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            for (int i = 0; i < count; i++) {
                uint32_t val = *(uint32_t*)(sp + i * 3);
                // val: [B2][B3][B0][XX]
                // key: B2
                uint8_t k = val & 255;
                
                // (val >> 8) -> [B3][B0]
                // Result: [B3][B0][B1]
                store3(dst + p[k], ((val >> 8) & 0xFFFF) | high_bits);
                p[k] += 3;
            }
        }
    }

    // --- Pass 4: Read B(3), Write A(4) ---
    // Input: [B3][B0][B1]. Key: B3. Seg: B2.
    // Store: [B0][B1][B2][B3] (Full uint32)
    {
        uint8_t* __restrict__ src = b;
        uint* __restrict__ dst = a; // 最终输出是 4 字节对齐的 int 数组
        // 注意:ptr[3] 此时存储的是 index (如果 stride 设为 1) 还是 byte offset?
        // 之前我们在计算 ptr[3] 时使用了 stride=4,所以这里存储的是 byte offset。
        // 但 dst 是 uint*,我们需要 index。
        // 修正:为了性能,直接让 ptr[3] 存 index 可能更好,或者这里转成 uint8* 操作。
        // 这里为了统一,我们将 dst 转为 uint8_t* 操作,或者除以 4。
        // 既然是指针操作,直接转 uint8_t* dst_u8 = (uint8_t*)a; 方便加 offset。
        
        uint8_t* __restrict__ dst_u8 = (uint8_t*)a;
        uint* __restrict__ p = ptr[3]; 
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[2][seg];
            if (count == 0) continue;

            // seg 是 B2
            uint32_t part_mid = seg << 16; 

            uint32_t offset = (seg == 0) ? 0 : (ptr[2][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            for (int i = 0; i < count; i++) {
                uint32_t val = *(uint32_t*)(sp + i * 3);
                // val: [B3][B0][B1][XX]
                
                uint32_t k = val & 255; // B3 (Key)
                
                // 我们需要重组 [B0][B1][B2][B3]
                // val >> 8     -> [00][B0][B1] (值: ...B1B0)
                // part_mid     -> [00][00][B2] (值: ...B20000) -> 这一步移位错了,内存顺序是 LE
                
                // 让我们重新推导 LE 下的算术:
                // 内存目标: [B0] [B1] [B2] [B3]
                // val >> 8 内存变成: [B0] [B1] [XX] ... (实际值是 0x...XXB1B0)
                // 所以 (val >> 8) 贡献了低 16 位 (B0, B1)。
                // seg 是 B2,需要放到第 3 个字节。 (seg << 16)。
                // k   是 B3,需要放到第 4 个字节。 (k << 24)。
                
                uint32_t final_val = ((val >> 8) & 0xFFFF) | part_mid | (k << 24);
                
                // 写入 4 字节 (这里可以用非对齐写,但 A 是对齐的,所以直接强转写最快)
                *(uint32_t*)(dst_u8 + p[k]) = final_val;
                
                p[k] += 4; // 步进 4 字节
            }
        }
    }
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1746.449 ms667 MB + 632 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-18 20:38:28 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠