#include "router.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdint.h>
#include <arpa/inet.h>
#include <stdio.h>
struct TrieNode {
uint32_t addr = 0; // net order, path of current node from root
uint32_t len = 0; // length of path (depth of node)
TrieNode* ch[2] = {nullptr, nullptr}; // children
RoutingTableEntry* entry = nullptr;
};
int mask_to_len(uint32_t mask);
uint32_t len_to_mask(int len);
TrieNode g_lookup_root; // trie root
bool getNetBit(uint32_t x, uint32_t pos) {
return ntohl(x) & (uint32_t(1) << (31 - pos));
}
bool bPrefixEqual(uint32_t x, uint32_t y, uint32_t len) {
uint32_t mask = len_to_mask(len);
return (mask & x) == (mask & y);
}
uint32_t getCommonLength(uint32_t x, uint32_t y) {
x = ntohl(x) ^ ntohl(y);
uint32_t len = 0;
uint32_t mask = uint32_t(1) << 31;
while ((mask & x) == 0) ++len, mask >>= 1;
return len;
}
TrieNode* searchPrefix(uint32_t addr, uint32_t len, TrieNode** stack = nullptr,
uint32_t* stack_size = nullptr) {
TrieNode* ptr = &g_lookup_root; // from root
if (stack_size) *stack_size = 0;
// find longest prefix
while (true) {
if (stack && stack_size) {
stack[(*stack_size)++] = ptr;
}
if (ptr->ch[0] && ptr->ch[0]->len <= len &&
bPrefixEqual(ptr->ch[0]->addr, addr, ptr->ch[0]->len)) {
ptr = ptr->ch[0];
} else if (ptr->ch[1] && ptr->ch[1]->len <= len &&
bPrefixEqual(ptr->ch[1]->addr, addr, ptr->ch[1]->len)) {
ptr = ptr->ch[1];
} else {
return ptr;
}
}
}
TrieNode* insertPrefix(uint32_t addr, uint32_t len) {
TrieNode* ptr = searchPrefix(addr, len);
if (len == ptr->len) return ptr;
// len > ptr->len
TrieNode* new_node = new TrieNode;
new_node->addr = addr;
new_node->len = len;
bool pa = getNetBit(addr, ptr->len);
if (ptr->ch[pa]) {
if (bPrefixEqual(ptr->ch[pa]->addr, addr, len)) {
// is prefix of child
new_node->ch[getNetBit(ptr->ch[pa]->addr, len)] = ptr->ch[pa];
ptr->ch[pa] = new_node;
} else {
// add common prefix node
uint32_t common_len = getCommonLength(ptr->ch[pa]->addr, addr);
TrieNode* pa_node = new TrieNode;
pa_node->addr = addr;
pa_node->len = common_len;
pa_node->ch[getNetBit(ptr->ch[pa]->addr, common_len)] = ptr->ch[pa];
pa_node->ch[getNetBit(addr, common_len)] = new_node;
ptr->ch[pa] = pa_node;
}
} else {
ptr->ch[pa] = new_node;
}
return new_node;
}
/**
* @brief 插入/删除一条路由表表项
* @param insert 如果要插入则为 true ,要删除则为 false
* @param entry 要插入/删除的表项
*
* 插入时如果已经存在一条 addr 和 len 都相同的表项,则替换掉原有的。
* 删除时按照 addr 和 len **精确** 匹配。
*/
void update(bool insert, RoutingTableEntry entry) {
if (insert) {
TrieNode* node = insertPrefix(entry.addr, entry.len);
node->entry = new RoutingTableEntry;
*node->entry = entry;
} else {
TrieNode* node;
TrieNode* stack[34];
uint32_t stack_size;
node = searchPrefix(entry.addr, entry.len, stack, &stack_size);
if (node->len == entry.len) {
delete node->entry;
node->entry = nullptr;
} else {
// no match
return;
}
// compress path
while (--stack_size > 0) {
node = stack[stack_size];
if (node->entry) break;
if (node->ch[0] && node->ch[1]) break;
TrieNode* pa = stack[stack_size - 1];
if (node->ch[0]) {
pa->ch[node == pa->ch[1]] = node->ch[0];
node->ch[0] = nullptr;
} else if (node->ch[1]) {
pa->ch[node == pa->ch[1]] = node->ch[1];
node->ch[1] = nullptr;
} else {
pa->ch[node == pa->ch[1]] = nullptr;
}
delete node;
node = nullptr;
}
}
}
/**
* @brief 转换 mask 为前缀长度
* @param mask 需要转换的 mask,网络字节序
* @return mask 合法则返回前缀长度,不合法则返回 -1
*/
int mask_to_len(uint32_t mask) {
// to host endian
mask = ntohl(mask);
uint32_t bit = uint32_t(1) << 31;
int len = 0;
while (bit & mask) ++len, bit >>= 1;
while (bit) {
if (bit & mask)
return -1;
else
bit >>= 1;
}
return len;
}
/**
* @brief 转换前缀长度为 mask,前缀长度范围为 [0,32]
* @param len 需要转换的前缀长度
* @return len 合法则返回对应的网络字节序的 mask,不合法则返回 0
*/
uint32_t len_to_mask(int len) {
if (len <= 0 || len > 32) return 0;
if (len == 32) return uint32_t(-1);
return htonl(((uint32_t(1) << len) - 1) << (32 - len));
}
void init(int n, int q, const RoutingTableEntry* a) {
for (int i = 0; i < n; ++i) {
update(true, a[i]);
}
}
unsigned query(unsigned addr) {
TrieNode* stack[34];
uint32_t stack_size;
searchPrefix(addr, 32, stack, &stack_size);
while (stack_size > 0) {
--stack_size;
if (stack[stack_size]->entry) {
return stack[stack_size]->entry->nexthop;
}
}
return 0;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 12.23 us | 24 KB | Accepted | Score: 25 | 显示更多 |
Testcase #2 | 76.519 ms | 103 MB + 584 KB | Accepted | Score: 25 | 显示更多 |
Testcase #3 | 393.723 ms | 103 MB + 584 KB | Accepted | Score: 25 | 显示更多 |
Testcase #4 | 708.943 ms | 103 MB + 584 KB | Accepted | Score: 25 | 显示更多 |