2025-09-30 11:31:05 +08:00

128 lines
3.4 KiB
Rust
Executable File

use std::{net::Ipv4Addr, sync::RwLock};
use once_cell::sync::OnceCell;
use sdlan_sn_rs::utils::net_bit_len_to_mask;
use tracing::{debug, error};
#[derive(Debug)]
pub struct RouteTable {
content: RwLock<Vec<RouteInfo>>,
}
static ROUTETABLE: OnceCell<RouteTable> = OnceCell::new();
pub fn init_route() {
let rt = RouteTable::new();
ROUTETABLE.set(rt).unwrap();
}
pub fn get_route_table() -> &'static RouteTable {
ROUTETABLE.get().unwrap()
}
impl RouteTable {
pub fn new() -> Self {
Self {
content: RwLock::new(Vec::new()),
}
}
pub fn get_gateway_ip(&self, net_ip: u32) -> Option<u32> {
let routes = self.content.read().unwrap();
for route in &*routes {
debug!("route: {:?}", route.to_string());
if (route.net_ip & route.net_mask) == (net_ip & route.net_mask) {
// found
return Some(route.gateway_ip);
}
}
None
}
pub fn del_route(&self, net_ip: u32, net_mask: u32) {
let mut routes = self.content.write().unwrap();
let mut remove_idx = routes.len();
for i in 0..routes.len() {
let route = &routes[i];
if route.net_ip == net_ip && route.net_mask == net_mask {
remove_idx = i;
break;
}
}
if remove_idx < routes.len() {
routes.remove(remove_idx);
}
}
pub fn add_route(&self, net_ip: u32, net_mask: u32, gateway_ip: u32) -> Result<(), String> {
{
let cnt = self.content.read().unwrap();
let net = net_ip & net_mask;
for route in &*cnt {
if (route.net_ip & route.net_mask) == net {
return Err("route exists".to_owned());
}
}
}
{
let mut routes = self.content.write().unwrap();
routes.push(RouteInfo {
net_ip,
net_mask,
gateway_ip,
})
}
Ok(())
}
}
#[derive(Debug)]
pub struct RouteInfo {
pub net_ip: u32,
pub net_mask: u32,
pub gateway_ip: u32,
}
impl RouteInfo {
pub fn to_string(&self) -> String {
format!(
"{:?} mask={:?}, gateway={:?}",
self.net_ip.to_be_bytes(),
self.net_mask.to_be_bytes(),
self.gateway_ip.to_be_bytes()
)
}
}
// ip, mask, gateway, cidr;gateway,cidr2;gateway2
pub fn parse_route(route: String) -> Vec<(u32, u32, u32)> {
let mut result = Vec::new();
let routes: Vec<_> = route.split(",").collect();
for route in routes {
let route_info: Vec<_> = route.split(";").collect();
debug!("got route info: {:?}", route_info);
if route_info.len() != 2 {
error!("route info format error");
continue;
}
let cidr = route_info[0];
let gateway = route_info[1].parse::<Ipv4Addr>().unwrap();
let ip_and_mask: Vec<_> = cidr.split("/").collect();
if ip_and_mask.len() != 2 {
error!("route info ip/bit error");
continue;
}
let ip = ip_and_mask[0].parse::<Ipv4Addr>().unwrap();
let maskbit = ip_and_mask[1].parse::<u8>().unwrap();
result.push((
u32::from_be_bytes(ip.octets()),
net_bit_len_to_mask(maskbit),
u32::from_be_bytes(gateway.octets()),
));
}
result
}