提交记录 28804


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Wrong Answer 0 687.894 ms 683656 KB C++17 9.65 KB
提交时间 评测时间
2026-01-18 20:13:28 2026-01-18 20:13:34
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

// 预取距离:Pass 1 需要,后续 Pass 根据指示移除 Prefetch
const int PREFETCH_DIST = 64; 

// 写入 3 字节 (b 数组)
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}

// 写入 2 字节 (a 数组临时, b 数组临时)
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}

void sort(uint* a, int n) {
    // ---------------------------------------------------------
    // 内存规划
    // a: 输入数组 (4N)。
    //    Pass 1: 只读输入。
    //    Pass 2: 输出 buffer (2N 布局,含 Padding)。
    //    Pass 4: 最终输出 buffer (4N 布局,紧凑)。
    // b: 临时数组 (3N + Padding)。
    //    Pass 1: 输出 buffer (3N 布局)。
    //    Pass 3: 输出 buffer (2N 布局,复用空间)。
    // ---------------------------------------------------------

    // 申请 b 数组,包含足够的 Padding 用于溢出写
    uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);
    
    // 全局 B3 直方图
    uint cnt_global[256];
    memset(cnt_global, 0, sizeof(cnt_global));
    
    // 1.1 统计 B3 (Global MSD)
    for (int i = 0; i < n; i++) {
        cnt_global[a[i] >> 24]++;
    }
    
    // 1.2 计算 B3 Offset (Pass 1 Write ptr)
    uint ptr_global[256];
    size_t offset_b3 = 0;
    for (int i = 0; i < 256; i++) {
        ptr_global[i] = (uint)offset_b3;
        offset_b3 += cnt_global[i] * 3 + 64; // Padding
    }
    
    // ---------------------------------------------------------
    // Pass 1: Global MSD (Partition by B3)
    // Read: a (4B) -> Write: b (3B: [B0, B1, B2])
    // ---------------------------------------------------------
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint p[256];
        memcpy(p, ptr_global, sizeof(p));

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val >> 24; 
                store3(dst + p[k], val); 
                p[k] += 3;
            }
        }
    }

    // ---------------------------------------------------------
    // 分段处理:遍历 B3 的每一个 Bucket
    // ---------------------------------------------------------
    
    uint8_t* a_u8 = (uint8_t*)a;
    size_t a_offset_base = 0; // 当前 B3 Bucket 在 a 中的起始元素下标 (用于 Pass 4)

    // 局部直方图/指针缓存
    uint cnt_local[256];
    uint ptr_local[256];

    for (int i_b3 = 0; i_b3 < 256; i_b3++) {
        int count_b3 = cnt_global[i_b3];
        if (count_b3 == 0) continue;

        // Pass 2 输入源 (b)
        uint8_t* seg_b_in = b + ptr_global[i_b3];
        
        // Pass 2 输出目标 (a 的当前段)
        // 注意:Pass 2 需要 Padding,可能会稍微超出 4*count_b3 的前半部分,
        // 但 a 是连续大数组,且后续 B3 bucket 此时为空(数据在 b),所以溢出到下一个 bucket 区域是安全的。
        // 唯独最后一个 bucket 需要注意,但 malloc 一般有少量余量,或者 n 很大时 padding 占比极小。
        // 这里为了绝对安全,如果涉及最后一个 bucket 且空间紧张,可以使用备用 buffer,
        // 但通常 OJ 环境 a 后面不会紧贴着不可写页。我们假设安全。
        uint8_t* seg_a_temp = a_u8 + a_offset_base * 4;

        // -----------------------------------------------------
        // Pass 2: Local MSD (Partition by B2)
        // Read b (3B: [B0, B1, B2]) -> Write a (2B: [B0, B1])
        // -----------------------------------------------------
        
        memset(cnt_local, 0, sizeof(cnt_local));
        
        // 统计 B2 (Offset 2 in b)
        for (int i = 0; i < count_b3; i++) {
            cnt_local[*(seg_b_in + i * 3 + 2)]++;
        }

        uint ptr_b2_start[256];
        uint32_t tmp_offset = 0;
        // 计算 Pass 2 写入偏移 (需要 Padding)
        for (int k = 0; k < 256; k++) {
            ptr_local[k] = tmp_offset; 
            ptr_b2_start[k] = tmp_offset; // 保存起始偏移供 Pass 3 读取
            tmp_offset += cnt_local[k] * 2 + 64; 
        }

        // 保存 B2 计数,供后续反向遍历使用
        uint cnt_b2_saved[256];
        memcpy(cnt_b2_saved, cnt_local, sizeof(cnt_local));

        // 执行 Pass 2 Scatter
        {
            uint p[256];
            memcpy(p, ptr_local, sizeof(p));
            uint8_t* src = seg_b_in;
            uint8_t* dst = seg_a_temp;

            // Pass 2 不需要 Prefetch (Local scan)
            int i = 0;
            for (; i <= count_b3 - 16; i += 16) {
                #pragma GCC unroll 16
                for (int j = 0; j < 16; j++) {
                    uint32_t val = *(uint32_t*)(src + (i + j) * 3);
                    uint8_t key = (val >> 16) & 0xFF; // B2
                    store2(dst + p[key], val); // Store [B0, B1] (low 16 bits of val)
                    p[key] += 2;
                }
            }
            for (; i < count_b3; i++) {
                uint32_t val = *(uint32_t*)(src + i * 3);
                uint8_t key = (val >> 16) & 0xFF;
                store2(dst + p[key], val);
                p[key] += 2;
            }
        }
        
        // Pass 2 完成,b 的对应区域现在空闲,可用作 Pass 3 的输出 Buffer
        uint8_t* seg_b_reuse = seg_b_in;

        // -----------------------------------------------------
        // Loop over B2 Buckets (Reverse Order!)
        // 反向遍历是为了让 Pass 4 的写入 (Write a, dense 4B) 
        // 始终在 Pass 3 读取 (Read a, sparse 2B) 之后或是高地址区域,避免覆盖。
        // Pass 3 读取完 a 的数据后,该区域即可被 Pass 4 覆盖。
        // -----------------------------------------------------
        
        // Pass 4 的当前写入指针 (以 int 为单位),初始指向当前 B3 Bucket 的末尾
        size_t offset_final_end = count_b3; 

        for (int i_b2 = 255; i_b2 >= 0; i_b2--) {
            int count_b2 = cnt_b2_saved[i_b2];
            if (count_b2 == 0) continue;

            // Pass 3 输入: a 中对应 B2 的段
            uint8_t* src_pass3 = seg_a_temp + ptr_b2_start[i_b2];
            // Pass 3 输出: b (复用)
            // 由于每次处理一个 B2 bucket,我们可以每次都从 seg_b_reuse 的头部开始写
            // 只要空间够大 (3*N_B3 vs 2*N_B2,肯定够)
            uint8_t* dst_pass3 = seg_b_reuse;

            // -------------------------------------------------
            // Pass 3: LSD Step 1 (Partition by B0)
            // Read a (2B: [B0, B1]) -> Write b (2B: [B0, B1])
            // Group by B0
            // -------------------------------------------------
            
            memset(cnt_local, 0, sizeof(cnt_local));
            for (int i = 0; i < count_b2; i++) {
                cnt_local[src_pass3[i * 2]]++; // Count B0
            }

            tmp_offset = 0;
            // 计算 Pass 3 写入偏移 (需要 Padding)
            for (int k = 0; k < 256; k++) {
                ptr_local[k] = tmp_offset;
                tmp_offset += cnt_local[k] * 2 + 64; 
            }
            
            // 执行 Pass 3 Scatter
            {
                uint p[256];
                memcpy(p, ptr_local, sizeof(p));
                uint8_t* src = src_pass3;
                uint8_t* dst = dst_pass3;

                int i = 0;
                for (; i <= count_b2 - 16; i += 16) {
                    #pragma GCC unroll 16
                    for (int j = 0; j < 16; j++) {
                        uint16_t val = *(uint16_t*)(src + (i + j) * 2);
                        uint8_t key = val & 0xFF; // B0
                        store2(dst + p[key], val);
                        p[key] += 2;
                    }
                }
                for (; i < count_b2; i++) {
                    uint16_t val = *(uint16_t*)(src + i * 2);
                    uint8_t key = val & 0xFF;
                    store2(dst + p[key], val);
                    p[key] += 2;
                }
            }

            // -------------------------------------------------
            // Pass 4: LSD Step 2 (Partition by B1) & Finalize
            // Read b (2B: [B0, B1]) -> Write a (4B: Full)
            // Stable Sort by B1 (Input is already sorted by B0 via Pass 3)
            // -------------------------------------------------
            
            // 统计 B1 (Pass 4 Key)
            // 需重新扫描 dst_pass3,因为 Pass 3 是按 B0 写入的
            memset(cnt_local, 0, sizeof(cnt_local));
            for (int i = 0; i < count_b2; i++) {
                cnt_local[dst_pass3[i * 2 + 1]]++; // Count B1
            }

            // 计算 Pass 4 写入偏移 (紧凑写入 a,不需要 Padding)
            // 指针是相对于该 B2 Bucket 最终在 a 中起始位置的偏移 (uint 单位)
            uint ptr_pass4[256];
            tmp_offset = 0;
            for (int k = 0; k < 256; k++) {
                ptr_pass4[k] = tmp_offset;
                tmp_offset += cnt_local[k]; 
            }
            
            // 确定 Pass 4 写入的基地址
            size_t offset_final_start = offset_final_end - count_b2;
            uint* dst_final_base = (uint*)a_u8 + a_offset_base + offset_final_start;
            
            uint32_t high_bits = (i_b3 << 24) | (i_b2 << 16);

            // 执行 Pass 4 Scatter
            {
                uint p[256];
                memcpy(p, ptr_pass4, sizeof(p));
                uint8_t* src = dst_pass3;

                int i = 0;
                for (; i <= count_b2 - 16; i += 16) {
                    #pragma GCC unroll 16
                    for (int j = 0; j < 16; j++) {
                        uint16_t val = *(uint16_t*)(src + (i + j) * 2);
                        uint8_t key = val >> 8; // B1
                        
                        // Construct Full Int: High | B1<<8 | B0
                        // val is (B1<<8) | B0. So just High | val.
                        dst_final_base[p[key]] = high_bits | val;
                        p[key]++;
                    }
                }
                for (; i < count_b2; i++) {
                    uint16_t val = *(uint16_t*)(src + i * 2);
                    uint8_t key = val >> 8;
                    dst_final_base[p[key]] = high_bits | val;
                    p[key]++;
                }
            }
            
            // 更新 End 指针
            offset_final_end = offset_final_start;
        }
        
        // 更新全局 Base
        a_offset_base += count_b3;
    }
    
    free(b);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1687.894 ms667 MB + 648 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-22 07:59:25 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠