// #10604 + ortc-yjp
#include <bits/stdc++.h>
#include <arpa/inet.h>
using namespace std;
typedef struct {
unsigned addr;
unsigned char len;
char pad[3]; // Padding for memory alignment
unsigned nexthop;
} __attribute__((packed)) RoutingTableEntry;
set<unsigned> merge(const set<unsigned> &A, const set<unsigned> &B)
{
vector<unsigned> ans(A.size() + B.size());
auto it = set_intersection(A.begin(), A.end(), B.begin(), B.end(), ans.begin());
if(it != ans.begin()) return set<unsigned>(ans.begin(), it);
return set<unsigned>(ans.begin(), set_union(A.begin(), A.end(), B.begin(), B.end(), ans.begin()));
}
set<unsigned> merge(const set<unsigned> &A, unsigned B)
{
set<unsigned> ans;
if(!A.count(B)) ans = A;
ans.insert(B);
return ans;
}
struct Node
{
Node *sons[2];
unsigned nexthop;
set<unsigned> nexthops;
Node()
{
sons[0] = sons[1] = 0;
nexthop = -1;
}
Node *getSon(int x)
{
Node *&son = sons[x];
if(!son) son = new Node();
return son;
}
void pass2(unsigned inherited)
{
if(nexthop == -1) nexthop = inherited;
if(!sons[0] && !sons[1])
{
nexthops.insert(nexthop);
return;
}
if(!sons[0])
{
sons[1]->pass2(nexthop);
nexthops = merge(sons[1]->nexthops, nexthop);
}
else if(!sons[1])
{
sons[0]->pass2(nexthop);
nexthops = merge(sons[0]->nexthops, nexthop);
}
else
{
sons[0]->pass2(nexthop);
sons[1]->pass2(nexthop);
nexthops = merge(sons[0]->nexthops, sons[1]->nexthops);
}
}
void report(vector<RoutingTableEntry> &buf, unsigned addr, unsigned len, unsigned nexthop)
{
RoutingTableEntry entry;
entry.addr = addr;
entry.len = len;
entry.nexthop = nexthop;
buf.push_back(entry);
}
void pass3(unsigned inherited, vector<RoutingTableEntry> &buf, unsigned addr, unsigned len)
{
if(nexthops.count(inherited)) nexthops.clear();
else
{
inherited = *nexthops.begin();
nexthops.clear();
nexthops.insert(inherited);
report(buf, addr, len, inherited);
}
if(!sons[0] && !sons[1]) return;
if(sons[0]) sons[0]->pass3(inherited, buf, addr, len + 1);
else if(nexthop != inherited) report(buf, addr, len + 1, nexthop);
if(sons[1]) sons[1]->pass3(inherited, buf, addr | 1 << (31 - len), len + 1);
else if(nexthop != inherited) report(buf, addr | 1 << (31 - len), len + 1, nexthop);
}
};
void insert(Node *rt, RoutingTableEntry *entry)
{
unsigned addr = entry->addr;
for(int T = entry->len; T--; )
{
rt = rt->getSon(addr >> 31 & 1);
addr <<= 1;
}
rt->nexthop = entry->nexthop;
}
void compress(int n, RoutingTableEntry *tbl, int *n_out, RoutingTableEntry **tbl_out) {
for(int i = 0; i < n; i++) tbl[i].addr = ntohl(tbl[i].addr);
Node *rt = new Node();
for(int i = 0; i < n; i++) insert(rt, tbl + i);
rt->pass2(-1);
vector<RoutingTableEntry> *buf = new vector<RoutingTableEntry>();
rt->pass3(-1, *buf, 0, 0);
*n_out = buf->size();
*tbl_out = buf->data();
for(int i = 0; i < *n_out; i++) (*tbl_out)[i].addr = htonl((*tbl_out)[i].addr);
}
// ===========================
#pragma GCC optimize("Ofast")
#pragma GCC target("popcnt")
#include <stdint.h>
#include <arpa/inet.h>
#include <algorithm>
#include "router.h"
/* HashMap */
const int HASHMAP_SIZE = 1 << 16;
const uint32_t HASHMAP_MOD = 65537;
uint32_t hashmap_key[HASHMAP_SIZE];
uint32_t hashmap_next[HASHMAP_SIZE];
uint32_t hashmap_first[HASHMAP_MOD];
int hashmap_size;
inline uint32_t hashmap_get(uint32_t x) {
uint32_t hash = x % HASHMAP_MOD;
uint32_t &first_entry_id = hashmap_first[hash];
uint32_t entry_id = first_entry_id;
while (1) {
const uint32_t &key = hashmap_key[entry_id];
if (key == x) {
return entry_id;
} else if (entry_id) {
entry_id = hashmap_next[entry_id];
} else {
break;
}
}
++hashmap_size;
hashmap_key[hashmap_size] = x;
hashmap_next[hashmap_size] = first_entry_id;
first_entry_id = hashmap_size;
return hashmap_size;
}
/* 3-level tree */
const uint32_t TABLE_32_SIZE = 16384;
const uint32_t TABLE_24_SIZE = 32768;
uint32_t table_32_cnt = 0;
uint32_t table_24_cnt = 0;
uint32_t table_32[TABLE_32_SIZE][1 << 8] __attribute__((aligned(4096)));
uint32_t table_24[TABLE_24_SIZE][1 << 8] __attribute__((aligned(4096)));
uint32_t table_16[1 << 16] __attribute__((aligned(4096)));
inline void fill(uint32_t *a, int n, uint32_t val) {
while (n >= 4) {
a[0] = val;
a[1] = val;
a[2] = val;
a[3] = val;
n -= 4;
a += 4;
}
while (n) {
a[0] = val;
n--;
a++;
}
}
inline void ins(uint32_t addr, int len, uint32_t nexthop) {
if (len <= 16) {
fill(table_16 + (addr >> 16), 1u << (16 - len), nexthop);
} else if (len <= 24) {
uint32_t &t16 = table_16[addr >> 16];
addr = (addr & 65535u) >> 8;
uint32_t *tmp;
if (t16 < -TABLE_24_SIZE) {
tmp = table_24[--table_24_cnt + TABLE_24_SIZE];
fill(tmp, addr, t16);
fill(tmp + addr, 1u << (24 - len), nexthop);
fill(tmp + addr + (1u << (24 - len)), 256 - addr - (1u << (24 - len)), t16);
t16 = table_24_cnt;
} else {
tmp = table_24[t16 + TABLE_24_SIZE];
fill(tmp + addr, 1u << (24 - len), nexthop);
}
} else {
uint32_t &t16 = table_16[addr >> 16];
addr &= 65535u;
uint32_t *tmp;
if (t16 < -TABLE_24_SIZE) {
tmp = table_24[--table_24_cnt + TABLE_24_SIZE];
fill(tmp, 256, t16);
t16 = table_24_cnt;
} else {
tmp = table_24[t16 + TABLE_24_SIZE];
}
uint32_t &t24 = tmp[addr >> 8];
addr &= 255u;
if (t24 < -TABLE_32_SIZE) {
tmp = table_32[--table_32_cnt + TABLE_32_SIZE];
fill(tmp, addr, t24);
fill(tmp + addr, 1u << (32 - len), nexthop);
fill(tmp + addr + (1u << (32 - len)), 256 - addr - (1u << (32 - len)), t24);
t24 = table_32_cnt;
} else {
tmp = table_32[t24 + TABLE_32_SIZE];
fill(tmp + addr, 1u << (32 - len), nexthop);
}
}
}
/* Bit sets */
const int MAX_N_LEVEL3_POINTERS = 1000000;
uint32_t level3_bits[TABLE_32_SIZE][(1 << 8) / 32];
uint32_t level3_bit_sums[TABLE_32_SIZE][(1 << 8) / 32];
// uint32_t level3_offsets[TABLE_32_SIZE]; // ??????
uint16_t level3_pointers[MAX_N_LEVEL3_POINTERS];
uint32_t n_level3_pointers;
inline void walk_level3(uint32_t level3_index) {
uint32_t *tmp = table_32[level3_index];
uint32_t *bits = level3_bits[level3_index];
uint32_t *bit_sums = level3_bit_sums[level3_index];
// level3_offsets[level3_index] = n_level3_pointers;
uint16_t *pointers = level3_pointers + n_level3_pointers;
uint32_t cur = tmp[0];
int cnt = 0;
bit_sums[0] = n_level3_pointers;
// index 0
pointers[0] = hashmap_get(cur);
for (int i = 1; i < 1 << 8; i++) {
if (i % 32 == 0) {
bit_sums[i / 32] = cnt + n_level3_pointers;
}
if (tmp[i] != cur) {
bits[i >> 5] |= 1u << (i & 31);
cur = tmp[i];
pointers[++cnt] = hashmap_get(cur);
}
}
n_level3_pointers += cnt + 1;
}
const int MAX_N_LEVEL2_POINTERS = 2000000;
uint32_t level2_bits[TABLE_24_SIZE][(1 << 8) / 32];
uint32_t level2_bit_sums[TABLE_24_SIZE][(1 << 8) / 32];
// uint32_t level2_offsets[TABLE_24_SIZE]; // may overflow uint16 ???
uint16_t level2_pointers[MAX_N_LEVEL2_POINTERS];
uint32_t n_level2_pointers;
inline void walk_level2(uint32_t level2_index) {
uint32_t *tmp = table_24[level2_index];
uint32_t *bits = level2_bits[level2_index];
uint32_t *bit_sums = level2_bit_sums[level2_index];
uint16_t *pointers = level2_pointers + n_level2_pointers;
uint32_t cur = tmp[0];
int cnt = 0;
bit_sums[0] = n_level2_pointers;
// index 0
pointers[0] = cur >= -TABLE_32_SIZE ? cur : hashmap_get(cur);
for (int i = 1; i < 1 << 8; i++) {
if (i % 32 == 0) {
bit_sums[i / 32] = cnt + n_level2_pointers;
}
if (tmp[i] != cur) {
bits[i >> 5] |= 1u << (i & 31);
cur = tmp[i];
pointers[++cnt] = cur >= -TABLE_32_SIZE ? cur : hashmap_get(cur);
}
}
n_level2_pointers += cnt + 1;
}
void init(int _n, int q, const RoutingTableEntry *_tbl) {
int n = -1;
RoutingTableEntry *tbl = NULL;
compress(_n, (RoutingTableEntry *) _tbl, &n, &tbl);
for (int i = 0; i < n; i++) {
ins(htonl(tbl[i].addr), tbl[i].len, tbl[i].nexthop);
}
for (int i = table_32_cnt; i < 0; i++) {
walk_level3(i + TABLE_32_SIZE);
}
for (int i = table_24_cnt; i < 0; i++) {
walk_level2(i + TABLE_24_SIZE);
}
}
inline uint32_t popcount(uint32_t x) {
uint32_t ret;
__asm__ volatile("popcnt %0, %1" : "=r"(ret) : "r"(x) : );
return ret;
}
unsigned query(unsigned addr) {
addr = htonl(addr);
uint32_t tmp = table_16[addr >> 16];
if (tmp >= -TABLE_24_SIZE) {
uint32_t level2_index = tmp + TABLE_24_SIZE;
uint32_t addr_l2 = addr << 16 >> 24;
uint32_t off = level2_bit_sums[level2_index][addr_l2 / 32]
+ popcount(level2_bits[level2_index][addr_l2 / 32] & ((2u << (addr_l2 & 31)) - 1));
tmp = (int32_t) (int16_t) level2_pointers[off];
if (tmp >= -TABLE_32_SIZE) {
uint32_t level3_index = tmp + TABLE_32_SIZE;
uint32_t addr_l3 = addr & 255u;
uint32_t off = level3_bit_sums[level3_index][addr_l3 / 32]
+ popcount(level3_bits[level3_index][addr_l3 / 32] & ((2u << (addr_l3 & 31)) - 1));
tmp = (uint32_t) level3_pointers[off];
return hashmap_key[tmp];
} else {
return hashmap_key[tmp];
}
} else {
return tmp;
}
}
| Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |