提交记录 28797


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Accepted 100 712.426 ms 683640 KB C++17 6.24 KB
提交时间 评测时间
2026-01-18 16:53:01 2026-01-18 16:53:08
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

const int n = 1e8;
const int PREFETCH_DIST = 64;

// 辅助函数:向地址 p 写入 3 字节的 val (Little Endian)
// 相当于写入 val 的低 24 位
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    // *(uint16_t*)p = (uint16_t)val;
    // *(p + 2) = (uint8_t)(val >> 16);
    *(uint32_t*)p = val; // 直接写入低 24 位
}

void sort(uint* a, int __n) {
    // 1. 直方图统计
    // ptr 在这里存储的是 "Byte Offset" (字节偏移量),而不是元素索引
    uint cnt[4][256];
    uint ptr[4][256];
    
    memset(cnt, 0, sizeof(cnt));
    
    // 统计 Pass 1-4 的分布
    for (int i = 0; i < n; i++) {
        uint val = a[i];
        cnt[0][val & 255]++;
        cnt[1][val >> 8 & 255]++;
        cnt[2][val >> 16 & 255]++;
        cnt[3][val >> 24]++;
    }

    // 计算 Offset
    for (int k = 0; k < 4; k++) {
        uint32_t offset = 0;
        // 前 3 轮是紧凑存储,每个元素占 3 字节
        // 第 4 轮是完整输出,每个元素占 4 字节
        uint32_t stride = (k == 3) ? 4 : 3;
        int padding = (k == 3) ? 0 : 4;
        for (int i = 0; i < 256; i++) {
            ptr[k][i] = offset;
            offset += cnt[k][i] * stride + padding;
        }
    }

    // 申请临时空间 b,大小为 3 * n
    // 为每个 Segment 留 4 字节冗余 -> 256 * 4 = 1024
    // 多申请 4 个字节,防止读取最后一个元素时越界 (Read 4 bytes trick)
    uint8_t* b = (uint8_t*)malloc(n * 3 + 256 * 4 + 4);
    
    // 我们将 a 强制转换为 uint8_t* 来看待,以便在 Pass 2 进行 3-byte 写入
    // a 的原始空间是 4*n,完全足够存 3*n 的数据
    uint8_t* a_u8 = (uint8_t*)a;

    // --- Pass 1: Read A(4), Write B(3) ---
    // Key: B0. Store: [B1][B2][B3]
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint* __restrict__ p = ptr[0];

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch(&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val & 255;
                store3(dst + p[k], val >> 8);
                p[k] += 3; // 步进 3 字节
            }
        }
    }

    // --- Pass 2: Read B(3), Write A(3) ---
    // Input: [B1][B2][B3]. Key: B1. Seg: B0.
    // Store: [B0][B2][B3]
    {
        uint8_t* __restrict__ src = b;
        uint8_t* __restrict__ dst = a_u8; // 复用 A 的内存作为临时 3-byte buffer
        uint* __restrict__ p = ptr[1];
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[0][seg];
            if (count == 0) continue;

            // seg 是 B0
            uint32_t b0 = seg;

            uint32_t offset = (seg == 0) ? 0 : (ptr[0][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            int i = 0;
            for (; i <= count - 21; i += 21) {
                _mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
                #pragma GCC unroll 21
                for (int j = 0; j < 21; j++) {
                    uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
                    uint8_t k = val & 255;
                    store3(dst + p[k], (val & 0xFFFF00) | b0);
                    p[k] += 3;
                }
            }
            for (; i < count; i++) {
                uint32_t val = *(uint32_t*)(sp + i * 3);
                
                uint8_t k = val & 255; // B1
                
                store3(dst + p[k], (val & 0xFFFF00) | b0);
                p[k] += 3;
            }
        }
    }

    // --- Pass 3: Read A(3), Write B(3) ---
    // Input: [B0][B2][B3]. Key: B2. Seg: B1.
    // Store: [B0][B1][B3]
    {
        uint8_t* __restrict__ src = a_u8;
        uint8_t* __restrict__ dst = b;
        uint* __restrict__ p = ptr[2];
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[1][seg];
            if (count == 0) continue;

            // seg 是 B1
            uint32_t b1 = seg << 8;

            uint32_t offset = (seg == 0) ? 0 : (ptr[1][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            int i = 0;
            for (; i <= count - 21; i += 21) {
                _mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
                #pragma GCC unroll 21
                for (int j = 0; j < 21; j++) {
                    uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
                    uint8_t k = (val >> 8) & 255;
                    store3(dst + p[k], (val & 0xFF00FF) | b1);
                    p[k] += 3;
                }
            }
            for (; i < count; i++) {
                uint32_t val = *(uint32_t*)(sp + i * 3);
                // val: [B0][B2][B3]
                // key: B2
                uint8_t k = (val >> 8) & 255;
                
                store3(dst + p[k], (val & 0xFF00FF) | b1);
                p[k] += 3;
            }
        }
    }

    // --- Pass 4: Read B(3), Write A(4) ---
    // Input: [B0][B1][B3]. Key: B3. Seg: B2.
    // Store: [B0][B1][B2][B3] (Full uint32)
    {
        uint8_t* __restrict__ src = b;
        uint* __restrict__ dst = a; 
        
        uint8_t* __restrict__ dst_u8 = (uint8_t*)a;
        uint* __restrict__ p = ptr[3]; 
        
        for (int seg = 0; seg < 256; seg++) {
            int count = cnt[2][seg];
            if (count == 0) continue;

            // seg 是 B2
            uint32_t b2 = seg << 16; 

            uint32_t offset = (seg == 0) ? 0 : (ptr[2][seg - 1] + 4);
            uint8_t* sp = src + offset;
            
            int i = 0;
            for (; i <= count - 21; i += 21) {
                _mm_prefetch((const char*)(sp + (i + 63) * 3), _MM_HINT_NTA);
                #pragma GCC unroll 21
                for (int j = 0; j < 21; j++) {
                    uint32_t val = *(uint32_t*)(sp + (i + j) * 3);
                    uint32_t k = (val >> 16) & 255;
                    uint32_t final_val = _pdep_u32(val, 0xFF00FFFF) | b2;
                    *(uint32_t*)(dst_u8 + p[k]) = final_val;
                    p[k] += 4;
                }
            }
            for (; i < count; i++) {
                uint32_t val = *(uint32_t*)(sp + i * 3);
                
                uint32_t k = (val >> 16) & 255; // B3 (Key)
                
                uint32_t final_val = _pdep_u32(val, 0xFF00FFFF) | b2;
                
                *(uint32_t*)(dst_u8 + p[k]) = final_val;
                
                p[k] += 4; 
            }
        }
    }
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1712.426 ms667 MB + 632 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-18 20:32:31 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠