#include "router.h"
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <time.h>
#define __min(a,b) (((a) < (b)) ? (a) : (b))
/*
RoutingTable Entry 的定义如下:
typedef struct {
unsigned addr;
unsigned char len;
unsigned nexthop;
} __attribute__((packed)) RoutingTableEntry;
*/
/**
* @brief 求 addr 的反端
*/
uint32_t endian_revert(uint32_t addr)
{
uint32_t addr_revert = 0;
uint32_t byte_data = addr & 0x000000ff;
addr_revert += byte_data << 24;
byte_data = addr & 0x0000ff00;
addr_revert += byte_data << 8;
byte_data = addr & 0x00ff0000;
addr_revert += byte_data >> 8;
byte_data = addr & 0xff000000;
addr_revert += byte_data >> 24;
return addr_revert;
}
struct RoutingNodeEntry
{
uint32_t addr_lit;
uint32_t nexthop;
RoutingNodeEntry* next_entry;
RoutingNodeEntry(uint32_t addr_lit, uint32_t nexthop, RoutingNodeEntry* next_entry) :
addr_lit(addr_lit), nexthop(nexthop), next_entry(next_entry)
{ }
};
struct RoutingNode
{
RoutingNodeEntry* first_entry;
RoutingNode* lc;
RoutingNode* rc;
RoutingNode* parent;
uint8_t len;
bool is_fake;
RoutingNode(RoutingTableEntry entry, RoutingNode* lc, RoutingNode* rc, RoutingNode* parent) :
first_entry(new RoutingNodeEntry(endian_revert(entry.addr), entry.nexthop, nullptr)),
lc(lc), rc(rc), parent(parent),
len((uint8_t)entry.len), is_fake(false)
{ }
RoutingNode(uint32_t addr_lit, uint8_t len, RoutingNode* lc, RoutingNode* rc, RoutingNode* parent) :
first_entry(new RoutingNodeEntry(addr_lit, 0, nullptr)),
lc(lc), rc(rc), parent(parent),
len(len), is_fake(true)
{ }
~RoutingNode()
{
RoutingNodeEntry* prev_entry = nullptr;
RoutingNodeEntry* cur_entry = first_entry;
while (cur_entry != nullptr) {
prev_entry = cur_entry;
cur_entry = cur_entry->next_entry;
delete prev_entry;
}
}
void show()
{
printf("<fake: %d; len: %d; [", is_fake, len);
RoutingNodeEntry* entry = first_entry;
while (entry) {
if (entry != first_entry) printf(", ");
printf("(0x%.8x, 0x%.8x)",
entry->addr_lit, entry->nexthop);
entry = entry->next_entry;
}
printf("] >\n");
}
};
class RoutingTable
{
RoutingNode* ROOT;
public:
RoutingTable() :
ROOT(new RoutingNode(0, 0, nullptr, nullptr, nullptr))
{ }
~RoutingTable() { save_delete_node(ROOT); }
void insert(RoutingTableEntry entry);
bool remove(RoutingTableEntry entry);
bool query(uint32_t addr_big, uint32_t* nexthop);
void show_node(RoutingNode* node) {
if (node == nullptr) printf("null pointer!\n");
else {
node->show();
if (node->lc == nullptr && node->rc == nullptr) printf("leaf\n");
else {
printf("lc: ");
if (node->lc) {
show_node(node->lc);
}
else printf("null pointer!\n");
printf("rc: ");
if (node->rc) {
show_node(node->rc);
}
else printf("null pointer!\n");
}
}
}
RoutingNode* get_root() { return ROOT; }
private:
void insert_same_entry(RoutingNode* node, RoutingTableEntry entry);
void replace_fake_node(RoutingNode* node, RoutingTableEntry entry);
bool remove_entry(RoutingNode* node, uint32_t addr_tar);
void remove_node(RoutingNode* node);
void get_node_data(RoutingNode* node, uint32_t* nexthop);
void save_delete_node(RoutingNode* node);
};
void RoutingTable::insert(RoutingTableEntry entry)
{
RoutingNode* cur_node = this->ROOT;
uint32_t addr_tar = endian_revert(entry.addr);
uint8_t len_tar = entry.len;
while (true) {
/* cur_node 应当是当前的祖先节点 */
if (cur_node->len == len_tar) { // 插入地址前缀和 cur_node 地址前缀一样
insert_same_entry(cur_node, entry);
break;
}
assert(cur_node->len < len_tar);
if (cur_node->lc == nullptr) { // cur_node 是叶子节点
cur_node->lc = new RoutingNode(entry, nullptr, nullptr, cur_node);
break;
}
uint32_t addr_lc = cur_node->lc->first_entry->addr_lit;
uint8_t len_lc = cur_node->lc->len;
if (cur_node->rc == nullptr){ // cur_node 只有一个孩子,孩子不可能是 fake
uint32_t res = addr_tar ^ addr_lc;
res >>= (32 - __min(len_tar, len_lc));
if (res != 0) { // 兄弟节点(不分左右)
cur_node->rc = new RoutingNode(entry, nullptr, nullptr, cur_node);
break;
}
else if (len_tar < len_lc) { // 父亲节点
cur_node->lc->parent = new RoutingNode(entry, cur_node->lc, nullptr, cur_node);
cur_node->lc = cur_node->lc->parent;
break;
}
else {
cur_node = cur_node->lc;
continue;
}
}
else { // cur_node 有两个孩子
uint32_t addr_rc = cur_node->rc->first_entry->addr_lit;
uint8_t len_rc = cur_node->rc->len;
uint32_t res_lc = addr_tar ^ addr_lc;
uint32_t res_rc = addr_tar ^ addr_rc;
uint32_t res_branch = addr_lc ^ addr_rc;
uint32_t mask_lc = 0xffffffff << (32 - __min(len_tar, len_lc));
uint32_t mask_rc = 0xffffffff << (32 - __min(len_tar, len_rc));
uint32_t mask_branch = 0xffffffff << (32 - __min(len_lc, len_rc));
res_lc &= mask_lc;
res_rc &= mask_rc;
res_branch &= mask_branch; // 一定不等于 0
if (res_branch < res_lc && res_branch < res_rc) { // 两个孩子的叔叔节点
/* 要求出新的岔点,分岔位置尽可能靠近孩子,因此从孩子出发(min(len_lc, len_rc))往上 */
uint8_t len_min = __min(len_lc, len_rc);
res_branch >>= (32 - len_min);
uint8_t branch_bit = len_min;
do {
branch_bit -= 1;
res_branch >>= 1;
} while (res_branch != 0);
RoutingNode* null_parent = new RoutingNode(addr_lc, branch_bit, cur_node->lc, cur_node->rc, cur_node);
cur_node->lc = null_parent;
cur_node->rc = new RoutingNode(entry, nullptr, nullptr, cur_node);
break;
}
else if (res_lc < res_rc) {
uint8_t len_min = __min(len_tar, len_lc);
res_lc >>= (32 - len_min); // 取 res_lc 的高 len_min 位,移到最低处
if (res_lc != 0) { // 左孩子的兄弟节点
/* 求出新的岔点。同样,分岔处靠近孩子一边 */
uint8_t branch_bit = len_min;
do {
branch_bit -= 1;
res_lc >>= 1;
} while (res_lc != 0);
RoutingNode* new_node = new RoutingNode(entry, nullptr, nullptr, nullptr);
RoutingNode* fake_node = new RoutingNode(addr_lc, branch_bit, cur_node->lc, new_node, cur_node);
cur_node->lc->parent = fake_node;
new_node->parent = fake_node;
cur_node->lc = fake_node;
break;
}
else if (len_tar < len_lc) { // 左孩子的父亲节点
if (cur_node->lc->is_fake) {
replace_fake_node(cur_node->lc, entry);
}
else {
cur_node->lc->parent = new RoutingNode(entry, cur_node->lc, nullptr, cur_node);
cur_node->lc = cur_node->lc->parent;
}
break;
}
else { // 左孩子的孩子
cur_node = cur_node->lc;
continue;
}
}
else if (res_rc < res_lc) {
uint8_t len_min = __min(len_tar, len_rc);
res_rc >>= (32 - len_min);
if (res_rc != 0) { // 右孩子的兄弟节点
uint8_t branch_bit = len_min;
do {
branch_bit -= 1;
res_rc >>= 1;
} while (res_rc != 0);
RoutingNode* new_node = new RoutingNode(entry, nullptr, nullptr, nullptr);
RoutingNode* fake_node = new RoutingNode(addr_rc, branch_bit, cur_node->rc, new_node, cur_node);
cur_node->rc->parent = fake_node;
new_node->parent = fake_node;
cur_node->rc = fake_node;
break;
}
else if (len_tar < len_rc) { // 右孩子的父亲
if (cur_node->rc->is_fake) {
replace_fake_node(cur_node->rc, entry);
}
else {
cur_node->rc->parent = new RoutingNode(entry, cur_node->rc, nullptr, cur_node);
cur_node->rc = cur_node->rc->parent;
}
break;
}
else { // 右孩子的孩子
cur_node = cur_node->rc;
continue;
}
}
else { // res_lc == 0 && res_rc == 0 ,两个孩子的父亲
RoutingNode* new_node = new RoutingNode(entry, cur_node->lc, cur_node->rc, cur_node);
cur_node->lc = new_node;
cur_node->rc = nullptr;
new_node->lc->parent = new_node;
new_node->rc->parent = new_node;
break;
}
}
}
}
void RoutingTable::insert_same_entry(RoutingNode* node, RoutingTableEntry entry)
{
assert(node->len == (uint8_t)entry.len);
if (node->is_fake) {
replace_fake_node(node, entry);
}
else {
RoutingNodeEntry* new_entry =
new RoutingNodeEntry(endian_revert(entry.addr), entry.nexthop, node->first_entry);
node->first_entry = new_entry;
}
}
void RoutingTable::replace_fake_node(RoutingNode* node, RoutingTableEntry entry)
{
assert(node->is_fake);
node->first_entry->addr_lit = endian_revert(entry.addr);
node->first_entry->nexthop = entry.nexthop;
node->len = (uint8_t)entry.len;
node->is_fake = false;
}
bool RoutingTable::remove(RoutingTableEntry entry)
{
RoutingNode* cur_node = this->ROOT;
uint32_t addr_tar = endian_revert(entry.addr);
uint8_t len_tar = entry.len;
while (true) {
if (cur_node->len == len_tar) {
if (cur_node->is_fake) return false;
return remove_entry(cur_node, addr_tar);
}
assert(cur_node->len < len_tar);
if (cur_node->lc) { // cur_node 是叶子节点
uint32_t addr_lc = cur_node->lc->first_entry->addr_lit;
uint8_t len_lc = cur_node->lc->len;
if (len_lc <= len_tar && (addr_lc ^ addr_tar) >> (32 - len_lc) == 0) {
cur_node = cur_node->lc;
continue;
}
}
if (cur_node->rc) {
uint32_t addr_rc = cur_node->rc->first_entry->addr_lit;
uint8_t len_rc = cur_node->rc->len;
if (len_rc <= len_tar && (addr_rc ^ addr_tar) >> (32 - len_rc) == 0) {
cur_node = cur_node->rc;
continue;
}
}
return false;
}
}
bool RoutingTable::remove_entry(RoutingNode* node, uint32_t addr_tar)
{
RoutingNodeEntry* prev_entry = nullptr;
RoutingNodeEntry* cur_entry = node->first_entry;
while (cur_entry != nullptr && cur_entry->addr_lit != addr_tar) {
prev_entry = cur_entry;
cur_entry = cur_entry->next_entry;
}
if (cur_entry == nullptr) return false;
if (prev_entry) {
/* cur_entry 不是第一个 */
prev_entry->next_entry = cur_entry->next_entry;
delete cur_entry;
}
else if (cur_entry->next_entry == nullptr) {
/* prev_entry 是空说明 cur_entry 是第一个,
* cur_entry->next_entry 也是空说明 cur_entry 是最后一个
* 因此要删除整个节点
* */
remove_node(node);
}
else {
/* cur_entry 是第一个但不是最后一个 */
node->first_entry = cur_entry->next_entry;
delete cur_entry;
}
return true;
}
void RoutingTable::remove_node(RoutingNode* node)
{
/* node 不应该是 fake */
RoutingNode* parent_node = node->parent;
assert(!node->is_fake);
if (node->lc == nullptr && node != ROOT) { // 叶子节点
RoutingNode* bro_node = (parent_node->lc == node) ? parent_node->rc : parent_node->lc;
if (parent_node->is_fake && parent_node != ROOT) { // 非根的 fake 节点一定有两个孩子
assert(parent_node->rc != nullptr);
assert(parent_node->lc != nullptr);
RoutingNode* grandpa_node = parent_node->parent;
assert(grandpa_node);
if (grandpa_node->lc == parent_node) {
grandpa_node->lc = bro_node;
}
else {
assert(grandpa_node->rc == parent_node);
grandpa_node->rc = bro_node;
}
bro_node->parent = grandpa_node;
delete parent_node;
delete node;
}
else if (bro_node != nullptr && bro_node->is_fake) {
assert(bro_node->rc != nullptr);
assert(bro_node->lc != nullptr);
parent_node->lc = bro_node->lc;
parent_node->rc = bro_node->rc;
bro_node->lc->parent = parent_node;
bro_node->rc->parent = parent_node;
delete bro_node;
delete node;
}
else {
if (parent_node->lc == node) {
parent_node->lc = parent_node->rc;
parent_node->rc = nullptr;
}
else {
assert(parent_node->rc == node);
parent_node->rc = nullptr;
}
delete node;
}
}
else if (node->rc == nullptr && node != ROOT) { // 只有一个孩子,这个孩子不可能是 fake
assert(node->lc && !node->lc->is_fake);
if (parent_node->lc == node) {
parent_node->lc = node->lc;
}
else {
assert(parent_node->rc == node);
parent_node->rc = node->lc;
}
node->lc->parent = parent_node;
delete node;
}
else { // 有两个孩子,或者是根节点,node 变成 fake
if (node != ROOT) {
/* 求出最靠近孩子的岔点 */
uint32_t addr_lc = node->lc->first_entry->addr_lit;
uint32_t addr_rc = node->rc->first_entry->addr_lit;
uint8_t len_lc = node->lc->len;
uint8_t len_rc = node->rc->len;
uint8_t branch_bit = __min(len_lc, len_rc);
uint32_t res_branch = (addr_lc ^ addr_rc) >> (32 - branch_bit);
assert(res_branch != 0);
do {
branch_bit -= 1;
res_branch >>= 1;
} while (res_branch != 0);
/* 将 node 变成正确的 fake */
assert(node->first_entry->next_entry == nullptr);
node->first_entry->addr_lit = addr_lc;
assert(node->len <= branch_bit);
node->len = branch_bit;
}
else {
assert(node->len == 0);
}
node->is_fake = true;
}
}
bool RoutingTable::query(uint32_t addr_big, uint32_t* nexthop)
{
RoutingNode* cur_node = ROOT;
RoutingNode* target_node = nullptr;
uint32_t addr_tar = endian_revert(addr_big);
while (true) {
if (!cur_node->is_fake) {
target_node = cur_node;
}
if (cur_node->lc) {
uint32_t addr_lc = cur_node->lc->first_entry->addr_lit;
uint8_t len_lc = cur_node->lc->len;
if ((addr_lc ^ addr_tar) >> (32 - len_lc) == 0) {
cur_node = cur_node->lc;
continue;
}
}
if (cur_node->rc) {
uint32_t addr_rc = cur_node->rc->first_entry->addr_lit;
uint8_t len_rc = cur_node->rc->len;
if ((addr_rc ^ addr_tar) >> (32 - len_rc) == 0) {
cur_node = cur_node->rc;
continue;
}
}
break;
}
if (target_node == nullptr) {
*nexthop = 0;
return false;
}
get_node_data(target_node, nexthop);
return true;
}
void RoutingTable::get_node_data(RoutingNode* node, uint32_t* nexthop)
{
*nexthop = node->first_entry->nexthop;
}
void RoutingTable::save_delete_node(RoutingNode* node)
{
if (node->lc) save_delete_node(node->lc);
if (node->rc) save_delete_node(node->rc);
delete node;
}
RoutingTable routing_table;
void shuffle(int* a, int len)
{
for (int i = 0; i < len; i++) {
int temp1 = rand() % len;
int swap = a[i];
a[i] = a[temp1];
a[temp1] = swap;
}
}
void remove_en(int j, const RoutingTableEntry* a)
{
printf("\nremove <addr: 0x%.8x, len: %d, nexthop: 0x%.8x>\n", endian_revert(a[j].addr), a[j].len, a[j].nexthop);
bool res = routing_table.remove(a[j]);
printf("remove success %d\n", res);
routing_table.show_node(routing_table.get_root());
printf("\n");
}
void remove_insert(int left_bound, int right_bound, const RoutingTableEntry* a)
{
for (int j = left_bound; j < right_bound; j++) {
//printf("\nremove <addr: 0x%.8x, len: %d, nexthop: 0x%.8x>\n", endian_revert(a[j].addr), a[j].len, a[j].nexthop);
bool res = routing_table.remove(a[j]);
assert(res);
//printf("remove success %d\n", res);
//routing_table.show_node(routing_table.get_root());
//printf("\n");
}
/*
srand((unsigned)time(NULL));
int* b = new int[right_bound - left_bound];
for (int i = 0; i < right_bound - left_bound; i++) {
b[i] = left_bound + i;
}
shuffle(b, right_bound - left_bound);
for (int i = 0; i < right_bound - left_bound; i++) {
printf("\nremove <addr: 0x%.8x, len: %d, nexthop: 0x%.8x>\n", endian_revert(a[b[i]].addr), a[b[i]].len, a[b[i]].nexthop);
bool res = routing_table.remove(a[b[i]]);
printf("remove success %d\n", res);
routing_table.show_node(routing_table.get_root());
printf("\n");
}
*/
for (int j = left_bound; j < right_bound; j++) {
//printf("\ninsert <addr: 0x%.8x, len: %d, nexthop: 0x%.8x>\n", endian_revert(a[j].addr), a[j].len, a[j].nexthop);
routing_table.insert(a[j]);
//routing_table.show_node(routing_table.get_root());
//printf("\n");
}
//delete[] b;
}
void init(int n, int q, const RoutingTableEntry *a) {
for (int i = 0; i < n; i++) {
RoutingTableEntry entry = a[i];
//printf("\ninsert <addr: 0x%.8x, len: %d, nexthop: 0x%.8x>\n", endian_revert(entry.addr), entry.len, entry.nexthop);
routing_table.insert(entry);
//routing_table.show_node(routing_table.get_root());
//printf("\n");
}
//routing_table.show_node(routing_table.get_root());
//printf("\n");
if (n > 5) {
remove_insert(n - 10, n, a);
}
}
unsigned query(unsigned addr) {
//printf("\nquery addr = 0x%.8x\n", endian_revert(addr));
unsigned nexthop;
routing_table.query(addr, &nexthop);
return nexthop;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 12.51 us | 24 KB | Accepted | Score: 25 | 显示更多 |
Testcase #2 | 147.936 ms | 123 MB + 780 KB | Accepted | Score: 25 | 显示更多 |
Testcase #3 | 494.299 ms | 123 MB + 780 KB | Accepted | Score: 25 | 显示更多 |
Testcase #4 | 840.893 ms | 123 MB + 780 KB | Accepted | Score: 25 | 显示更多 |