#include "router.h"
#include <stdint.h>
#include <stdlib.h>
#include <netinet/in.h>
#include <vector>
#define MASK_LEFT (0x80000000u)
std::vector<RoutingTableEntry> table;
typedef struct TrieNode {
TrieNode *children[2] = {NULL, NULL};
int8_t depth = -1;
int8_t size = 1;
bool is_entry = false;
uint32_t addr = 0;
uint32_t nexthop = 0;
} TrieNode;
TrieNode root;
void initialize_TrieNode(TrieNode *node, uint32_t addr, uint32_t nexthop,
int8_t depth, int8_t size, TrieNode *chlid0, TrieNode *child1, bool isEntry) {
node->addr = addr;
node->nexthop = nexthop;
node->depth = depth;
node->size = size;
node->is_entry = isEntry;
node->children[0] = chlid0;
node->children[1] = child1;
}
/**
* @brief 插入/删除一条路由表表项
* @param insert 如果要插入则为 true ,要删除则为 false
* @param entry 要插入/删除的表项
*
* 插入时如果已经存在一条 addr 和 len 都相同的表项,则替换掉原有的。
* 删除时按照 addr 和 len **精确** 匹配。
*/
void update(bool insert, RoutingTableEntry entry) {
uint32_t addr = ntohl(entry.addr);
int8_t len = entry.len;
uint32_t nexthop = entry.nexthop;
TrieNode *prev_node;
TrieNode *curr_node = &root;
if (insert) {
uint32_t mask = MASK_LEFT;
for (int8_t i = 0; i < len; ++i, mask >>= 1) {
int8_t entry_bit = (addr & mask) >> (31 - i);
if (i == curr_node->depth + curr_node->size) { // reaching the end of a node
prev_node = curr_node;
curr_node = curr_node->children[entry_bit];
if (curr_node == NULL) {
curr_node = (TrieNode *) malloc(sizeof(TrieNode));
initialize_TrieNode(curr_node, addr, nexthop, i, len - i, NULL, NULL, true);
prev_node->children[entry_bit] = curr_node;
return;
}
// otherwise entry_bit == curr_node_bit is guaranteed
} else { // in the middle of a node
int8_t curr_node_bit = (curr_node->addr & mask) >> (31 - i);
if (entry_bit != curr_node_bit) { // split due to mismatching
TrieNode *entry_node = (TrieNode *) malloc(sizeof(TrieNode));
initialize_TrieNode(entry_node, addr, nexthop, i, len - i, NULL, NULL, true);
uint32_t curr_size = i - curr_node->depth;
TrieNode *next_node = (TrieNode *) malloc(sizeof(TrieNode));
initialize_TrieNode(next_node, curr_node->addr, curr_node->nexthop,
i, curr_node->size - curr_size,
curr_node->children[0], curr_node->children[1],
curr_node->is_entry);
curr_node->children[curr_node_bit] = next_node;
curr_node->children[entry_bit] = entry_node;
curr_node->size = curr_size;
curr_node->is_entry = false; // otherwise we should have reached the end of a node
return;
}
}
}
// all matches
if (len == curr_node->depth + curr_node->size) { // reaching the end of a node, unnecessary to split, replace
curr_node->nexthop = nexthop;
curr_node->is_entry = true;
curr_node->addr = addr;
} else { // node splitting: split a longer chain
int8_t next_node_bit = (curr_node->addr & mask) >> (31 - len);
TrieNode *next_node = (TrieNode *) malloc(sizeof(TrieNode));
uint32_t curr_size = len - curr_node->depth;
initialize_TrieNode(next_node, curr_node->addr, curr_node->nexthop,
len, curr_node->size - curr_size,
curr_node->children[0], curr_node->children[1],
curr_node->is_entry);
curr_node->children[next_node_bit] = next_node;
curr_node->children[1 - next_node_bit] = NULL;
curr_node->addr = addr;
curr_node->nexthop = nexthop;
curr_node->size = curr_size;
curr_node->is_entry = true;
}
} else { // deletion
uint32_t mask = MASK_LEFT;
int8_t entry_bit;
int8_t curr_node_bit;
int8_t curr_node_rank;
for (int8_t i = 0; i < len; ++i, mask >>= 1) {
entry_bit = (addr & mask) >> (31 - i);
if (i == curr_node->depth + curr_node->size) { // reaching the end of a node
prev_node = curr_node;
curr_node = curr_node->children[entry_bit];
curr_node_rank = entry_bit;
if (curr_node == NULL) {
return;
}
// otherwise entry_bit == curr_node_bit is guaranteed
} else {
curr_node_bit = (curr_node->addr & mask) >> (31 - i);
if (entry_bit != curr_node_bit) {
return;
}
}
}
if (curr_node->is_entry && addr == curr_node->addr &&
len == curr_node->depth + curr_node->size) { // found, delete
if (curr_node == &root || (curr_node->children[0] != NULL && curr_node->children[1] != NULL)) {
// 2 children or root
curr_node->is_entry = false;
} else {
TrieNode *child_node = curr_node->children[curr_node->children[0] == NULL ? 1 : 0];
if (child_node != NULL) { // 1 child
curr_node->addr = child_node->addr;
curr_node->size += child_node->size;
curr_node->nexthop = child_node->nexthop;
curr_node->is_entry = child_node->is_entry;
curr_node->children[0] = child_node->children[0];
curr_node->children[1] = child_node->children[1];
free(child_node);
} else { // no child
prev_node->children[curr_node_rank] = NULL;
free(curr_node);
if (!prev_node->is_entry && prev_node != &root) {
TrieNode *other = prev_node->children[1 - curr_node_rank];
initialize_TrieNode(prev_node, other->addr, other->nexthop, prev_node->depth,
prev_node->size + other->size, other->children[0], other->children[1],
other->is_entry);
free(other);
}
}
}
/*
if (curr_node == &root) { // len == 0
root.is_entry = false;
return;
}
if (curr_node->children[0]) {
if (curr_node->children[1]) { // 2 children
curr_node->is_entry = false;
} else { // 1 child
TrieNode *child_node = curr_node->children[0];
curr_node->addr = child_node->addr;
curr_node->size += child_node->size;
curr_node->nexthop = child_node->nexthop;
curr_node->is_entry = child_node->is_entry;
curr_node->children[0] = child_node->children[0];
curr_node->children[1] = child_node->children[1];
free(child_node);
}
} else {
if (curr_node->children[1]) { // 1 child
TrieNode *child_node = curr_node->children[0];
curr_node->addr = child_node->addr;
curr_node->size += child_node->size;
curr_node->nexthop = child_node->nexthop;
curr_node->is_entry = child_node->is_entry;
curr_node->children[0] = child_node->children[0];
curr_node->children[1] = child_node->children[1];
free(child_node);
} else { // no child
prev_node->children[curr_node_rank] = NULL;
free(curr_node);
if (!prev_node->is_entry && prev_node != &root) {
TrieNode *other = prev_node->children[1 - curr_node_rank];
initialize_TrieNode(prev_node, other->addr, other->nexthop, prev_node->depth,
prev_node->size + other->size, other->children[0], other->children[1],
other->is_entry);
free(other);
}
}
}
*/
}
}
}
/**
* @brief 进行一次路由表的查询,按照最长前缀匹配原则
* @param addr 需要查询的目标地址,网络字节序
* @param nexthop 如果查询到目标,把表项的 nexthop 写入
* @return 查到则返回 true ,没查到则返回 false
*/
bool prefix_query(uint32_t addr, uint32_t *nexthop) {
uint32_t nexthop_tmp, is_entry_tmp = false;
addr = ntohl(addr);
uint32_t mask = 0x80000000;
TrieNode *prev_node;
TrieNode *curr_node = &root;
for (int8_t i = 0; i < 32; ++i, mask >>= 1) {
uint8_t entry_bit = (addr & mask) >> (31 - i);
if (i == curr_node->depth + curr_node->size) { // reaching the end of a node
if (curr_node->is_entry) {
nexthop_tmp = curr_node->nexthop;
is_entry_tmp = true;
}
curr_node = curr_node->children[entry_bit];
if (curr_node == NULL) {
*nexthop = nexthop_tmp;
return is_entry_tmp;
}
} else {
uint8_t curr_node_bit = (curr_node->addr & mask) >> (31 - i);
if (entry_bit != curr_node_bit) {
*nexthop = nexthop_tmp;
return is_entry_tmp;
}
}
}
*nexthop = curr_node->nexthop;
return curr_node->is_entry;
}
void init(int n, int q, const RoutingTableEntry *a) {
for (int i = 0; i < n; ++i) {
update(true, a[i]);
}
for (int i = 0; i < n; i += 2) {
update(false, a[i]);
}
for (int i = 0; i < n; i += 2) {
update(true, a[i]);
}
}
unsigned query(unsigned addr) {
uint32_t nexthop = 0;
if (prefix_query(addr, &nexthop))
return nexthop;
else
return 0;
}
| Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
| Testcase #1 | 12.65 us | 24 KB | Accepted | Score: 25 | 显示更多 |
| Testcase #2 | 141.404 ms | 78 MB + 336 KB | Accepted | Score: 25 | 显示更多 |
| Testcase #3 | 391.818 ms | 78 MB + 336 KB | Accepted | Score: 25 | 显示更多 |
| Testcase #4 | 642.99 ms | 78 MB + 336 KB | Accepted | Score: 25 | 显示更多 |