提交记录 28799


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Accepted 100 660.195 ms 683644 KB C++17 9.77 KB
提交时间 评测时间
2026-01-18 19:18:54 2026-01-18 19:19:00
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

const int n = 1e8;
const int PREFETCH_DIST = 64; // 元素个数:Pass1(256B), Pass2(192B), Pass3/4(128B)

// 辅助函数:向地址 p 写入 3 字节 (利用 uint32 覆盖写,需保证 buffer 有 padding)
// Input val: [B0, B1, B2, X] (Little Endian) -> Writes B0, B1, B2
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}

// 辅助函数:向地址 p 写入 2 字节
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}

void sort(uint* a, int __n) {
    // ---------------------------------------------------------
    // Pass 1: Global MSD (Partition by B3)
    // Read: a (4 bytes) -> Write: b (3 bytes: [B0, B1, B2])
    // ---------------------------------------------------------
    
    uint cnt_global[256];
    memset(cnt_global, 0, sizeof(cnt_global));
    
    // 1.1 统计 B3
    for (int i = 0; i < n; i++) {
        cnt_global[a[i] >> 24]++;
    }
    
    // 1.2 计算 B3 Offset (Bytes in b)
    // 增加 4 字节 Padding 以安全使用 store3
    uint ptr_global[256];
    uint32_t offset_b3 = 0;
    for (int i = 0; i < 256; i++) {
        ptr_global[i] = offset_b3;
        offset_b3 += cnt_global[i] * 3 + 4; 
    }
    
    // 申请 b 数组 
    // 大小 = 理论数据量(3N) + 所有桶的 Padding(256*4) + 尾部安全余量
    uint8_t* b = (uint8_t*)malloc(n * 3 + 256 * 4 + 4096);
    
    // 1.3 执行 Pass 1 分发
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint p[256];
        memcpy(p, ptr_global, sizeof(p));

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val >> 24; 
                store3(dst + p[k], val); 
                p[k] += 3;
            }
        }
    }

    // ---------------------------------------------------------
    // 分段处理:遍历 B3 的每一个 Bucket
    // ---------------------------------------------------------
    
    uint8_t* a_u8 = (uint8_t*)a;

    // 局部直方图缓存
    uint cnt0[256];
    uint cnt1[256];
    uint cnt2[256];
    uint ptr0[256]; // Pass 2 (Write a) pointers
    uint ptr1[256]; // Pass 3 (Write b) pointers
    uint ptr2[256]; // Pass 4 (Write a Final) pointers

    uint32_t b_offset_start = 0; // byte offset in b (Reading)
    uint32_t a_offset_start = 0; // index offset in a (Writing Final)

    for (int i_b3 = 0; i_b3 < 256; i_b3++) {
        int count = cnt_global[i_b3];
        if (count == 0) continue;

        uint8_t* seg_b_in = b + ptr_global[i_b3]; // Pass 1 Output -> Pass 2 Input
        
        // Pass 2 Output (Temporary in a)
        // 使用与 Final Output 相同的基地址,因为 4N > 2N + Padding,安全
        uint8_t* seg_a_temp = a_u8 + (a_offset_start * 4);
        
        // -----------------------------------------------------
        // Step 2: Local Histogram Scan
        // -----------------------------------------------------
        memset(cnt0, 0, sizeof(cnt0));
        memset(cnt1, 0, sizeof(cnt1));
        memset(cnt2, 0, sizeof(cnt2));

        int i = 0;
        for (; i <= count - 4; i += 4) {
            // Unroll 4
            uint32_t v0 = *(uint32_t*)(seg_b_in + i * 3);
            cnt0[v0 & 0xFF]++;
            cnt1[(v0 >> 8) & 0xFF]++;
            cnt2[(v0 >> 16) & 0xFF]++;
            
            uint32_t v1 = *(uint32_t*)(seg_b_in + (i + 1) * 3);
            cnt0[v1 & 0xFF]++;
            cnt1[(v1 >> 8) & 0xFF]++;
            cnt2[(v1 >> 16) & 0xFF]++;

            uint32_t v2 = *(uint32_t*)(seg_b_in + (i + 2) * 3);
            cnt0[v2 & 0xFF]++;
            cnt1[(v2 >> 8) & 0xFF]++;
            cnt2[(v2 >> 16) & 0xFF]++;

            uint32_t v3 = *(uint32_t*)(seg_b_in + (i + 3) * 3);
            cnt0[v3 & 0xFF]++;
            cnt1[(v3 >> 8) & 0xFF]++;
            cnt2[(v3 >> 16) & 0xFF]++;
        }
        for (; i < count; i++) {
            uint32_t v = *(uint32_t*)(seg_b_in + i * 3);
            cnt0[v & 0xFF]++;
            cnt1[(v >> 8) & 0xFF]++;
            cnt2[(v >> 16) & 0xFF]++;
        }

        // Calculate Pointers with Padding for intermediate steps
        uint32_t tmp = 0;
        for(int k=0; k<256; k++) { ptr0[k] = tmp; tmp += cnt0[k] * 2 + 4; } // Pass 2 (write a): Need Padding
        
        tmp = 0;
        for(int k=0; k<256; k++) { ptr1[k] = tmp; tmp += cnt1[k] * 2 + 4; } // Pass 3 (write b): Need Padding
        
        tmp = 0;
        for(int k=0; k<256; k++) { ptr2[k] = tmp; tmp += cnt2[k]; }         // Pass 4 (Final): NO Padding (Dense)

        // -----------------------------------------------------
        // Pass 2: LSD Step 1 (Key B0)
        // Read b (3B: B0,B1,B2) -> Write a (2B: B1,B2)
        // -----------------------------------------------------
        {
            uint p[256];
            memcpy(p, ptr0, sizeof(p));
            uint8_t* src = seg_b_in;
            uint8_t* dst = seg_a_temp;

            int k = 0;
            for (; k <= count - 16; k += 16) {
                _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 3), _MM_HINT_T0);
                #pragma GCC unroll 16
                for (int j = 0; j < 16; j++) {
                    uint32_t val = *(uint32_t*)(src + (k + j) * 3);
                    uint8_t key = val & 0xFF; // B0
                    // Store [B1, B2] (val >> 8)
                    store2(dst + p[key], val >> 8);
                    p[key] += 2;
                }
            }
            for (; k < count; k++) {
                uint32_t val = *(uint32_t*)(src + k * 3);
                uint8_t key = val & 0xFF;
                store2(dst + p[key], val >> 8);
                p[key] += 2;
            }
        }

        // -----------------------------------------------------
        // Pass 3: LSD Step 2 (Key B1)
        // Read a (2B: B1,B2) -> Write b (2B: B0,B2)
        // Iterate B0 buckets to restore B0
        // -----------------------------------------------------
        {
            // Reuse seg_b_in for output. Size is 3N > 2N, safe.
            uint8_t* dst_base = seg_b_in; 
            
            uint p[256];
            memcpy(p, ptr1, sizeof(p));

            // Iterate over B0 buckets (Pass 2 output)
            for (int b0 = 0; b0 < 256; b0++) {
                int c = cnt0[b0];
                if (c == 0) continue;

                uint8_t* src = seg_a_temp + ptr0[b0];
                uint16_t val_b0 = b0; 

                int k = 0;
                for (; k <= c - 21; k += 21) {
                    _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_T0);
                    #pragma GCC unroll 21
                    for (int j = 0; j < 21; j++) {
                        // Read [B1, B2]
                        uint16_t val = *(uint16_t*)(src + (k + j) * 2);
                        uint8_t key = val & 0xFF; // B1
                        
                        // Construct [B0, B2]
                        // We store B0 at low byte, B2 at high byte.
                        // val & 0xFF00 is (B2 << 8).
                        uint16_t new_val = val_b0 | (val & 0xFF00);
                        
                        store2(dst_base + p[key], new_val);
                        p[key] += 2;
                    }
                }
                for (; k < c; k++) {
                    uint16_t val = *(uint16_t*)(src + k * 2);
                    uint8_t key = val & 0xFF;
                    store2(dst_base + p[key], val_b0 | (val & 0xFF00));
                    p[key] += 2;
                }
            }
        }

        // -----------------------------------------------------
        // Pass 4: LSD Step 3 (Key B2) & Finalize
        // Read b (2B: B0,B2) -> Write a (4B: Full)
        // Iterate B1 buckets to restore B1
        // -----------------------------------------------------
        {
            uint* dst_base = a + a_offset_start;
            uint8_t* src_base = seg_b_in; 

            uint p[256];
            memcpy(p, ptr2, sizeof(p));

            uint32_t val_b3_shifted = i_b3 << 24;

            // Iterate over B1 buckets (Pass 3 output)
            for (int b1 = 0; b1 < 256; b1++) {
                int c = cnt1[b1];
                if (c == 0) continue;

                uint8_t* src = src_base + ptr1[b1];
                
                // Common High bits: B3 | B1 << 8
                // We will OR this with PDEP result
                uint32_t common_bits = val_b3_shifted | (b1 << 8);

                int k = 0;
                for (; k <= c - 32; k += 32) {
                    _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_T0);
                    #pragma GCC unroll 32
                    for (int j = 0; j < 32; j++) {
                        // Read [B0, B2]
                        uint32_t val = *(uint16_t*)(src + (k + j) * 2);
                        
                        // Key is B2 (High byte of val)
                        uint8_t key = val >> 8;
                        
                        // Reconstruct:
                        // val has bits: 0-7 (B0), 8-15 (B2)
                        // Target: 0-7 (B0), 16-23 (B2)
                        // PDEP Mask: 0x00FF00FF (Deposit val bits 0-7 to 0-7, 8-15 to 16-23)
                        uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
                        
                        dst_base[p[key]] = common_bits | scattered;
                        p[key]++;
                    }
                }
                for (; k < c; k++) {
                    uint32_t val = *(uint16_t*)(src + k * 2);
                    uint8_t key = val >> 8;
                    uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
                    dst_base[p[key]] = common_bits | scattered;
                    p[key]++;
                }
            }
        }

        // Update offsets
        a_offset_start += count;
    }
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1660.195 ms667 MB + 636 KBAcceptedScore: 100


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-22 08:03:30 | Loaded in 0 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠