#define USE_TRIE
#include "router.h"
#include <stdint.h>
#include <netinet/in.h>
#include <new>
#include <unordered_set>
#include <algorithm>
#ifdef DEBUG
#include <stdio.h>
#endif
inline uint32_t mask(unsigned char len) {
return len == 0 ? 0 : ((~0) << (32 - len));
}
struct Node {
uint32_t prefix;
unsigned char len;
bool valid;
uint32_t nexthop;
Node *next[2];
} *pool = 0, *pool_end = 0;
const size_t ALLOC_AMOUNT = 1 << 20;
Node *alloc() {
if (pool == pool_end) {
pool = new Node[ALLOC_AMOUNT];
pool_end = pool + ALLOC_AMOUNT;
}
return pool ++;
}
Node *root = 0;
#ifdef DEBUG
void trie_print(Node *cur = root, int indent = 0) {
#define INDENT for (int _ = 0; _ < indent; _ ++) fprintf(stderr, " ");
if (cur) {
INDENT fprintf(stderr, "%08x/%d %s\n", cur->prefix, cur->len, cur->valid ? "valid" : "");
for (size_t i = 0; i < 2; i ++) {
INDENT fprintf(stderr, " %zu\n", i);
trie_print(cur->next[i], indent + 1);
}
} else {
INDENT fprintf(stderr, "-\n");
}
#undef INDENT
}
#endif
Node **trie_find(uint32_t addr, unsigned char lenlim = 32, bool insert = false) {
Node **cur = &root, **val = 0;
while (*cur) {
#ifdef DEBUG
fprintf(stderr, "cur %08x/%d %s\n", (*cur)->prefix, (*cur)->len, (*cur)->valid ? "valid" : "");
if (val) fprintf(stderr, "val %08x/%d %s\n", (*val)->prefix, (*val)->len, (*val)->valid ? "valid" : "");
#endif
if ((addr & mask((*cur)->len)) != (*cur)->prefix) break;
if ((*cur)->valid) val = cur;
if ((*cur)->len >= lenlim) break;
size_t b = (addr & (1 << (31 - (*cur)->len))) ? 1 : 0;
cur = &((*cur)->next[b]);
}
return insert ? cur : val;
}
void trie_add(uint32_t prefix, unsigned char len, uint32_t nexthop) {
Node **f = trie_find(prefix, len, true);
#ifdef DEBUG
fprintf(stderr, "add %08x/%d\n", prefix, len);
if (*f)
fprintf(stderr, "found %08x/%d %s\n", (*f)->prefix, (*f)->len, (*f)->valid ? "valid" : "");
else
fprintf(stderr, "found = NULL\n");
#endif
if (*f == 0) {
*f = new(alloc()) Node {
.prefix = prefix,
.len = len,
.valid = 1,
.nexthop = nexthop,
.next { 0, 0 }
};
} else {
uint32_t pattern = (prefix ^ (*f)->prefix) | ~mask(std::min((*f)->len, len));
unsigned char common = pattern == 0 ? 32 : __builtin_clz(pattern);
#ifdef DEBUG
fprintf(stderr, "common = %d\n", common);
#endif
size_t pb = (pattern & (1 << (31 - common))) ? 1 : 0;
if (common == (*f)->len) {
(*f)->valid = true;
(*f)->nexthop = nexthop;
} else {
Node *new_node = new(alloc()) Node {
.prefix = prefix & mask(common),
.len = common,
.valid = 0,
.next { 0, 0 }
};
new_node->next[!pb] = *f;
*f = new_node;
if (common == len) {
new_node->valid = true;
} else {
new_node->next[pb] = new(alloc()) Node {
.prefix = prefix,
.len = len,
.valid = 1,
.nexthop = nexthop,
.next { 0, 0 }
};
}
}
}
#ifdef DEBUG
trie_print();
#endif
}
struct EntryHash {
size_t operator()(const RoutingTableEntry &entry) const {
return std::hash<uint64_t>()(
uint64_t(entry.addr) | (uint64_t(entry.len) << 32)
);
}
};
struct EntryEquals {
bool operator()(const RoutingTableEntry &a, const RoutingTableEntry &b) const {
return a.addr == b.addr && a.len == b.len;
}
};
std::unordered_set<RoutingTableEntry, EntryHash, EntryEquals> hash_table;
void init(int n, int q, const RoutingTableEntry *a) {
#ifndef USE_TRIE
hash_table.reserve(n);
#endif
for (size_t i = 0; i < (size_t) n; i ++) {
#ifdef USE_TRIE
trie_add(ntohl(a[i].addr), a[i].len, a[i].nexthop);
#else
hash_table.insert(a[i]);
#endif
}
}
unsigned query(unsigned addr) {
#ifdef USE_TRIE
Node **node = trie_find(ntohl(addr));
if (! node || ! *node) return 0;
#ifdef DEBUG
fprintf(stderr, "%08x in %08x/%d\n", ntohl(addr), (*node)->prefix, (*node)->len);
#endif
return (*node)->nexthop;
#else
unsigned res = 0;
for (unsigned char len = 0; len <= 32; len ++) {
auto it = hash_table.find(RoutingTableEntry {
.addr = htonl(ntohl(addr) & mask(len)), .len = len
});
if (it != hash_table.end()) {
#ifdef DEBUG
fprintf(stderr, "%08x/%d\n", ntohl(it->addr), it->len);
#endif
res = it->nexthop;
}
}
return res;
#endif
}
#ifdef LOCAL
const size_t MAX_ENTS = 1000000;
RoutingTableEntry entries[MAX_ENTS];
int main() {
FILE *f = fopen("out.bin", "rb");
size_t num = fread(entries, sizeof(RoutingTableEntry), MAX_ENTS, f);
fclose(f);
init(num, 0, entries);
unsigned a;
while (scanf("%u", &a) > 0) {
printf("%u\n", query(htonl(a)));
fflush(stdout);
}
}
#endif
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 13.03 us | 28 KB | Accepted | Score: 25 | 显示更多 |
Testcase #2 | 57.759 ms | 55 MB + 396 KB | Accepted | Score: 25 | 显示更多 |
Testcase #3 | 268.617 ms | 55 MB + 396 KB | Accepted | Score: 25 | 显示更多 |
Testcase #4 | 478.717 ms | 55 MB + 396 KB | Accepted | Score: 25 | 显示更多 |