提交记录 28802


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Runtime Error 0 786.266 ms 683656 KB C++17 8.92 KB
提交时间 评测时间
2026-01-18 19:53:01 2026-01-18 19:53:06
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

const int PREFETCH_DIST = 64; 

// 辅助函数:向地址 p 写入 3 字节
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}

// 辅助函数:向地址 p 写入 2 字节
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}

// 辅助函数:向地址 p 写入 1 字节
inline void store1(uint8_t* __restrict__ p, uint8_t val) {
    *p = val;
}

void sort(uint* a, int n) {
    // ---------------------------------------------------------
    // 内存规划:
    // a: 输入/最终输出 (4N)。Pass 1 之后数据移至 b,a 用作 Pass 2 的输出 (2N) 和 Pass 4 的最终输出。
    // b: 临时 buffer (3N + Padding)。用于 Pass 1 输出,并在 Pass 3 复用。
    // ---------------------------------------------------------

    uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 256 * 64 + 4096);

    uint cnt_global[256];
    memset(cnt_global, 0, sizeof(cnt_global));
    
    // 1.1 统计 B3 (Global MSD)
    for (int i = 0; i < n; i++) {
        cnt_global[a[i] >> 24]++;
    }
    
    // 1.2 计算 B3 Offset
    uint ptr_global[256];
    size_t offset_b3 = 0;
    for (int i = 0; i < 256; i++) {
        ptr_global[i] = (uint)offset_b3;
        offset_b3 += cnt_global[i] * 3 + 64; // Padding needed for store3/prefetch
    }
    
    // ---------------------------------------------------------
    // Pass 1: Global MSD (Partition by B3)
    // Read: a (4B) -> Write: b (3B: [B0, B1, B2])
    // ---------------------------------------------------------
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint p[256];
        memcpy(p, ptr_global, sizeof(p));

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val >> 24; 
                store3(dst + p[k], val); 
                p[k] += 3;
            }
        }
    }

    // ---------------------------------------------------------
    // 分段处理:遍历 B3 的每一个 Bucket
    // ---------------------------------------------------------
    
    uint8_t* a_u8 = (uint8_t*)a;
    size_t a_offset_base = 0; // 当前 B3 Bucket 在 a 中的起始元素下标

    for (int i_b3 = 0; i_b3 < 256; i_b3++) {
        int count_b3 = cnt_global[i_b3];
        if (count_b3 == 0) continue;

        uint8_t* seg_b_in = b + ptr_global[i_b3];
        
        // Pass 2 输出到 a 的当前段。
        // Pass 2 写 2 字节/元素,Pass 4 写 4 字节/元素。
        // 由于是顺序处理,Pass 4 的写入会覆盖 Pass 2 的数据,但 Pass 3 会在此之前将 Pass 2 数据搬回 b。
        uint8_t* seg_a_current = a_u8 + a_offset_base * 4;

        // -----------------------------------------------------
        // Pass 2: Local MSD (Partition by B2)
        // Read b (3B: [B0, B1, B2]) -> Write a (2B: [B0, B1])
        // -----------------------------------------------------
        
        uint cnt_b2[256] = {0};
        
        // 统计 B2 (Offset 2 in b)
        for (int i = 0; i < count_b3; i++) {
            cnt_b2[*(seg_b_in + i * 3 + 2)]++;
        }

        uint ptr_b2[256];
        uint32_t tmp_offset = 0;
        // 计算 Pass 2 写入偏移。注意 a 必须紧凑 (No Padding),因为它是最终输出容器。
        for (int k = 0; k < 256; k++) {
            ptr_b2[k] = tmp_offset;
            tmp_offset += cnt_b2[k] * 4; 
        }

        // 保存 B2 的信息供 Pass 3/4 使用
        uint cnt_b2_saved[256];
        memcpy(cnt_b2_saved, cnt_b2, sizeof(cnt_b2));
        uint ptr_b2_saved[256];
        memcpy(ptr_b2_saved, ptr_b2, sizeof(ptr_b2));

        // 执行 Pass 2 Scatter
        {
            uint p[256];
            memcpy(p, ptr_b2, sizeof(p));
            uint8_t* src = seg_b_in;
            uint8_t* dst = seg_a_current;

            int i = 0;
            for (; i <= count_b3 - 16; i += 16) {
                _mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 3], _MM_HINT_NTA);
                
                #pragma GCC unroll 16
                for (int j = 0; j < 16; j++) {
                    uint32_t val = *(uint32_t*)(src + (i + j) * 3);
                    uint8_t key = (val >> 16) & 0xFF; // B2
                    store2(dst + p[key], val); // Store [B0, B1]
                    p[key] += 2;
                }
            }
            for (; i < count_b3; i++) {
                uint32_t val = *(uint32_t*)(src + i * 3);
                uint8_t key = (val >> 16) & 0xFF;
                store2(dst + p[key], val);
                p[key] += 2;
            }
        }

        // -----------------------------------------------------
        // 分段处理:遍历 B2 的每一个 Bucket
        // -----------------------------------------------------
        for (int i_b2 = 0; i_b2 < 256; i_b2++) {
            int count_b2 = cnt_b2_saved[i_b2];
            if (count_b2 == 0) continue;

            uint8_t* seg_a_in = seg_a_current + ptr_b2_saved[i_b2];
            // 复用 b 的空间作为 Pass 3 的输出 buffer。
            // 此时 b 对应的 B3 Bucket 区域已读完,可以复用。
            // Pass 3 需要 1N 空间,b 有 3N 空间,足够。
            uint8_t* seg_b_temp = seg_b_in; 

            // 同时统计 B0 (Pass 3 Key) 和 B1 (Pass 4 Key)
            uint cnt_b0[256] = {0};
            uint cnt_b1[256] = {0};

            for (int i = 0; i < count_b2; i++) {
                uint16_t val = *(uint16_t*)(seg_a_in + i * 2);
                cnt_b0[val & 0xFF]++;
                cnt_b1[val >> 8]++;
            }

            // 计算 Pass 3 (Write b) 指针 (Need Padding)
            uint ptr_b0[256];
            tmp_offset = 0;
            for (int k = 0; k < 256; k++) {
                ptr_b0[k] = tmp_offset;
                tmp_offset += cnt_b0[k] + 32; 
            }
            uint ptr_b0_read[256]; // 用于 Pass 4 读取
            memcpy(ptr_b0_read, ptr_b0, sizeof(ptr_b0));

            // 计算 Pass 4 (Write a Final) 指针 (No Padding)
            // 目标地址是 a 中对应 B2 Bucket 的位置
            size_t elem_offset_b2 = ptr_b2_saved[i_b2] / 2;
            uint* a_final_dst_base = (uint*)a_u8 + a_offset_base + elem_offset_b2;
            
            uint ptr_b1[256];
            tmp_offset = 0;
            for (int k = 0; k < 256; k++) {
                ptr_b1[k] = tmp_offset;
                tmp_offset += cnt_b1[k]; 
            }

            // -------------------------------------------------
            // Pass 3: LSD Step 1 (Key B0)
            // Read a (2B: [B0, B1]) -> Write b (1B: [B1])
            // -------------------------------------------------
            {
                uint p[256];
                memcpy(p, ptr_b0, sizeof(p));
                uint8_t* src = seg_a_in;
                uint8_t* dst = seg_b_temp;

                int i = 0;
                for (; i <= count_b2 - 16; i += 16) {
                    _mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 2], _MM_HINT_NTA);
                    
                    #pragma GCC unroll 16
                    for (int j = 0; j < 16; j++) {
                        uint16_t val = *(uint16_t*)(src + (i + j) * 2);
                        uint8_t key = val & 0xFF; // B0
                        store1(dst + p[key], val >> 8); // Store B1
                        p[key]++;
                    }
                }
                for (; i < count_b2; i++) {
                    uint16_t val = *(uint16_t*)(src + i * 2);
                    uint8_t key = val & 0xFF;
                    store1(dst + p[key], val >> 8);
                    p[key]++;
                }
            }

            // -------------------------------------------------
            // Pass 4: LSD Step 2 (Key B1) & Finalize
            // Read b (1B: [B1]) -> Write a (4B: Full)
            // -------------------------------------------------
            {
                uint p[256];
                memcpy(p, ptr_b1, sizeof(p));
                
                uint32_t high_bits = (i_b3 << 24) | (i_b2 << 16);

                // 按照 B0 的顺序遍历 (Pass 3 的输出桶)
                // 这样进入 Pass 4 Scatter 的数据流是按 B0 排序的
                // Pass 4 是稳定 Scatter,所以最终结果按 B1 排序,且相同 B1 内按 B0 排序
                for (int b0 = 0; b0 < 256; b0++) {
                    int count = cnt_b0[b0];
                    if (count == 0) continue;

                    uint8_t* src = seg_b_temp + ptr_b0_read[b0];
                    // 构建除 B1 外的部分
                    uint32_t val_base = high_bits | b0;

                    int i = 0;
                    for (; i <= count - 16; i += 16) {
                        #pragma GCC unroll 16
                        for (int j = 0; j < 16; j++) {
                            uint8_t b1 = src[i + j]; // Read B1
                            // Write Full: B3 | B2 | B1 | B0
                            a_final_dst_base[p[b1]] = val_base | (b1 << 8);
                            p[b1]++;
                        }
                    }
                    for (; i < count; i++) {
                        uint8_t b1 = src[i];
                        a_final_dst_base[p[b1]] = val_base | (b1 << 8);
                        p[b1]++;
                    }
                }
            }
        }
        
        a_offset_base += count_b3;
    }
    
    free(b);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1786.266 ms667 MB + 648 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-22 08:01:01 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠