提交记录 28803


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Accepted 100 2.271 s 683656 KB C++17 8.51 KB
提交时间 评测时间
2026-01-18 20:00:14 2026-01-18 20:00:22
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

const int PREFETCH_DIST = 64; 

// 辅助函数
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}
inline void store1(uint8_t* __restrict__ p, uint8_t val) {
    *p = val;
}

void sort(uint* a, int n) {
    // b: 临时数组, 大小 3N + Padding
    // Pass 1 Write, Pass 3 Write (reuse), Pass 2 Temp (reuse)
    uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);

    uint cnt_global[256];
    memset(cnt_global, 0, sizeof(cnt_global));
    
    // 1.1 统计 B3 (Global MSD)
    for (int i = 0; i < n; i++) {
        cnt_global[a[i] >> 24]++;
    }
    
    // 1.2 计算 B3 Offset (Pass 1 Write ptr)
    uint ptr_global[256];
    size_t offset_b3 = 0;
    for (int i = 0; i < 256; i++) {
        ptr_global[i] = (uint)offset_b3;
        offset_b3 += cnt_global[i] * 3 + 64; // Padding
    }
    
    // ---------------------------------------------------------
    // Pass 1: Global MSD (Partition by B3)
    // Read: a (4B) -> Write: b (3B: [B0, B1, B2])
    // ---------------------------------------------------------
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint p[256];
        memcpy(p, ptr_global, sizeof(p));

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val >> 24; 
                store3(dst + p[k], val); 
                p[k] += 3;
            }
        }
    }

    // ---------------------------------------------------------
    // 分段处理:遍历 B3 的每一个 Bucket
    // ---------------------------------------------------------
    
    uint8_t* a_u8 = (uint8_t*)a;
    size_t a_offset_final = 0; // Final output offset in a (bytes / 4)

    for (int i_b3 = 0; i_b3 < 256; i_b3++) {
        int count_b3 = cnt_global[i_b3];
        if (count_b3 == 0) continue;

        // Input for Pass 2 (from b)
        uint8_t* seg_b_in = b + ptr_global[i_b3];
        
        // Output for Pass 2 (Temp)
        // Pass 2 需要 2*N + Padding 的空间。
        // 为了避免内存覆盖,根据 i_b3 的进度选择临时空间位置。
        // 前半段(i_b3 < 128)时,b 的前半段被占用,但 a 的后半段空闲。
        // 后半段(i_b3 >= 128)时,b 的前半段已空闲 (Pass 1 数据已消费)。
        uint8_t* seg_pass2_out;
        if (i_b3 < 128) {
            // 使用 a 的末尾区域 (安全,因为 a 的写入从头开始)
            // 需要确保不覆盖当前 Final Write 的区域。
            // Final Write 目前写到 a_offset_final。
            // Pass 2 Out 放在 a 数组的最后端。
            seg_pass2_out = a_u8 + ((size_t)n * 4) - ((size_t)count_b3 * 2 + 256 * 64 + 4096); 
        } else {
            // 使用 b 的起始区域 (安全,因为 b 的读取在后半段)
            seg_pass2_out = b;
        }

        // -----------------------------------------------------
        // Pass 2: Local MSD (Partition by B2)
        // Read b (3B: [B0, B1, B2]) -> Write Temp (2B: [B0, B1])
        // -----------------------------------------------------
        
        uint cnt_b2[256];
        memset(cnt_b2, 0, sizeof(cnt_b2));
        
        // 统计 B2 (Offset 2 in b)
        for (int i = 0; i < count_b3; i++) {
            cnt_b2[*(seg_b_in + i * 3 + 2)]++;
        }

        uint ptr_b2[256];
        uint32_t tmp_offset = 0;
        // 计算 Pass 2 写入偏移 (Need Padding for Scatter)
        for (int k = 0; k < 256; k++) {
            ptr_b2[k] = tmp_offset;
            tmp_offset += cnt_b2[k] * 2 + 64; 
        }

        uint ptr_b2_start[256]; // 保存起始位置供 Loop B2 使用
        memcpy(ptr_b2_start, ptr_b2, sizeof(ptr_b2));
        
        // Scatter B2
        {
            uint p[256];
            memcpy(p, ptr_b2, sizeof(p));
            uint8_t* src = seg_b_in;
            uint8_t* dst = seg_pass2_out;

            int i = 0;
            for (; i <= count_b3 - 16; i += 16) {
                _mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 3], _MM_HINT_NTA);
                
                #pragma GCC unroll 16
                for (int j = 0; j < 16; j++) {
                    uint32_t val = *(uint32_t*)(src + (i + j) * 3);
                    uint8_t key = (val >> 16) & 0xFF; // B2
                    store2(dst + p[key], val); // Store [B0, B1]
                    p[key] += 2;
                }
            }
            for (; i < count_b3; i++) {
                uint32_t val = *(uint32_t*)(src + i * 3);
                uint8_t key = (val >> 16) & 0xFF;
                store2(dst + p[key], val);
                p[key] += 2;
            }
        }
        
        // Pass 2 完成,seg_b_in 指向的内存 (3N part) 现在已经无用,可以作为 Pass 3 的输出 buffer
        uint8_t* seg_pass3_out_base = seg_b_in;

        // -----------------------------------------------------
        // 分段处理:遍历 B2 的每一个 Bucket
        // -----------------------------------------------------
        for (int i_b2 = 0; i_b2 < 256; i_b2++) {
            int count_b2 = cnt_b2[i_b2];
            if (count_b2 == 0) continue;

            uint8_t* src_b2 = seg_pass2_out + ptr_b2_start[i_b2];
            uint8_t* dst_b1 = seg_pass3_out_base; // Reuse b slice

            // -------------------------------------------------
            // Pass 3: LSD Step 1 (Partition by B1)
            // Read Temp (2B: [B0, B1]) -> Write b (1B: [B0])
            // Group by B1 to prepare for Final Sort
            // -------------------------------------------------
            
            uint cnt_b1[256] = {0};
            for (int i = 0; i < count_b2; i++) {
                // Read B1 (High byte)
                cnt_b1[*(src_b2 + i * 2 + 1)]++;
            }

            uint ptr_b1[256];
            tmp_offset = 0;
            // Need padding for Pass 3 Scatter
            for (int k = 0; k < 256; k++) {
                ptr_b1[k] = tmp_offset;
                tmp_offset += cnt_b1[k] + 32; 
            }
            
            uint ptr_b1_read[256];
            memcpy(ptr_b1_read, ptr_b1, sizeof(ptr_b1));

            // Execute Pass 3 Scatter
            {
                uint p[256];
                memcpy(p, ptr_b1, sizeof(p));
                
                int i = 0;
                for (; i <= count_b2 - 16; i += 16) {
                    #pragma GCC unroll 16
                    for (int j = 0; j < 16; j++) {
                        uint16_t val = *(uint16_t*)(src_b2 + (i + j) * 2);
                        uint8_t key = val >> 8; // B1
                        store1(dst_b1 + p[key], val & 0xFF); // Store B0
                        p[key]++;
                    }
                }
                for (; i < count_b2; i++) {
                    uint16_t val = *(uint16_t*)(src_b2 + i * 2);
                    uint8_t key = val >> 8;
                    store1(dst_b1 + p[key], val & 0xFF);
                    p[key]++;
                }
            }

            // -------------------------------------------------
            // Pass 4: Finalize (Iterate B1, Sort B0)
            // Read b (1B: [B0]) -> Write a (4B)
            // -------------------------------------------------
            
            uint* dst_final = (uint*)a_u8 + a_offset_final;
            uint32_t high_bits = (i_b3 << 24) | (i_b2 << 16);

            for (int i_b1 = 0; i_b1 < 256; i_b1++) {
                int c = cnt_b1[i_b1];
                if (c == 0) continue;

                uint8_t* b0_ptr = dst_b1 + ptr_b1_read[i_b1];
                uint32_t val_base = high_bits | (i_b1 << 8);

                // Small Sort B0
                // Typically c is very small (~6). std::sort is fast.
                if (c < 32) {
                     // Insertion sort logic inline or std::sort
                     // std::sort with uint8_t is very optimized
                     std::sort(b0_ptr, b0_ptr + c);
                     for (int k = 0; k < c; k++) {
                         *dst_final++ = val_base | b0_ptr[k];
                     }
                } else {
                    // Counting sort for larger buckets (rare but possible)
                    uint cnt_b0_local[256] = {0};
                    for (int k = 0; k < c; k++) cnt_b0_local[b0_ptr[k]]++;
                    
                    for (int v = 0; v < 256; v++) {
                        int cc = cnt_b0_local[v];
                        if (cc) {
                            uint32_t val = val_base | v;
                            while (cc--) *dst_final++ = val;
                        }
                    }
                }
            }
            a_offset_final += count_b2;
        }
    }
    
    free(b);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #12.271 s667 MB + 648 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-22 08:03:30 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠