提交记录 28825


用户 题目 状态 得分 用时 内存 语言 代码长度
platelet 1001. 测测你的排序 Runtime Error 0 5.91 us 12 KB C++17 19.20 KB
提交时间 评测时间
2026-01-19 00:14:06 2026-01-19 00:14:12
#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>

template <int STRIDE>
std::array<int, 256> get_population_upper_bounds(uint8_t* A, int N, int budget, int sample_size) {
    std::array<int, 256> results;
    results.fill(0);

    size_t n = (size_t)sample_size;
    
    // 1. Calculate Cache Line Alignment Info
    uintptr_t start_addr = (uintptr_t)A;
    // Align up to next 64-byte boundary
    uintptr_t aligned_start = (start_addr + 63) & ~63ULL;
    // Align down end address
    uintptr_t end_addr_exclusive = start_addr + (size_t)N * STRIDE;
    uintptr_t aligned_end = end_addr_exclusive & ~63ULL;

    if (aligned_end <= aligned_start) {
        // Not enough data for aligned sampling, fallback to full scan
         for (int i = 0; i < N; ++i) results[A[(size_t)i * STRIDE]]++;
         return results;
    }

    size_t num_lines = (aligned_end - aligned_start) / 64;
    
    // Safety check: if no cache lines available, fallback to full scan
    if (num_lines == 0) {
        for (int i = 0; i < N; ++i) results[A[(size_t)i * STRIDE]]++;
        return results;
    }
    
    // Determine the offset pattern for the FIRST aligned block
    // We need (aligned_start + off) % STRIDE == start_addr % STRIDE
    size_t diff = aligned_start - start_addr;
    int base_offset = (STRIDE - (diff % STRIDE)) % STRIDE;

    // 2. Adjust Sample Size to be in terms of cache lines
    // Average items per line
    int items_per_line_approx = 64 / STRIDE;
    size_t lines_to_sample = (n + items_per_line_approx - 1) / items_per_line_approx;
    
    // Cap at available lines
    if (lines_to_sample > num_lines) lines_to_sample = num_lines;
    
    // Recalculate actual n for statistics
    // (This is an approximation if stride=3 because different lines have different counts, 
    // but for large N it converges)
    // For Stride=4, count is always 16.
    // For Stride=3, count is 21 or 22 (avg 21.33).
    // optimizing: just counting actually sampled items is better, 
    // but user code expects 'n' to be passed to math formulas.
    // We will count exact sampled items in the loop.
    
    // 3. Sparse Sampling of Cache Lines
    std::array<int, 256> sample_counts;
    sample_counts.fill(0);
    size_t actual_sampled_count = 0;
    
    static thread_local std::mt19937 gen;
    const uint32_t mod_blocks = (uint32_t)num_lines;
    const uint64_t mu = ((unsigned __int128)1 << 64) / mod_blocks;

    for (size_t i = 0; i < lines_to_sample; ++i) {
        // Random Block Index
        uint32_t x = gen();
        uint64_t q = ((unsigned __int128)x * mu) >> 64;
        uint32_t blk_idx = x - q * mod_blocks;
        if (blk_idx >= mod_blocks) blk_idx -= mod_blocks;

        uint8_t* p_line = (uint8_t*)(aligned_start + (size_t)blk_idx * 64);

        // Calculate offset for this specific block
        // Block addr changes by 64. 64 % 3 = 1. 64 % 4 = 0.
        // offset_new = (offset_old - delta_addr) % STRIDE
        // delta_addr = blk_idx * 64
        int current_offset;
        if constexpr (STRIDE == 4) {
             current_offset = base_offset;
        } else {
             // STRIDE == 3
             // shift = (blk_idx) % 3
             // off = (base - shift) % 3
             int shift = blk_idx % 3;
             current_offset = base_offset - shift;
             if (current_offset < 0) current_offset += 3;
        }

        // Fetch fixed number of items per cache line
        // Safe max index check:
        // Stride 4: offset max 3. count 16. max idx = 3 + 15*4 = 63 < 64.
        // Stride 3: offset max 2. count 21. max idx = 2 + 20*3 = 62 < 64.
        const int ITEMS = 64 / STRIDE;
        
        #pragma GCC unroll 21
        for (int k = 0; k < ITEMS; ++k) {
            sample_counts[p_line[current_offset + k * STRIDE]]++;
        }
        actual_sampled_count += ITEMS;
    }
    
    n = actual_sampled_count;

    // 2. 二分查找最优 Z 值
    // 目标:找到最大的 Z,使得 Sum(UpperBounds(Z)) <= Budget
    double low_z = 0.0;
    double high_z = 10.0;
    double best_z = 0.0;
    
    double n_double = (double)n;
    double N_double = (double)N;
    double fpc = (double)(N - n) / (double)(N - 1);
    if (fpc < 0) fpc = 0; // Safety

    // 预计算 p_hat 以加速循环
    std::array<double, 256> p_hats;
    for(int i=0; i<256; ++i) p_hats[i] = sample_counts[i] / n_double;

    for (int iter = 0; iter < 20; ++iter) {
        double mid_z = (low_z + high_z) * 0.5;
        double z2 = mid_z * mid_z;
        double div_factor = 1.0 / (1.0 + z2 / n_double);
        
        long long current_sum = 0;
        
        for (int i = 0; i < 256; ++i) {
            double p_hat = p_hats[i];
            
            // Wilson Score Interval
            double term1 = p_hat + z2 / (2.0 * n_double);
            double variance_term = (p_hat * (1.0 - p_hat) / n_double) * fpc;
            if (variance_term < 0) variance_term = 0;
            double term2 = mid_z * std::sqrt(variance_term + z2 / (4.0 * n_double * n_double));
            double p_upper = (term1 + term2) * div_factor;
            
            int limit = (int)std::ceil(N_double * p_upper);
            current_sum += limit;
        }

        if (current_sum <= budget) {
            best_z = mid_z;
            low_z = mid_z;
        } else {
            high_z = mid_z;
        }
    }

    // 3. 使用最佳 Z 生成最终结果
    double z = best_z;
    double z2 = z * z;
    double div_factor = 1.0 / (1.0 + z2 / n_double);
    
    for (int i = 0; i < 256; ++i) {
        double p_hat = p_hats[i];
        double term1 = p_hat + z2 / (2.0 * n_double);
        double variance_term = (p_hat * (1.0 - p_hat) / n_double) * fpc;
        if (variance_term < 0) variance_term = 0;
        double term2 = z * std::sqrt(variance_term + z2 / (4.0 * n_double * n_double));
        double p_upper = (term1 + term2) * div_factor;
        
        int limit = (int)std::ceil(N_double * p_upper);
        if (limit > N) limit = N;
        results[i] = limit;
    }

    return results;
}

using namespace std;

const int n = 1e8;
const int PREFETCH_DIST = 64; // 元素个数:Pass1(256B), Pass2(192B), Pass3/4(128B)

// 辅助函数:向地址 p 写入 3 字节 (利用 uint32 覆盖写,需保证 buffer 有 padding)
// Input val: [B0, B1, B2, X] (Little Endian) -> Writes B0, B1, B2
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
    *(uint32_t*)p = val;
}

// 辅助函数:向地址 p 写入 2 字节
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
    *(uint16_t*)p = val;
}

void foobar(uint* a, int __n) {
    // ---------------------------------------------------------
    // Pass 1: Global MSD (Partition by B3)
    // Read: a (4 bytes) -> Write: b (3 bytes: [B0, B1, B2])
    // ---------------------------------------------------------
    
    uint cnt_global[256];
    // memset(cnt_global, 0, sizeof(cnt_global)); // No longer needed beforehand
    
    // 1.1 统计 B3 (Sampling & Upper Bounds)
    // Budget set to n * 1.47 (47% over-provisioning)
    int budget = (int)(n * 1.47);
    int sample_size = 15000;
    // A 是 (uint8_t*)a + 3 (B3 byte), stride = 4
    auto bounds = get_population_upper_bounds<4>((uint8_t*)a + 3, n, budget, sample_size);
}

void sort(uint* a, int __n) {}

// void sort(uint* a, int __n) {
//     // ---------------------------------------------------------
//     // Pass 1: Global MSD (Partition by B3)
//     // Read: a (4 bytes) -> Write: b (3 bytes: [B0, B1, B2])
//     // ---------------------------------------------------------
    
//     uint cnt_global[256];
//     // memset(cnt_global, 0, sizeof(cnt_global)); // No longer needed beforehand
    
//     // 1.1 统计 B3 (Sampling & Upper Bounds)
//     // Budget set to n * 1.47 (47% over-provisioning)
//     int budget = (int)(n * 1.47);
//     int sample_size = 15000;
//     // A 是 (uint8_t*)a + 3 (B3 byte), stride = 4
//     auto bounds = get_population_upper_bounds<4>((uint8_t*)a + 3, n, budget, sample_size);
    
//     // 1.2 计算 B3 Offset (Bytes in b)
//     // 增加 4 字节 Padding 以安全使用 store3
//     uint ptr_global[256];
//     uint32_t offset_b3 = 0;
//     for (int i = 0; i < 256; i++) {
//         ptr_global[i] = offset_b3;
//         offset_b3 += bounds[i] * 3; // Use Upper Bound
//     }
    
//     // 申请 b 数组 
//     uint8_t* b = (uint8_t*)malloc(budget * 3);
    
//     // 1.3 执行 Pass 1 分发
//     {
//         uint* __restrict__ src = a;
//         uint8_t* __restrict__ dst = b;
//         uint p[256];
//         memcpy(p, ptr_global, sizeof(p));

//         for (int i = 0; i < n; i += 16) {
//             _mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
            
//             #pragma GCC unroll 16
//             for (int j = 0; j < 16; j++) {
//                 uint val = src[i + j];
//                 uint8_t k = val >> 24; 
//                 store3(dst + p[k], val); 
//                 p[k] += 3;
//             }
//         }
        
//         // Reconstruct exact counts from pointer progress
//         for(int k=0; k<256; ++k) {
//             cnt_global[k] = (p[k] - ptr_global[k]) / 3;
//         }
//     }

//     // ---------------------------------------------------------
//     // 分段处理:遍历 B3 的每一个 Bucket
//     // ---------------------------------------------------------
    
//     uint8_t* a_u8 = (uint8_t*)a;

//     // 局部直方图缓存
//     uint cnt0[256];
//     uint cnt1[256];
//     uint cnt2[256];
//     uint ptr0[256]; // Pass 2 (Write a) pointers
//     uint ptr1[256]; // Pass 3 (Write b) pointers
//     uint ptr2[256]; // Pass 4 (Write a Final) pointers

//     uint32_t b_offset_start = 0; // byte offset in b (Reading)
//     uint32_t a_offset_start = 0; // index offset in a (Writing Final)

//     for (int i_b3 = 0; i_b3 < 256; i_b3++) {
//         int count = cnt_global[i_b3];
//         if (count == 0) continue;

//         uint8_t* seg_b_in = b + ptr_global[i_b3]; // Pass 1 Output -> Pass 2 Input
        
//         // Pass 2 Output (Temporary in a)
//         // 使用与 Final Output 相同的基地址,因为 4N > 2N + Padding,安全
//         uint8_t* seg_a_temp = a_u8 + (a_offset_start * 4);
        
//         // -----------------------------------------------------
//         // Step 2: Use Sampling for Pointers & Init Counters
//         // -----------------------------------------------------
        
//         // Use sampling to estimate B0 upper bounds for Pass 2 (ptr0)
//         // Budget logic: 
//         // Pass 2 writes 2 bytes per item into a space reserved for 4 bytes per item (Final Output array).
//         // Effectively, we have capacity for count * 2 items of size 2 bytes.
//         int budget_pass2 = count * 2;
//         int sample_size = 5000;
        
//         // get_population_upper_bounds will default to stride 3 because we pass stride=3?
//         // Wait, function signature is (uint8_t* A, int N, int budget, int stride, int sample_size)
//         // seg_b_in has data [B0, B1, B2]... So stride=3, offset=0 is B0.
//         auto bounds = get_population_upper_bounds<3>(seg_b_in, count, budget_pass2, sample_size);
        
//         // Calculate ptr0 (bucket start offsets in 'a')
//         uint32_t tmp = 0;
//         for(int k=0; k<256; k++) { 
//             ptr0[k] = tmp; 
//             tmp += bounds[k] * 2;
//         }
        
//         // Initialize cnt1, cnt2 for exact counting during Pass 2
//         memset(cnt1, 0, sizeof(cnt1));
//         memset(cnt2, 0, sizeof(cnt2));
//         // cnt0 will be recovered from pointer progress after Pass 2

//         // -----------------------------------------------------
//         // Pass 2: LSD Step 1 (Key B0)
//         // Read b (3B: B0,B1,B2) -> Write a (2B: B1,B2)
//         // AND compute exact histograms for B1, B2
//         // -----------------------------------------------------
//         {
//             uint p[256];
//             memcpy(p, ptr0, sizeof(p));
//             uint8_t* src = seg_b_in;
//             uint8_t* dst = seg_a_temp;

//             int k = 0;
//             for (; k <= count - 16; k += 16) {
//                 _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 3), _MM_HINT_T0);
//                 #pragma GCC unroll 16
//                 for (int j = 0; j < 16; j++) {
//                     uint32_t val = *(uint32_t*)(src + (k + j) * 3);
//                     uint8_t key = val & 0xFF; // B0
                    
//                     // Count B1 (val >> 8 & 0xFF) and B2 (val >> 16 & 0xFF)
//                     cnt1[(val >> 8) & 0xFF]++;
//                     cnt2[(val >> 16) & 0xFF]++;
                    
//                     // Store [B1, B2] (val >> 8)
//                     store2(dst + p[key], val >> 8);
//                     p[key] += 2;
//                 }
//             }
//             for (; k < count; k++) {
//                 uint32_t val = *(uint32_t*)(src + k * 3);
//                 uint8_t key = val & 0xFF;
                
//                 cnt1[(val >> 8) & 0xFF]++;
//                 cnt2[(val >> 16) & 0xFF]++;
                
//                 store2(dst + p[key], val >> 8);
//                 p[key] += 2;
//             }
            
//             // Post-Pass 2: Recover exact cnt0 and compute ptr1, ptr2
//             bool retry = false;
//             uint32_t tmp1 = 0;
//             uint32_t tmp2 = 0;
//             for(int k=0; k<256; k++) {
//                 // Recover cnt0 from pointer progress (p - ptr0) / 2
//                 cnt0[k] = (p[k] - ptr0[k]) >> 1; 
//                 if (cnt0[k] > bounds[k]) retry = true;
                
//                 ptr1[k] = tmp1; 
//                 tmp1 += cnt1[k] * 2 + 4; // Padding
                
//                 ptr2[k] = tmp2; 
//                 tmp2 += cnt2[k];         // No Padding (Dense)
//             }

//             if (retry) {
//                 // Recalculate ptr0 with exact counts
//                 uint32_t tmp = 0;
//                 for(int k=0; k<256; k++) {
//                     ptr0[k] = tmp;
//                     tmp += cnt0[k] * 2 + 4; 
//                 }

//                 // Rerun Pass 2 (Distribution Only, no counting)
//                 uint p_retry[256];
//                 memcpy(p_retry, ptr0, sizeof(p_retry));
//                 uint8_t* src = seg_b_in;
//                 uint8_t* dst = seg_a_temp;

//                 int k = 0;
//                 for (; k <= count - 16; k += 16) {
//                     _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 3), _MM_HINT_T0);
//                     #pragma GCC unroll 16
//                     for (int j = 0; j < 16; j++) {
//                         uint32_t val = *(uint32_t*)(src + (k + j) * 3);
//                         uint8_t key = val & 0xFF; // B0
//                         store2(dst + p_retry[key], val >> 8);
//                         p_retry[key] += 2;
//                     }
//                 }
//                 for (; k < count; k++) {
//                     uint32_t val = *(uint32_t*)(src + k * 3);
//                     uint8_t key = val & 0xFF;
//                     store2(dst + p_retry[key], val >> 8);
//                     p_retry[key] += 2;
//                 }
//             }
//         }

//         // -----------------------------------------------------
//         // Pass 3: LSD Step 2 (Key B1)
//         // Read a (2B: B1,B2) -> Write b (2B: B0,B2)
//         // Iterate B0 buckets to restore B0
//         // -----------------------------------------------------
//         {
//             // Reuse seg_b_in for output. Size is 3N > 2N, safe.
//             uint8_t* dst_base = seg_b_in; 
            
//             uint p[256];
//             memcpy(p, ptr1, sizeof(p));

//             // Iterate over B0 buckets (Pass 2 output)
//             for (int b0 = 0; b0 < 256; b0++) {
//                 int c = cnt0[b0];
//                 if (c == 0) continue;

//                 uint8_t* src = seg_a_temp + ptr0[b0];
//                 uint16_t val_b0 = b0; 

//                 int k = 0;
//                 for (; k <= c - 21; k += 21) {
//                     _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_T0);
//                     #pragma GCC unroll 21
//                     for (int j = 0; j < 21; j++) {
//                         // Read [B1, B2]
//                         uint16_t val = *(uint16_t*)(src + (k + j) * 2);
//                         uint8_t key = val & 0xFF; // B1
                        
//                         // Construct [B0, B2]
//                         // We store B0 at low byte, B2 at high byte.
//                         // val & 0xFF00 is (B2 << 8).
//                         uint16_t new_val = val_b0 | (val & 0xFF00);
                        
//                         store2(dst_base + p[key], new_val);
//                         p[key] += 2;
//                     }
//                 }
//                 for (; k < c; k++) {
//                     uint16_t val = *(uint16_t*)(src + k * 2);
//                     uint8_t key = val & 0xFF;
//                     store2(dst_base + p[key], val_b0 | (val & 0xFF00));
//                     p[key] += 2;
//                 }
//             }
//         }

//         // -----------------------------------------------------
//         // Pass 4: LSD Step 3 (Key B2) & Finalize
//         // Read b (2B: B0,B2) -> Write a (4B: Full)
//         // Iterate B1 buckets to restore B1
//         // -----------------------------------------------------
//         {
//             uint* dst_base = a + a_offset_start;
//             uint8_t* src_base = seg_b_in; 

//             uint p[256];
//             memcpy(p, ptr2, sizeof(p));

//             uint32_t val_b3_shifted = i_b3 << 24;

//             // Iterate over B1 buckets (Pass 3 output)
//             for (int b1 = 0; b1 < 256; b1++) {
//                 int c = cnt1[b1];
//                 if (c == 0) continue;

//                 uint8_t* src = src_base + ptr1[b1];
                
//                 // Common High bits: B3 | B1 << 8
//                 // We will OR this with PDEP result
//                 uint32_t common_bits = val_b3_shifted | (b1 << 8);

//                 int k = 0;
//                 for (; k <= c - 32; k += 32) {
//                     _mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_T0);
//                     #pragma GCC unroll 32
//                     for (int j = 0; j < 32; j++) {
//                         // Read [B0, B2]
//                         uint32_t val = *(uint16_t*)(src + (k + j) * 2);
                        
//                         // Key is B2 (High byte of val)
//                         uint8_t key = val >> 8;
                        
//                         // Reconstruct:
//                         // val has bits: 0-7 (B0), 8-15 (B2)
//                         // Target: 0-7 (B0), 16-23 (B2)
//                         // PDEP Mask: 0x00FF00FF (Deposit val bits 0-7 to 0-7, 8-15 to 16-23)
//                         uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
                        
//                         dst_base[p[key]] = common_bits | scattered;
//                         p[key]++;
//                     }
//                 }
//                 for (; k < c; k++) {
//                     uint32_t val = *(uint16_t*)(src + k * 2);
//                     uint8_t key = val >> 8;
//                     uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
//                     dst_base[p[key]] = common_bits | scattered;
//                     p[key]++;
//                 }
//             }
//         }

//         // Update offsets
//         a_offset_start += count;
//     }
// }

CompilationN/AN/ACompile OKScore: N/A

Testcase #15.91 us12 KBRuntime ErrorScore: 0


Judge Duck Online | 评测鸭在线
Server Time: 2026-02-06 22:17:55 | Loaded in 1 ms | Server Status
个人娱乐项目,仅供学习交流使用 | 捐赠