diff --git a/src/bin/punchnet/main.rs b/src/bin/punchnet/main.rs index 7a611f2..582a3ca 100755 --- a/src/bin/punchnet/main.rs +++ b/src/bin/punchnet/main.rs @@ -354,7 +354,7 @@ fn main() { } } - let should_daemonize = true; + let should_daemonize = false; #[cfg(not(target_os = "windows"))] if should_daemonize { diff --git a/src/network/async_main.rs b/src/network/async_main.rs index 9ae80fa..b35444a 100755 --- a/src/network/async_main.rs +++ b/src/network/async_main.rs @@ -250,7 +250,7 @@ async fn loop_tap(eee: &'static Node, cancel: CancellationToken) { } } } - debug!("loop_tap exited"); + error!("loop_tap exited"); } async fn get_tun_flow(eee: &'static Node, tx: Sender>) { diff --git a/src/network/node.rs b/src/network/node.rs index 6e78940..867e4ec 100755 --- a/src/network/node.rs +++ b/src/network/node.rs @@ -1,9 +1,12 @@ use arc_swap::ArcSwap; use dashmap::DashMap; +use prost::Message; use quinn::Endpoint; use rsa::RsaPrivateKey; use sdlan_sn_rs::config::{AF_INET, AF_INET6}; use tokio::net::UdpSocket; +use std::any::Any; +use std::future::Future; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicU8, Ordering}; use std::sync::{Arc, Mutex, RwLock}; @@ -233,15 +236,20 @@ pub struct Node { nat_type: Mutex, nat_cookie: AtomicU32, - cookie_match: DashMap>, + // cookie_match: DashMap>, + cookie_match: Queryer, - packet_id_match: DashMap>, + // packet_id_match: Queryer, + // packet_id_match: DashMap>, } unsafe impl Sync for Node {} impl Node { pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) { + // self.packet_id_match.write_feedback(pktid, feed); + self.cookie_match.write_feedback(pktid, feed); + /* match self.packet_id_match.remove(&pktid) { Some(sender) => { let _ = sender.1.send(feed); @@ -250,6 +258,7 @@ impl Node { return; } } + */ } pub fn get_nat_type(&self) -> NatType { @@ -325,8 +334,30 @@ impl Node { self.identity_id.store(identity_id); // *self._token.lock().unwrap() = token; // *self.network_code.lock().unwrap() = network_code; - let (tx, rx) = oneshot::channel(); let id = self.get_next_packet_id(); + + // let result = self.packet_id_match.do_action_and_wait_for( + let result = self.cookie_match.do_action_and_wait_for( + id, + || async { + let _ = self + .start_stop_sender + .send(StartStopInfo { + is_start: true, + pkt_id: Some(id), + }) + .await; + debug!("start with feedback"); + }, + timeout + ).await?; + + if let Ok(data) = result.downcast() { + return Ok(*data); + } + Err(SDLanError::NormalError("timed out")) + /* + let (tx, rx) = oneshot::channel(); self.packet_id_match.insert(id, tx); let _ = self .start_stop_sender @@ -350,6 +381,7 @@ impl Node { Err(SDLanError::NormalError("timed out")) } } + */ } pub async fn stop(&self) { @@ -453,9 +485,9 @@ impl Node { stats: NodeStats::new(), _last_register_req: AtomicU64::new(0), - packet_id_match: DashMap::new(), + // packet_id_match: Queryer::new(), nat_cookie: AtomicU32::new(1), - cookie_match: DashMap::new(), + cookie_match: Queryer::new(), server_ip, install_channel, } @@ -537,11 +569,14 @@ impl Node { */ pub async fn send_nat_probe_reply(&self, cookie: u32, buf: SdlStunProbeReply) { + self.cookie_match.write_feedback(cookie, buf); + /* if let Some((_key, chan)) = self.cookie_match.remove(&cookie) { let _ = chan.send(buf); return; } - error!("failed to get such cookie stun probe"); + */ + // error!("failed to get such cookie stun probe"); } pub async fn probe_nat_type(&self) -> NatType { @@ -619,38 +654,30 @@ impl Node { to_server: &SocketAddr, ) -> Result { let cookie = self.nat_cookie.fetch_add(1, Ordering::Relaxed); - let probe = SdlStunProbe { - attr: msgattr as u32, - cookie, - step: 0, - }; // println!("==> sending probe request: {:?}", probe); - let (tx, rx) = oneshot::channel(); - self.cookie_match.insert(cookie, tx); - // let cookie = msg.cookie; - let msg = encode_to_udp_message(Some(probe), PacketType::StunProbe as u8).unwrap(); - if let Err(_e) = self.udp_sock_v4.send_to(&msg, to_server).await { - self.cookie_match.remove(&cookie); - return Err(SDLanError::NormalError("send error")); + let res = self.cookie_match.do_action_and_wait_for( + cookie, + || async { + let probe = SdlStunProbe { + attr: msgattr as u32, + cookie, + step: 0, + }; + let msg = encode_to_udp_message(Some(probe), PacketType::StunProbe as u8).unwrap(); + if let Err(_e) = self.udp_sock_v4.send_to(&msg, to_server).await { + error!("failed to send StunProbe"); + } + }, + Duration::from_secs(3), + ).await?; + if let Ok(data) = res.downcast() { + return Ok(*data); } - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(3)) => { - self.cookie_match.remove(&cookie); - return Err(SDLanError::NormalError("timed out")); - } - reply = rx => { - self.cookie_match.remove(&cookie); - if let Ok(reply) = reply { - // reply received, - return Ok(reply); - // println!("got nat ip: {}:{}", ip_to_string(&reply.ip), reply.port); - } - return Err(SDLanError::NormalError("reply recv error")); + + Err(SDLanError::NormalError("reply recv error")) // step 1 received - } - } } } @@ -833,3 +860,110 @@ impl EdgePeer { self.nat_type } } + +type BoxedProstMessage = Box; + +pub struct Queryer { + pub mailbox: DashMap>, +} + +impl Queryer { + pub fn new() -> Self { + Self { + mailbox: DashMap::new(), + } + } + + pub fn write_feedback(&self, id: u32, data: T) { + if let Some((_, tx)) = self.mailbox.remove(&id) { + if let Err(e) = tx.send(Box::new(data)) { + error!("failed to write feedback"); + } + } + } + + pub async fn send_message_to_udp_and_wait_for(&self, sock: &Socket, id: u32, message: T, packet_type: u8, to_server: &SocketAddr, timeout: Duration) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.mailbox.insert(id, tx); + + let content = encode_to_udp_message(Some(message), packet_type)?; + + if let Err(_e) = sock.send_to(&content, to_server).await { + self.mailbox.remove(&id); + return Err(SDLanError::NormalError("send error")); + } + + let quic_conn = get_quic_write_conn(); + quic_conn.send(content).await?; + + tokio::select! { + data = rx => { + if let Ok(data) = data { + self.mailbox.remove(&id); + Ok(data) + } else { + self.mailbox.remove(&id); + Err(SDLanError::IOError("rx receive failed".to_string())) + } + } + _ = tokio::time::sleep(timeout) => { + self.mailbox.remove(&id); + Err(SDLanError::IOError("timed out".to_string())) + } + } + + } + + pub async fn do_action_and_wait_for(&self, id: u32, action: T, timeout: Duration) -> Result + where + F: Future, + T: Fn() -> F, + { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.mailbox.insert(id, tx); + + action().await; + + tokio::select! { + data = rx => { + if let Ok(data) = data { + self.mailbox.remove(&id); + Ok(data) + } else { + self.mailbox.remove(&id); + Err(SDLanError::IOError("rx receive failed".to_string())) + } + } + _ = tokio::time::sleep(timeout) => { + self.mailbox.remove(&id); + Err(SDLanError::IOError("timed out".to_string())) + } + } + + } + + pub async fn send_message_to_quic_and_wait_for(&self, id: u32, message: T, packet_type: u8, timeout: Duration) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.mailbox.insert(id, tx); + + let content = encode_to_tcp_message(Some(message), packet_type)?; + let quic_conn = get_quic_write_conn(); + quic_conn.send(content).await?; + + tokio::select! { + data = rx => { + if let Ok(data) = data { + self.mailbox.remove(&id); + Ok(data) + } else { + self.mailbox.remove(&id); + Err(SDLanError::IOError("rx receive failed".to_string())) + } + } + _ = tokio::time::sleep(timeout) => { + self.mailbox.remove(&id); + Err(SDLanError::IOError("timed out".to_string())) + } + } + } +} \ No newline at end of file diff --git a/src/network/route.rs b/src/network/route.rs index 0ac0039..4bcc478 100755 --- a/src/network/route.rs +++ b/src/network/route.rs @@ -1,13 +1,15 @@ -use std::{collections::{HashMap}, net::Ipv4Addr, sync::{atomic::{AtomicBool, Ordering}}}; +use std::{collections::HashMap, net::Ipv4Addr, sync::atomic::{AtomicBool, Ordering}, time::Duration}; use ahash::RandomState; use dashmap::{DashMap}; use ipnet::Ipv4Net; use sdlan_sn_rs::utils::{Result, SDLanError}; +use tokio::sync::oneshot::{Receiver, Sender, channel}; use tracing::{debug, error}; -use crate::{RouteTableTrie, network::tun::add_route}; +use crate::{RouteTableTrie, network::tun::add_route, pb::{SdlArpResponse, SdlStunReply}}; + pub struct RouteTable2 { pub cache_table: DashMap<(Ipv4Net, Ipv4Addr), AtomicBool, RandomState>, diff --git a/src/tcp/quic.rs b/src/tcp/quic.rs index 8ec79ad..1332c8e 100644 --- a/src/tcp/quic.rs +++ b/src/tcp/quic.rs @@ -36,7 +36,7 @@ impl ReadWriterHandle { Ok(()) } - fn new<>( + fn new( cancel: CancellationToken, addr: &str, // on_connected: OnConnectedCallback<'a>, @@ -78,8 +78,10 @@ impl ReadWriterHandle { loop { if let Some(msg) = data_from_tcp.recv().await { handle_tcp_message(msg).await; + println!("handle_tcp_message ok"); } else { error!("data from tcp exited"); + println!("data from tcp exited"); // eprintln!("data from tcp exited"); return; } @@ -109,6 +111,7 @@ async fn handle_tcp_message(msg: SdlanTcp) { // edge.tcp_pong.store(now, Ordering::Relaxed); debug!("got tcp message: {:?}", msg.packet_type); + println!("got tcp message: {:?}", msg.packet_type); match msg.packet_type { PacketType::RegisterSuperACK => { let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else { @@ -116,6 +119,7 @@ async fn handle_tcp_message(msg: SdlanTcp) { return; }; + debug!("got register super ack1: {:?}", ack); edge.send_register_super_feedback( ack.pkt_id, RegisterSuperFeedback { @@ -192,12 +196,14 @@ async fn handle_tcp_message(msg: SdlanTcp) { debug!("nat type is {:?}", nattype); // println!("nat type is: {:?}", nattype); }); + println!("register message handled"); } PacketType::ArpResponse => { let Ok(resp) = SdlArpResponse::decode(&msg.current_packet[..]) else { error!("failed to decode ARP RESPONSE"); return; }; + debug!("got arp response: {:?}", resp); if resp.target_mac.len() != 6 { // invalid target_mac error!("invalid target_mac: {:?}, ip={}", resp.target_mac, ip_to_string(&resp.target_ip)); @@ -518,6 +524,7 @@ impl ReadWriteActor { } } else { // None, just return + println!("start or stop is None"); return; } } @@ -625,6 +632,7 @@ impl ReadWriteActor { let write_to_tcp = async { while let Some(data) = to_tcp.recv().await { + debug!("data size = {}", data.len()); match send.write(&data).await { Ok(size) => { debug!("{} bytes sent to tcp", size);