From 9cba4a976dc14c89011e9d87a61915048537856a Mon Sep 17 00:00:00 2001 From: alex Date: Wed, 25 Mar 2026 16:35:11 +0800 Subject: [PATCH] added route table --- Cargo.lock | 11 +++ Cargo.toml | 2 + src/bin/punchnet/main.rs | 2 +- src/network/arp.rs | 10 +-- src/network/node.rs | 4 + src/network/route.rs | 148 +++++++++++++------------------------ src/network/tun_linux.rs | 20 ++++- src/utils/acl_session.rs | 6 +- src/utils/command.rs | 20 ++++- src/utils/mod.rs | 2 + src/utils/system_action.rs | 91 +++++++++++++++++++++++ 11 files changed, 210 insertions(+), 106 deletions(-) create mode 100644 src/utils/system_action.rs diff --git a/Cargo.lock b/Cargo.lock index c01c03d..6f9203a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -155,6 +155,15 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d67af77d68a931ecd5cbd8a3b5987d63a1d1d1278f7f6a60ae33db485cdebb69" +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + [[package]] name = "arrayvec" version = "0.7.6" @@ -2135,6 +2144,7 @@ name = "punchnet" version = "1.0.3" dependencies = [ "ahash", + "arc-swap", "bytes", "cargo-deb", "chacha20poly1305", @@ -2148,6 +2158,7 @@ dependencies = [ "futures-util", "hex", "hmac", + "ipnet", "libc", "local-ip-address", "md-5", diff --git a/Cargo.toml b/Cargo.toml index 7cbe37e..3e47c06 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,8 @@ hmac = "0.12.1" md-5 = "0.10.6" hex = "0.4.3" ahash = "0.8.12" +ipnet = "2.12.0" +arc-swap = "1.9.0" # rolling-file = { path = "../rolling-file" } [target.'cfg(unix)'.dependencies] diff --git a/src/bin/punchnet/main.rs b/src/bin/punchnet/main.rs index cb2253a..be8fec6 100755 --- a/src/bin/punchnet/main.rs +++ b/src/bin/punchnet/main.rs @@ -398,7 +398,7 @@ fn main() { fn run_it(cmd: CommandLineInput2, client_id: String, mac: Mac, system: &str, version: &str) { let rt = Runtime::new().unwrap(); match &cmd.cmd { - Commands::Start => { + Commands::Start(rtinfo) => { rt.block_on(async move { let remembered_token = get_access_token(); if remembered_token.is_none() { diff --git a/src/network/arp.rs b/src/network/arp.rs index 906e348..aeb9f50 100755 --- a/src/network/arp.rs +++ b/src/network/arp.rs @@ -15,11 +15,10 @@ use tokio::sync::{ oneshot, }; -use super::{get_edge, get_route_table, init_arp_wait_list, init_route}; +use super::{get_edge, init_arp_wait_list}; static GLOBAL_ARP: OnceCell = OnceCell::new(); pub fn init_arp() { - init_route(); init_arp_wait_list(); let actor = ArpActor::new(); GLOBAL_ARP.set(actor).unwrap(); @@ -198,9 +197,10 @@ impl ArpInfo { } if target_ip == 0 { - let route_table = get_route_table(); - if let Some(gateway_ip) = route_table.get_gateway_ip(ip) { - target_ip = gateway_ip; + // let route_table = get_route_table(); + if let Some((_prefix, gateway_ip)) = edge.route_table.route_table.lookup(ip) { + // if let Some(gateway_ip) = route_table.get_gateway_ip(ip) { + target_ip = gateway_ip.into(); } } if target_ip == 0 { diff --git a/src/network/node.rs b/src/network/node.rs index 19c3929..77858f9 100755 --- a/src/network/node.rs +++ b/src/network/node.rs @@ -12,6 +12,7 @@ use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; use tracing::{debug, error}; +use crate::network::RouteTable2; use crate::quic::quic_init; use crate::{ConnectionInfo, MyEncryptor, RuleCache, get_base_dir}; use crate::pb::{ @@ -170,6 +171,7 @@ pub struct Node { pub identity_id: IdentityID, pub rule_cache: RuleCache, + pub route_table: RouteTable2, pub access_token: StringToken, pub session_token: StringToken>, @@ -395,6 +397,8 @@ impl Node { rule_cache: RuleCache::new(), + route_table: RouteTable2::new(), + network_domain: RwLock::new(String::new()), udp_sock_for_dns: udpsock_for_dns, diff --git a/src/network/route.rs b/src/network/route.rs index c4974c1..0ac0039 100755 --- a/src/network/route.rs +++ b/src/network/route.rs @@ -1,127 +1,81 @@ -use std::{net::Ipv4Addr, sync::RwLock}; +use std::{collections::{HashMap}, net::Ipv4Addr, sync::{atomic::{AtomicBool, Ordering}}}; -use once_cell::sync::OnceCell; -use sdlan_sn_rs::utils::net_bit_len_to_mask; +use ahash::RandomState; +use dashmap::{DashMap}; +use ipnet::Ipv4Net; +use sdlan_sn_rs::utils::{Result, SDLanError}; use tracing::{debug, error}; -#[derive(Debug)] -pub struct RouteTable { - content: RwLock>, +use crate::{RouteTableTrie, network::tun::add_route}; + +pub struct RouteTable2 { + pub cache_table: DashMap<(Ipv4Net, Ipv4Addr), AtomicBool, RandomState>, + pub route_table: RouteTableTrie, } -static ROUTETABLE: OnceCell = OnceCell::new(); - -pub fn init_route() { - let rt = RouteTable::new(); - ROUTETABLE.set(rt).unwrap(); -} - -pub fn get_route_table() -> &'static RouteTable { - ROUTETABLE.get().unwrap() -} - -impl RouteTable { +impl RouteTable2 { pub fn new() -> Self { Self { - content: RwLock::new(Vec::new()), + cache_table: DashMap::with_hasher(RandomState::new()), + route_table: RouteTableTrie::new(), } } - pub fn get_gateway_ip(&self, net_ip: u32) -> Option { - let routes = self.content.read().unwrap(); - for route in &*routes { - debug!("route: {:?}", route.to_string()); - if (route.net_ip & route.net_mask) == (net_ip & route.net_mask) { - // found - return Some(route.gateway_ip); + pub fn parse_and_add_route(&self, route_str: &str) -> Result<()> { + let routes = parse_route(route_str); + for route in routes.keys() { + if self.cache_table.get(route).is_some() { + error!("route {} {} has been added", route.0.to_string(), route.1); + return Err(SDLanError::IOError(format!("route {} already added", route.0.to_string()))); } } - None - } - pub fn del_route(&self, net_ip: u32, net_mask: u32) { - let mut routes = self.content.write().unwrap(); - let mut remove_idx = routes.len(); - for i in 0..routes.len() { - let route = &routes[i]; - if route.net_ip == net_ip && route.net_mask == net_mask { - remove_idx = i; - break; - } + for route in routes.keys() { + self.cache_table.insert(*route, AtomicBool::new(false)); + self.route_table.insert(route.0.addr().into(), route.0.prefix_len(), route.1); } - if remove_idx < routes.len() { - routes.remove(remove_idx); - } - } - - pub fn add_route(&self, net_ip: u32, net_mask: u32, gateway_ip: u32) -> Result<(), String> { - { - let cnt = self.content.read().unwrap(); - let net = net_ip & net_mask; - for route in &*cnt { - if (route.net_ip & route.net_mask) == net { - return Err("route exists".to_owned()); - } - } - } - { - let mut routes = self.content.write().unwrap(); - routes.push(RouteInfo { - net_ip, - net_mask, - gateway_ip, - }) - } - Ok(()) } -} -#[derive(Debug)] -pub struct RouteInfo { - pub net_ip: u32, - pub net_mask: u32, - pub gateway_ip: u32, -} - -impl RouteInfo { - pub fn to_string(&self) -> String { - format!( - "{:?} mask={:?}, gateway={:?}", - self.net_ip.to_be_bytes(), - self.net_mask.to_be_bytes(), - self.gateway_ip.to_be_bytes() - ) + pub fn apply_system(&self) { + for route in &self.cache_table { + let origin = route.fetch_or(true, Ordering::Relaxed); + if !origin { + // should add to system + add_route(route.key().0, route.key().1); + } + } } } -// ip, mask, gateway, cidr;gateway,cidr2;gateway2 -pub fn parse_route(route: String) -> Vec<(u32, u32, u32)> { - let mut result = Vec::new(); - let routes: Vec<_> = route.split(",").collect(); - for route in routes { - let route_info: Vec<_> = route.split(";").collect(); - debug!("got route info: {:?}", route_info); - if route_info.len() != 2 { +// ip, mask, gateway, cidr gateway,cidr2 gateway2 +pub fn parse_route(route: &str) -> HashMap<(Ipv4Net, Ipv4Addr), bool> { + let mut result = HashMap::new(); + // let routes: Vec<_> = route.split(",").collect(); - error!("route info format error"); + for route in route.trim().split(",") { + let route_info: Vec<_> = route.trim().split_whitespace().collect(); + if route_info.len() != 2 { + error!("route info format error: {}", route); continue; } + debug!("got route info: {:?}", route_info); + + let Ok(gateway) = route_info[1].parse::() else { + error!("failed to parse gw: {}", route_info[1]); + continue; + }; + let cidr = route_info[0]; - let gateway = route_info[1].parse::().unwrap(); - let ip_and_mask: Vec<_> = cidr.split("/").collect(); - if ip_and_mask.len() != 2 { - error!("route info ip/bit error"); + let Ok(net )= cidr.parse::() else { + error!("failed to parse cidr: {}, skipping", cidr); continue; + }; + let origin = result.insert((net, gateway), true); + if origin.is_some() { + error!("{} {} already added", net.to_string(), gateway.to_string()); } - let ip = ip_and_mask[0].parse::().unwrap(); - let maskbit = ip_and_mask[1].parse::().unwrap(); - result.push(( - u32::from_be_bytes(ip.octets()), - net_bit_len_to_mask(maskbit), - u32::from_be_bytes(gateway.octets()), - )); } result } diff --git a/src/network/tun_linux.rs b/src/network/tun_linux.rs index 86c793a..2c40322 100755 --- a/src/network/tun_linux.rs +++ b/src/network/tun_linux.rs @@ -1,4 +1,5 @@ use etherparse::{Ethernet2Header}; +use ipnet::Ipv4Net; use sdlan_sn_rs::config::SDLAN_DEFAULT_TTL; use sdlan_sn_rs::utils::{ ip_to_string, is_ipv6_multicast, net_bit_len_to_mask, @@ -7,6 +8,7 @@ use sdlan_sn_rs::utils::{ use std::ffi::CStr; use std::ffi::{c_char, c_int}; use std::fs::{self, OpenOptions}; +use std::net::Ipv4Addr; use std::os::unix::fs::{MetadataExt, PermissionsExt}; use std::path::Path; use std::ptr::null_mut; @@ -280,7 +282,7 @@ impl TunTapPacketHandler for Iface { // is tcp } IpNumber::UDP => { - let udp = transport.tcp().unwrap(); + let udp = transport.udp().unwrap(); let out_five_tuple = FiveTuple { src_ip: ipv4.source.into(), @@ -880,5 +882,21 @@ fn restore_resolv_conf() -> Result<()> { chown(dst, Some(uid), Some(gid))?; } + Ok(()) +} + +pub fn add_route(net: Ipv4Net, gw: Ipv4Addr) -> Result<()> { + let res = Command::new("route") + .arg("add") + .arg("-net") + .arg(net.to_string()) + .arg("gw") + .arg(gw.to_string()) + .output()?; + + Ok(()) +} + +pub fn del_route() -> Result<()> { Ok(()) } \ No newline at end of file diff --git a/src/utils/acl_session.rs b/src/utils/acl_session.rs index 7b3ee9b..4163adc 100644 --- a/src/utils/acl_session.rs +++ b/src/utils/acl_session.rs @@ -2,7 +2,7 @@ use std::{net::IpAddr, sync::{Arc, atomic::{AtomicU64, Ordering}}, time::{Durati use ahash::RandomState; use dashmap::{DashMap, DashSet}; -use tracing::debug; +use tracing::{debug, error}; const RuleValidTimeInSecs: u64 = 60; @@ -70,6 +70,7 @@ impl SessionTable { pub fn retain(&self) { let now = now_secs(); + debug!("retain session"); self.table.retain(|_, info|{ let last = info.last_active.load(Ordering::Relaxed); now-last < self.timeout_secs @@ -133,11 +134,14 @@ impl RuleCache { } pub fn touch_packet(&self, info: FiveTuple) { + error!("touch a packet: {:?}", info); self.session_table.add_session_info(info); } pub fn is_identity_ok(&self, identity: IdentityID, info: FiveTuple) -> (bool, ShouldRenew) { + error!("is identity ok? {:?}", info); if self.session_table.process_packet(&info) { + error!("identity is ok"); return (true, false); } diff --git a/src/utils/command.rs b/src/utils/command.rs index 4b08484..6204715 100755 --- a/src/utils/command.rs +++ b/src/utils/command.rs @@ -22,14 +22,24 @@ pub enum Commands { /// after login, we can use start to /// connect to the remote - Start, + Start(RouteCmdInfo), Info, + RouteAdd(RouteCmdInfo), + RouteDel(RouteCmdInfo), + RouteList, + /// exits the Stop, } +#[derive(Args, Debug)] +pub struct RouteCmdInfo { + #[arg(short, long, default_value="")] + pub route: String, +} + #[derive(Args, Debug)] pub struct UserLogin { #[arg(short, long, env = APP_USER_ENV_NAME)] @@ -39,6 +49,14 @@ pub struct UserLogin { pub password: String, } +#[derive(Args, Debug)] +pub struct AutoRunTokenLogin { + #[arg(long, env=APP_TOKEN_ENV_NAME, required=false)] + pub token: String, + + #[arg(short, long, default_value="")] + pub route: String, +} #[derive(Args, Debug)] pub struct TokenLogin { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index dea7fb0..4acc172 100755 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,12 +1,14 @@ mod command; mod acl_session; mod encrypter; +mod system_action; use std::{fs::OpenOptions, io::Write, net::Ipv4Addr, path::Path}; pub use encrypter::*; pub use command::*; pub use acl_session::*; +pub use system_action::*; mod socks; use rand::Rng; diff --git a/src/utils/system_action.rs b/src/utils/system_action.rs new file mode 100644 index 0000000..fa6b62b --- /dev/null +++ b/src/utils/system_action.rs @@ -0,0 +1,91 @@ +use std::{net::Ipv4Addr, sync::Arc}; + +use arc_swap::ArcSwap; + +use crate::network::Node; + +#[derive(Default, Clone)] +pub struct TrieNode { + child: [Option>; 2], + prefix_len: u8, + nexthop: Option, +} + +#[derive(Default, Clone)] +struct IpTrie { + root: TrieNode, +} + +impl IpTrie { + fn new() -> Self { + Self { + root: TrieNode::default() + } + } + + fn insert(&mut self, prefix: u32, prefix_len: u8, nexthop: Ipv4Addr) { + if prefix_len > 32 { + return; + } + + let mut node = &mut self.root; + for i in 0..prefix_len { + let bit = ((prefix >> (31-i)) & 1) as usize; + node = node.child[bit].get_or_insert_with(|| Box::new(TrieNode::default())); + + } + if prefix_len > node.prefix_len { + node.prefix_len = prefix_len; + node.nexthop = Some(nexthop); + } + } + + fn lookup(&self, ip: u32) -> Option<(u8, Ipv4Addr)>{ + let mut node = &self.root; + let mut best = None; + + for i in 0..32 { + if node.nexthop.is_some() { + best = Some((node.prefix_len, node.nexthop.unwrap())); + } + let bit = ((ip>>(31-i)) & 1) as usize; + match &node.child[bit] { + Some(child) => { + node = child; + } + None => { + break; + } + } + } + if node.nexthop.is_some() { + best = Some((node.prefix_len, node.nexthop.unwrap())); + } + best + } +} + +pub struct RouteTableTrie { + trie: ArcSwap, +} + +impl RouteTableTrie { + pub fn new() -> Self { + Self { + trie: ArcSwap::new(Arc::new(IpTrie::default())) + } + } + + pub fn lookup(&self, ip: u32) -> Option<(u8, Ipv4Addr)> { + let trie = self.trie.load(); + trie.lookup(ip) + } + + pub fn insert(&self, prefix: u32, prefix_len: u8, nexthop: Ipv4Addr) { + let old = self.trie.load(); + let mut new_trie = (*(*old)).clone(); + + new_trie.insert(prefix, prefix_len, nexthop); + self.trie.store(Arc::new(new_trie)); + } +} \ No newline at end of file