#pragma GCC target("avx2,bmi,bmi2,popcnt,lzcnt")
#include <bits/stdc++.h>
#include <immintrin.h>
std::array<int, 256> get_population_upper_bounds(uint8_t* A, int N, int budget, int stride, int sample_size) {
std::array<int, 256> results;
results.fill(0);
// 0. 基础检查
if (sample_size >= N || N < 1000) {
for (int i = 0; i < N; ++i) results[A[(size_t)i * stride]]++;
return results;
}
// 1. 采样 (Sparse Random Sampling with Barrett Reduction)
size_t n = (size_t)sample_size;
std::vector<int> sample_counts(256, 0);
static std::mt19937 gen;
// Barrett reducing 优化取模: r = x % n
// Precompute mu = 2^64 / N
const uint32_t mod_n = (uint32_t)N;
const uint64_t mu = ((unsigned __int128)1 << 64) / mod_n;
for (size_t i = 0; i < n; ++i) {
uint32_t x = gen();
// fast mod using barrett reduction
uint64_t q = ((unsigned __int128)x * mu) >> 64;
uint32_t r = x - q * mod_n;
if (r >= mod_n) r -= mod_n; // Correction
sample_counts[A[(size_t)r * stride]]++;
}
// 2. 二分查找最优 Z 值
// 目标:找到最大的 Z,使得 Sum(UpperBounds(Z)) <= Budget
double low_z = 2.75;
double high_z = 3.25;
double best_z = 0.0;
double n_double = (double)n;
double N_double = (double)N;
double fpc = (double)(N - n) / (double)(N - 1);
// 预计算 p_hat 以加速循环
std::array<double, 256> p_hats;
for(int i=0; i<256; ++i) p_hats[i] = sample_counts[i] / n_double;
for (int iter = 0; iter < 4; ++iter) {
double mid_z = (low_z + high_z) * 0.5;
double z2 = mid_z * mid_z;
double div_factor = 1.0 / (1.0 + z2 / n_double);
long long current_sum = 0;
for (int i = 0; i < 256; ++i) {
double p_hat = p_hats[i];
// Wilson Score Interval Upper Bound (with FPC scaling on variance)
double term1 = p_hat + z2 / (2.0 * n_double);
double variance_term = (p_hat * (1.0 - p_hat) / n_double) * fpc;
double term2 = mid_z * std::sqrt(variance_term + z2 / (4.0 * n_double * n_double));
double p_upper = (term1 + term2) * div_factor;
int limit = (int)std::ceil(N_double * p_upper);
// 这里不 clamp 到 N,为了让 sum 能反映真实的膨胀程度,便于二分
// 但如果最后计算结果要用,必须 clamp
current_sum += limit;
}
if (current_sum <= budget) {
best_z = mid_z; // 可行,尝试更大的 Z (更宽的区间,更安全)
low_z = mid_z;
} else {
high_z = mid_z; // 超预算,减小 Z
}
}
// 3. 使用最佳 Z 生成最终结果
double z = best_z;
double z2 = z * z;
double div_factor = 1.0 / (1.0 + z2 / n_double);
for (int i = 0; i < 256; ++i) {
double p_hat = p_hats[i];
double term1 = p_hat + z2 / (2.0 * n_double);
double variance_term = (p_hat * (1.0 - p_hat) / n_double) * fpc;
double term2 = z * std::sqrt(variance_term + z2 / (4.0 * n_double * n_double));
double p_upper = (term1 + term2) * div_factor;
int limit = (int)std::ceil(N_double * p_upper);
if (limit > N) limit = N;
results[i] = limit;
}
return results;
}
using namespace std;
const int n = 1e8;
const int PREFETCH_DIST = 64; // 元素个数:Pass1(256B), Pass2(192B), Pass3/4(128B)
// 辅助函数:向地址 p 写入 3 字节 (利用 uint32 覆盖写,需保证 buffer 有 padding)
// Input val: [B0, B1, B2, X] (Little Endian) -> Writes B0, B1, B2
inline void store3(uint8_t* __restrict__ p, uint32_t val) {
*(uint32_t*)p = val;
}
// 辅助函数:向地址 p 写入 2 字节
inline void store2(uint8_t* __restrict__ p, uint16_t val) {
*(uint16_t*)p = val;
}
void sort(uint* a, int __n) {
// ---------------------------------------------------------
// Pass 1: Global MSD (Partition by B3)
// Read: a (4 bytes) -> Write: b (3 bytes: [B0, B1, B2])
// ---------------------------------------------------------
uint cnt_global[256];
// memset(cnt_global, 0, sizeof(cnt_global)); // No longer needed beforehand
// 1.1 统计 B3 (Sampling & Upper Bounds)
// Budget set to n * 1.47 (47% over-provisioning)
int budget = (int)(n * 1.47);
int sample_size = 15000;
// A 是 (uint8_t*)a + 3 (B3 byte), stride = 4
auto bounds = get_population_upper_bounds((uint8_t*)a + 3, n, budget, 4, sample_size);
// 1.2 计算 B3 Offset (Bytes in b)
// 增加 4 字节 Padding 以安全使用 store3
uint ptr_global[256];
uint32_t offset_b3 = 0;
for (int i = 0; i < 256; i++) {
ptr_global[i] = offset_b3;
offset_b3 += bounds[i] * 3; // Use Upper Bound
}
// 申请 b 数组
uint8_t* b = (uint8_t*)malloc(budget * 3);
// 1.3 执行 Pass 1 分发
{
uint* __restrict__ src = a;
uint8_t* __restrict__ dst = b;
uint p[256];
memcpy(p, ptr_global, sizeof(p));
for (int i = 0; i < n; i += 16) {
_mm_prefetch((const char*)&src[i + PREFETCH_DIST], _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint val = src[i + j];
uint8_t k = val >> 24;
store3(dst + p[k], val);
p[k] += 3;
}
}
// Reconstruct exact counts from pointer progress
for(int k=0; k<256; ++k) {
cnt_global[k] = (p[k] - ptr_global[k]) / 3;
}
}
// ---------------------------------------------------------
// 分段处理:遍历 B3 的每一个 Bucket
// ---------------------------------------------------------
uint8_t* a_u8 = (uint8_t*)a;
// 局部直方图缓存
uint cnt0[256];
uint cnt1[256];
uint cnt2[256];
uint ptr0[256]; // Pass 2 (Write a) pointers
uint ptr1[256]; // Pass 3 (Write b) pointers
uint ptr2[256]; // Pass 4 (Write a Final) pointers
uint32_t b_offset_start = 0; // byte offset in b (Reading)
uint32_t a_offset_start = 0; // index offset in a (Writing Final)
for (int i_b3 = 0; i_b3 < 256; i_b3++) {
int count = cnt_global[i_b3];
if (count == 0) continue;
uint8_t* seg_b_in = b + ptr_global[i_b3]; // Pass 1 Output -> Pass 2 Input
// Pass 2 Output (Temporary in a)
// 使用与 Final Output 相同的基地址,因为 4N > 2N + Padding,安全
uint8_t* seg_a_temp = a_u8 + (a_offset_start * 4);
// -----------------------------------------------------
// Step 2: Local Histogram Scan
// -----------------------------------------------------
memset(cnt0, 0, sizeof(cnt0));
memset(cnt1, 0, sizeof(cnt1));
memset(cnt2, 0, sizeof(cnt2));
int i = 0;
for (; i <= count - 4; i += 4) {
// Unroll 4
uint32_t v0 = *(uint32_t*)(seg_b_in + i * 3);
cnt0[v0 & 0xFF]++;
cnt1[(v0 >> 8) & 0xFF]++;
cnt2[(v0 >> 16) & 0xFF]++;
uint32_t v1 = *(uint32_t*)(seg_b_in + (i + 1) * 3);
cnt0[v1 & 0xFF]++;
cnt1[(v1 >> 8) & 0xFF]++;
cnt2[(v1 >> 16) & 0xFF]++;
uint32_t v2 = *(uint32_t*)(seg_b_in + (i + 2) * 3);
cnt0[v2 & 0xFF]++;
cnt1[(v2 >> 8) & 0xFF]++;
cnt2[(v2 >> 16) & 0xFF]++;
uint32_t v3 = *(uint32_t*)(seg_b_in + (i + 3) * 3);
cnt0[v3 & 0xFF]++;
cnt1[(v3 >> 8) & 0xFF]++;
cnt2[(v3 >> 16) & 0xFF]++;
}
for (; i < count; i++) {
uint32_t v = *(uint32_t*)(seg_b_in + i * 3);
cnt0[v & 0xFF]++;
cnt1[(v >> 8) & 0xFF]++;
cnt2[(v >> 16) & 0xFF]++;
}
// Calculate Pointers with Padding for intermediate steps
uint32_t tmp = 0;
for(int k=0; k<256; k++) { ptr0[k] = tmp; tmp += cnt0[k] * 2 + 4; } // Pass 2 (write a): Need Padding
tmp = 0;
for(int k=0; k<256; k++) { ptr1[k] = tmp; tmp += cnt1[k] * 2 + 4; } // Pass 3 (write b): Need Padding
tmp = 0;
for(int k=0; k<256; k++) { ptr2[k] = tmp; tmp += cnt2[k]; } // Pass 4 (Final): NO Padding (Dense)
// -----------------------------------------------------
// Pass 2: LSD Step 1 (Key B0)
// Read b (3B: B0,B1,B2) -> Write a (2B: B1,B2)
// -----------------------------------------------------
{
uint p[256];
memcpy(p, ptr0, sizeof(p));
uint8_t* src = seg_b_in;
uint8_t* dst = seg_a_temp;
int k = 0;
for (; k <= count - 16; k += 16) {
_mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 3), _MM_HINT_NTA);
#pragma GCC unroll 16
for (int j = 0; j < 16; j++) {
uint32_t val = *(uint32_t*)(src + (k + j) * 3);
uint8_t key = val & 0xFF; // B0
// Store [B1, B2] (val >> 8)
store2(dst + p[key], val >> 8);
p[key] += 2;
}
}
for (; k < count; k++) {
uint32_t val = *(uint32_t*)(src + k * 3);
uint8_t key = val & 0xFF;
store2(dst + p[key], val >> 8);
p[key] += 2;
}
}
// -----------------------------------------------------
// Pass 3: LSD Step 2 (Key B1)
// Read a (2B: B1,B2) -> Write b (2B: B0,B2)
// Iterate B0 buckets to restore B0
// -----------------------------------------------------
{
// Reuse seg_b_in for output. Size is 3N > 2N, safe.
uint8_t* dst_base = seg_b_in;
uint p[256];
memcpy(p, ptr1, sizeof(p));
// Iterate over B0 buckets (Pass 2 output)
for (int b0 = 0; b0 < 256; b0++) {
int c = cnt0[b0];
if (c == 0) continue;
uint8_t* src = seg_a_temp + ptr0[b0];
uint16_t val_b0 = b0;
int k = 0;
for (; k <= c - 21; k += 21) {
_mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_NTA);
#pragma GCC unroll 21
for (int j = 0; j < 21; j++) {
// Read [B1, B2]
uint16_t val = *(uint16_t*)(src + (k + j) * 2);
uint8_t key = val & 0xFF; // B1
// Construct [B0, B2]
// We store B0 at low byte, B2 at high byte.
// val & 0xFF00 is (B2 << 8).
uint16_t new_val = val_b0 | (val & 0xFF00);
store2(dst_base + p[key], new_val);
p[key] += 2;
}
}
for (; k < c; k++) {
uint16_t val = *(uint16_t*)(src + k * 2);
uint8_t key = val & 0xFF;
store2(dst_base + p[key], val_b0 | (val & 0xFF00));
p[key] += 2;
}
}
}
// -----------------------------------------------------
// Pass 4: LSD Step 3 (Key B2) & Finalize
// Read b (2B: B0,B2) -> Write a (4B: Full)
// Iterate B1 buckets to restore B1
// -----------------------------------------------------
{
uint* dst_base = a + a_offset_start;
uint8_t* src_base = seg_b_in;
uint p[256];
memcpy(p, ptr2, sizeof(p));
uint32_t val_b3_shifted = i_b3 << 24;
// Iterate over B1 buckets (Pass 3 output)
for (int b1 = 0; b1 < 256; b1++) {
int c = cnt1[b1];
if (c == 0) continue;
uint8_t* src = src_base + ptr1[b1];
// Common High bits: B3 | B1 << 8
// We will OR this with PDEP result
uint32_t common_bits = val_b3_shifted | (b1 << 8);
int k = 0;
for (; k <= c - 32; k += 32) {
_mm_prefetch((const char*)(src + (k + PREFETCH_DIST) * 2), _MM_HINT_T0);
#pragma GCC unroll 32
for (int j = 0; j < 32; j++) {
// Read [B0, B2]
uint32_t val = *(uint16_t*)(src + (k + j) * 2);
// Key is B2 (High byte of val)
uint8_t key = val >> 8;
// Reconstruct:
// val has bits: 0-7 (B0), 8-15 (B2)
// Target: 0-7 (B0), 16-23 (B2)
// PDEP Mask: 0x00FF00FF (Deposit val bits 0-7 to 0-7, 8-15 to 16-23)
uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
dst_base[p[key]] = common_bits | scattered;
p[key]++;
}
}
for (; k < c; k++) {
uint32_t val = *(uint16_t*)(src + k * 2);
uint8_t key = val >> 8;
uint32_t scattered = _pdep_u32(val, 0x00FF00FF);
dst_base[p[key]] = common_bits | scattered;
p[key]++;
}
}
}
// Update offsets
a_offset_start += count;
}
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 608.537 ms | 668 MB + 648 KB | Accepted | Score: 100 | 显示更多 |