提交记录 28801


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Wrong Answer 0 840.106 ms 683656 KB C++17 10.33 KB
提交时间 评测时间
2026-01-18 19:32:45 2026-01-18 19:32:52
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

using namespace std;

// 预取距离
const int PREFETCH_DIST = 64; 

// 写入 3 字节 (b 数组)
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}

// 写入 2 字节 (a 数组临时)
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}

// 写入 1 字节 (b 数组临时)
inline void store1(uint8_t* __restrict__ p, uint8_t val) {
    *p = val;
}

void sort(uint* a, int n) {
    // ---------------------------------------------------------
    // 内存规划
    // a: 输入数组 (4N 字节)。Pass 1 后数据移至 b,a 可作为临时空间。
    // b: 临时数组。
    // Pass 1 Output (b): 存储 [B0, B1, B2]。大小 3N。
    // Pass 2 Output (a): 存储 [B0, B1]。大小 2N。存储在 a 的前半部分或分段复用。
    // Pass 3 Output (b): 存储 [B0]。大小 1N。复用 b 的空间 (Pass 2 读取后释放)。
    // ---------------------------------------------------------

    // 申请 b 数组。
    // 大小 = 3N (Pass 1) + Padding。
    // Pass 3 也会复用这块内存,但只需要 1N,所以 3N 足够。
    // 需要足够的 Padding 防止 Scatter 溢出。
    uint8_t* b = (uint8_t*)malloc((size_t)n * 3 + 4096 * 256);

    // 全局计数器 (B3)
    uint cnt_global[256];
    memset(cnt_global, 0, sizeof(cnt_global));
    
    // 1.1 统计 B3
    for (int i = 0; i < n; i++) {
        cnt_global[a[i] >> 24]++;
    }
    
    // 1.2 计算 Pass 1 (B3) 写入指针
    // 每个 Bucket 预留一些 Padding 用于 SIMD/Unroll 溢出写
    uint ptr_global[256];
    size_t offset_b3 = 0;
    for (int i = 0; i < 256; i++) {
        ptr_global[i] = (uint)offset_b3;
        // 3 bytes per element + 64 bytes padding
        offset_b3 += cnt_global[i] * 3 + 64; 
    }
    
    // 1.3 Pass 1: Global Scatter (Partition by B3) -> Write to b (3 Bytes)
    {
        uint* __restrict__ src = a;
        uint8_t* __restrict__ dst = b;
        uint p[256];
        memcpy(p, ptr_global, sizeof(p));

        for (int i = 0; i < n; i += 16) {
            _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
            #pragma GCC unroll 16
            for (int j = 0; j < 16; j++) {
                uint val = src[i + j];
                uint8_t k = val >> 24; 
                store3(dst + p[k], val); 
                p[k] += 3;
            }
        }
    }

    // ---------------------------------------------------------
    // 分段处理:遍历 B3 的每一个 Bucket
    // ---------------------------------------------------------
    
    // a 数组视作字节数组用于 Pass 2 临时存储
    uint8_t* a_u8 = (uint8_t*)a;
    
    // 用于 Pass 4 (Final Write) 的全局偏移 (以 uint 为单位)
    int a_final_offset = 0;

    // 辅助直方图和指针数组
    uint cnt_local[256];
    uint ptr_local[256];

    for (int i_b3 = 0; i_b3 < 256; i_b3++) {
        int count_b3 = cnt_global[i_b3];
        if (count_b3 == 0) continue;

        // 当前 B3 Bucket 在 b 中的起始位置
        uint8_t* seg_b_in = b + ptr_global[i_b3];
        
        // Pass 2 Output (Temp in a): 存储 [B0, B1]
        // 我们需要一块临时空间。由于我们是顺序处理 B3 Bucket,
        // 且 Pass 4 的写入 (4N) 总是快于 Pass 2 的需求 (2N),
        // 我们可以安全地利用 a 数组当前处理进度之后的一段空间,
        // 或者简单地利用 a_final_offset 对应的字节位置?
        // 风险:Pass 4 写 full int 会覆盖 Pass 2 写 short。
        // 解决方案:利用 a 的偏移量。
        // Pass 4 写入起始位置: a + a_final_offset (字节位置 a_final_offset * 4)
        // Pass 2 写入起始位置: 我们必须保证 Pass 2 的数据在被 Pass 3 读取完之前不被 Pass 4 覆盖。
        // 对于当前 Bucket,Pass 4 写入 4*count_b3 字节。
        // Pass 2 写入 2*count_b3 字节。
        // Pass 3 读取 Pass 2 数据,写入 Pass 3 数据 (到 b)。
        // 只有 Pass 3 完成后,Pass 4 才开始写。
        // 所以,对于 *同一个* Bucket,我们可以复用 a 的同一块内存区域,
        // 只要 Pass 2 的数据 (2B) 放在 Pass 4 (4B) 将要写入的区域内即可。
        // 因为 Pass 2 写完 -> Pass 3 读完 -> Pass 4 才写。无冲突。
        
        uint8_t* seg_a_temp = a_u8 + (size_t)a_final_offset * 4;

        // =====================================================
        // Step 2: Pass 2 (Partition by B2)
        // Read b (3B: B0,B1,B2) -> Write a (2B: B0,B1)
        // =====================================================
        
        // 2.1 统计 B2
        memset(cnt_local, 0, sizeof(cnt_local));
        for (int i = 0; i < count_b3; i++) {
            // b 存储的是 [B0, B1, B2]
            // B2 是第 3 个字节 (offset 2)
            uint8_t val_b2 = *(seg_b_in + i * 3 + 2);
            cnt_local[val_b2]++;
        }

        // 2.2 计算 Offset (for a temp)
        uint32_t tmp_offset = 0;
        for (int k = 0; k < 256; k++) {
            ptr_local[k] = tmp_offset;
            tmp_offset += cnt_local[k] * 2 + 64; // Padding
        }
        
        // 保存 B2 的计数,供后续步骤使用
        uint cnt_b2_saved[256];
        memcpy(cnt_b2_saved, cnt_local, sizeof(cnt_local));
        // 保存 B2 的指针基址 (相对 seg_a_temp),供 Step 3 读取
        uint ptr_b2_start[256];
        memcpy(ptr_b2_start, ptr_local, sizeof(ptr_local));

        // 2.3 Scatter B2
        {
            uint p[256];
            memcpy(p, ptr_local, sizeof(p));
            uint8_t* src = seg_b_in;
            uint8_t* dst = seg_a_temp;

            for (int i = 0; i < count_b3; i += 16) {
                // Prefetch b
                _mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 3], _MM_HINT_NTA);

                #pragma GCC unroll 16
                for (int j = 0; j < 16; j++) {
                    if (i + j >= count_b3) break;
                    // Read [B0, B1, B2]
                    uint32_t val = *(uint32_t*)(src + (i + j) * 3); // Reads 4 bytes, safe due to padding
                    uint8_t key = (val >> 16) & 0xFF; // B2
                    
                    // Write [B0, B1]
                    store2(dst + p[key], val); // val's low 16 bits are B0, B1
                    p[key] += 2;
                }
            }
        }
        
        // Pass 2 完成,b 的当前段数据已无用。
        // Pass 3 将输出到 b 的当前段 (复用内存)。
        uint8_t* seg_b_temp = seg_b_in;

        // =====================================================
        // Loop over B2 Buckets
        // =====================================================
        for (int i_b2 = 0; i_b2 < 256; i_b2++) {
            int count_b2 = cnt_b2_saved[i_b2];
            if (count_b2 == 0) continue;

            uint8_t* seg_a_in = seg_a_temp + ptr_b2_start[i_b2];
            
            // =================================================
            // Step 3: Pass 3 (Partition by B1)
            // Read a (2B: B0,B1) -> Write b (1B: B0)
            // =================================================
            
            // 3.1 统计 B1
            // 此时只需用一个小数组,因为在 B2 循环内,重置开销 256 次 * 256 * 256 太大?
            // 不,总循环次数是 B3(256) * B2(256)。
            // 统计 B1 需要清空 cnt。
            // 优化:只有 256 个 bin,memset很快。
            memset(cnt_local, 0, sizeof(cnt_local));
            for (int i = 0; i < count_b2; i++) {
                // a 存储 [B0, B1]
                uint16_t val = *(uint16_t*)(seg_a_in + i * 2);
                cnt_local[val >> 8]++; // High byte is B1
            }

            // 3.2 计算 Offset (for b temp)
            // 复用 b 的空间。由于 b 原本分配给 B3 bucket 的空间是 3*count_b3。
            // 现在我们只处理其中一个 B2 bucket,且只写 1 字节/元素。
            // 我们可以直接从 seg_b_temp 开始写吗?
            // 我们需要为每个 B2 bucket 分配 b 的空间吗?
            // 不需要,我们可以复用整个 seg_b_temp 区域,只要不同 B2 bucket 之间不冲突。
            // 但是这里是一个串行过程。我们在处理 B2 bucket i 时,
            // 我们可以使用 seg_b_temp 的开头作为临时空间。
            // 只要大小够 (1 * count_b2 + padding),绝对够。
            
            tmp_offset = 0;
            for (int k = 0; k < 256; k++) {
                ptr_local[k] = tmp_offset;
                tmp_offset += cnt_local[k] + 32; // 1 byte element + padding
            }
            
            uint cnt_b1_saved[256];
            memcpy(cnt_b1_saved, cnt_local, sizeof(cnt_local));
            uint ptr_b1_start[256];
            memcpy(ptr_b1_start, ptr_local, sizeof(ptr_local));

            // 3.3 Scatter B1
            {
                uint p[256];
                memcpy(p, ptr_local, sizeof(p));
                uint8_t* src = seg_a_in;
                uint8_t* dst = seg_b_temp;

                for (int i = 0; i < count_b2; i += 16) {
                    _mm_prefetch((const char*)&src[(i + PREFETCH_DIST) * 2], _MM_HINT_NTA);
                    
                    #pragma GCC unroll 16
                    for (int j = 0; j < 16; j++) {
                        if (i + j >= count_b2) break;
                        uint16_t val = *(uint16_t*)(src + (i + j) * 2);
                        uint8_t key = val >> 8; // B1
                        store1(dst + p[key], val & 0xFF); // Store B0
                        p[key]++;
                    }
                }
            }

            // =================================================
            // Step 4: Pass 4 (Generate Final Output)
            // Loop B1 -> Sort B0 -> Write Full Int
            // =================================================
            
            // 当前的高位部分
            uint32_t high_parts = (i_b3 << 24) | (i_b2 << 16); // | (i_b1 << 8) inside loop

            for (int i_b1 = 0; i_b1 < 256; i_b1++) {
                int count_b1 = cnt_b1_saved[i_b1];
                if (count_b1 == 0) continue;

                uint8_t* b0_src = seg_b_temp + ptr_b1_start[i_b1];
                uint* dst_final = a + a_final_offset;
                
                uint32_t current_high = high_parts | (i_b1 << 8);

                // 我们需要对 b0_src 中的 count_b1 个 B0 进行排序
                // 如果数量少,直接排序
                // 如果数量多,计数排序
                if (count_b1 < 64) {
                    // Small sort
                    // Copy to stack buffer to avoid modifying source if strict, but here src is temp.
                    // Sort in place?
                    // b0_src point to b array.
                    // bubble sort or std::sort
                    // std::sort on uint8 is fast
                    std::sort(b0_src, b0_src + count_b1);
                    
                    for(int k=0; k<count_b1; k++) {
                        dst_final[k] = current_high | b0_src[k];
                    }
                } else {
                    // Counting Sort for B0
                    uint cnt_b0[256] = {0};
                    for(int k=0; k<count_b1; k++) cnt_b0[b0_src[k]]++;
                    
                    int k = 0;
                    for(int v=0; v<256; v++) {
                        int c = cnt_b0[v];
                        if (c == 0) continue;
                        uint32_t val = current_high | v;
                        while(c--) {
                            dst_final[k++] = val;
                        }
                    }
                }

                a_final_offset += count_b1;
            }
        }
    }
    
    free(b);
}

CompilationN/AN/ACompile OKScore: N/A

Testcase #1840.106 ms667 MB + 648 KBWrong AnswerScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-01-22 08:01:05 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠