use std::{net::Ipv4Addr, sync::RwLock}; use once_cell::sync::OnceCell; use sdlan_sn_rs::utils::net_bit_len_to_mask; #[derive(Debug)] pub struct RouteTable { content: RwLock>, } static ROUTETABLE: OnceCell = OnceCell::new(); pub fn init_route() { let rt = RouteTable::new(); ROUTETABLE.set(rt).unwrap(); } pub fn get_route_table() -> &'static RouteTable { ROUTETABLE.get().unwrap() } impl RouteTable { pub fn new() -> Self { Self { content: RwLock::new(Vec::new()), } } pub fn get_gateway_ip(&self, net_ip: u32) -> Option { let routes = self.content.read().unwrap(); for route in &*routes { println!("route: {:?}", route.to_string()); if (route.net_ip & route.net_mask) == (net_ip & route.net_mask) { // found return Some(route.gateway_ip); } } None } pub fn del_route(&self, net_ip: u32, net_mask: u32) { let mut routes = self.content.write().unwrap(); let mut remove_idx = routes.len(); for i in 0..routes.len() { let route = &routes[i]; if route.net_ip == net_ip && route.net_mask == net_mask { remove_idx = i; break; } } if remove_idx < routes.len() { routes.remove(remove_idx); } } pub fn add_route(&self, net_ip: u32, net_mask: u32, gateway_ip: u32) -> Result<(), String> { { let cnt = self.content.read().unwrap(); let net = net_ip & net_mask; for route in &*cnt { if (route.net_ip & route.net_mask) == net { return Err("route exists".to_owned()); } } } { let mut routes = self.content.write().unwrap(); routes.push(RouteInfo { net_ip, net_mask, gateway_ip, }) } Ok(()) } } #[derive(Debug)] pub struct RouteInfo { pub net_ip: u32, pub net_mask: u32, pub gateway_ip: u32, } impl RouteInfo { pub fn to_string(&self) -> String { format!( "{:?} mask={:?}, gateway={:?}", self.net_ip.to_be_bytes(), self.net_mask.to_be_bytes(), self.gateway_ip.to_be_bytes() ) } } // ip, mask, gateway, cidr;gateway,cidr2;gateway2 pub fn parse_route(route: String) -> Vec<(u32, u32, u32)> { let mut result = Vec::new(); let routes: Vec<_> = route.split(",").collect(); for route in routes { let route_info: Vec<_> = route.split(";").collect(); println!("got route info: {:?}", route_info); if route_info.len() != 2 { println!("route info format error"); continue; } let cidr = route_info[0]; let gateway = route_info[1].parse::().unwrap(); let ip_and_mask: Vec<_> = cidr.split("/").collect(); if ip_and_mask.len() != 2 { println!("route info ip/bit error"); continue; } let ip = ip_and_mask[0].parse::().unwrap(); let maskbit = ip_and_mask[1].parse::().unwrap(); result.push(( u32::from_be_bytes(ip.octets()), net_bit_len_to_mask(maskbit), u32::from_be_bytes(gateway.octets()), )); } result }