From bdde2cb99a7d2ee64b1f35d7835cb8576a5aca3b Mon Sep 17 00:00:00 2001 From: asxalex Date: Fri, 19 Jul 2024 19:51:56 +0800 Subject: [PATCH] support the feedback --- src/bin/sdlan/main.rs | 2 +- src/network/async_main.rs | 66 +++++++++++++++++++++++++++++++------- src/network/node.rs | 67 +++++++++++++++++++++++++++++++++++---- src/tcp/tcp_codec.rs | 10 ++++++ src/tcp/tcp_conn.rs | 39 ++++++++++++++++------- 5 files changed, 155 insertions(+), 29 deletions(-) diff --git a/src/bin/sdlan/main.rs b/src/bin/sdlan/main.rs index 6f519a5..f0044f4 100644 --- a/src/bin/sdlan/main.rs +++ b/src/bin/sdlan/main.rs @@ -34,7 +34,7 @@ async fn main() { let _ = rx.recv(); let edge = get_edge(); - edge.start(cmd.token).await; + let _ = edge.start_without_feedback(cmd.token).await; /* tokio::time::sleep(Duration::from_secs(20)).await; diff --git a/src/network/async_main.rs b/src/network/async_main.rs index 0351d2a..91a8ecf 100644 --- a/src/network/async_main.rs +++ b/src/network/async_main.rs @@ -3,15 +3,16 @@ use std::time::Duration; use std::net::IpAddr; use crate::config::TCP_PING_TIME; -use crate::network::{get_edge, ping_to_sn, read_and_parse_packet}; +use crate::network::{get_edge, ping_to_sn, read_and_parse_packet, RegisterSuperFeedback}; use crate::pb::{ encode_to_tcp_message, encode_to_udp_message, SdlData, SdlDevAddr, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, SdlStunRequest, }; -use crate::tcp::{init_tcp_conn, EventType, PacketType, SdlanTcp}; +use crate::tcp::{init_tcp_conn, EventType, NakMsgCode, PacketType, SdlanTcp}; use crate::utils::{send_to_sock, CommandLine}; use etherparse::IpHeaders; use sdlan_sn_rs::config::{AF_INET, SDLAN_DEFAULT_TTL}; +use sdlan_sn_rs::packet::Register; use sdlan_sn_rs::peer::SdlanSock; use sdlan_sn_rs::utils::{ aes_encrypt, get_current_timestamp, @@ -22,7 +23,7 @@ use tokio::io::AsyncWriteExt; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio_util::sync::CancellationToken; -use super::{check_peer_registration_needed, packet, Node }; +use super::{check_peer_registration_needed, packet, Node, StartStopInfo }; use crate::utils::Socket; use prost::Message; @@ -37,6 +38,10 @@ async fn handle_tcp_message(msg: SdlanTcp) { debug!("got tcp message: {:?}", msg.packet_type); match msg.packet_type { PacketType::RegisterSuperACK => { + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 0, + message: "".to_owned(), + }); let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else { error!("failed to decode REGISTER_SUPER_ACK"); return; @@ -74,12 +79,47 @@ async fn handle_tcp_message(msg: SdlanTcp) { PacketType::RegisterSuperNAK => { let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else { error!("failed to decode REGISTER_SUPER_NAK"); + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 1, + message: "failed to decode REGISTER SUPER NAK".to_owned(), + }); return; }; - error!( - "RegisterSuperNAK received: {}({})", - _nak.error_code, _nak.error_message - ); + + let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else { + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 2, + message: "error_code not recognized".to_owned(), + }); + return; + }; + match error_code { + NakMsgCode::InvalidToken => { + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 3, + message: "invalid token".to_owned(), + }); + } + NakMsgCode::NodeDisabled => { + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 4, + message: "Node is disabled".to_owned(), + }); + } + _other => { + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 0, + message: "".to_owned(), + + }); + } + } + /* + edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { + result: 1, + message: "failed to decode REGISTER SUPER NAK".to_owned(), + }); + */ edge.set_authorized(false, Vec::new()); // std::process::exit(0); } @@ -148,7 +188,7 @@ async fn handle_tcp_event(edge: &Node, eventtype: EventType, eventprotobuf: &[u8 pub async fn async_main( install_channel: String, args: CommandLine, - start_stop_chan: Receiver, + start_stop_chan: Receiver, cancel: CancellationToken, ) -> Result<()> { // let _ = PidRecorder::new(".pid"); @@ -185,9 +225,9 @@ pub async fn async_main( // let token = args.token.clone(); init_tcp_conn( &args.tcp, - move |stream| { + move |stream, pkt_id| { let installed_channel = install_channel.to_owned(); - Box::pin(async { + Box::pin(async move { let token = edge._token.lock().unwrap().clone(); // let edge = get_edge(); // let edge = get_edge(); @@ -215,7 +255,11 @@ pub async fn async_main( token, }; // debug!("send register super: {:?}", register_super); - let packet_id = edge.get_next_packet_id(); + let packet_id = match pkt_id { + Some(id) => id, + None => edge.get_next_packet_id(), + }; + // let packet_id = edge.get_next_packet_id(); let data = encode_to_tcp_message( Some(register_super), packet_id, diff --git a/src/network/node.rs b/src/network/node.rs index 9990318..30e1bcc 100644 --- a/src/network/node.rs +++ b/src/network/node.rs @@ -33,7 +33,7 @@ pub async fn init_edge( token: &str, node_conf: NodeConfig, tos: u32, - start_stop: Sender, + start_stop: Sender, ) -> Result<()> { let _ = PidRecorder::new(".pid"); @@ -89,6 +89,17 @@ pub fn get_edge() -> &'static Node { edge.unwrap() } +pub struct RegisterSuperFeedback { + // 0 for success, other for error + pub result: u8, + pub message: String, +} + +pub struct StartStopInfo { + pub is_start: bool, + pub pkt_id: Option, +} + pub struct Node { packet_id: AtomicU32, @@ -96,7 +107,7 @@ pub struct Node { pub tcp_pong: Arc, - start_stop_sender: Sender, + start_stop_sender: Sender, // user token info pub _token: Mutex, @@ -139,22 +150,65 @@ pub struct Node { nat_cookie: AtomicU32, cookie_match: DashMap>, + + packet_id_match: DashMap>, } unsafe impl Sync for Node {} impl Node { + pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) { + match self.packet_id_match.remove(&pktid) { + Some(sender) => { + let _ = sender.1.send(feed); + } + None => { + return; + } + } + } + pub fn get_nat_type(&self) -> NatType { self.nat_type.lock().unwrap().clone() } - pub async fn start(&self, token: String) { + + pub async fn start_without_feedback(&self, token: String) -> Result<()> { *self._token.lock().unwrap() = token; - let _ = self.start_stop_sender.send(true).await; + let _ = self.start_stop_sender.send(StartStopInfo{ + is_start: true, + pkt_id: None, + }).await; + Ok(()) + } + + pub async fn start_with_feedback(&self, token: String, timeout: Duration) -> Result { + *self._token.lock().unwrap() = token; + let (tx, mut rx) = oneshot::channel(); + let id = self.get_next_packet_id(); + self.packet_id_match.insert(id, tx); + let _ = self.start_stop_sender.send(StartStopInfo{ + is_start: true, + pkt_id: Some(id), + }).await; + + tokio::select! { + Ok(result) = rx => { + self.packet_id_match.remove(&id); + Ok(result) + } + _ = tokio::time::sleep(timeout) => { + Err(SDLanError::NormalError("timed out")) + } + } + } pub async fn stop(&self) { *self._token.lock().unwrap() = "".to_owned(); - let _ = self.start_stop_sender.send(false).await; + let _ = self.start_stop_sender.send(StartStopInfo{ + is_start: false, + pkt_id: None, + }).await; } pub fn new( @@ -166,7 +220,7 @@ impl Node { token: &str, private: RsaPrivateKey, tcp_pong: Arc, - start_stop: Sender, + start_stop: Sender, ) -> Self { Self { packet_id: AtomicU32::new(1), @@ -209,6 +263,7 @@ impl Node { stats: NodeStats::new(), _last_register_req: AtomicU64::new(0), + packet_id_match: DashMap::new(), nat_cookie: AtomicU32::new(1), cookie_match: DashMap::new(), } diff --git a/src/tcp/tcp_codec.rs b/src/tcp/tcp_codec.rs index c3530ef..355e551 100644 --- a/src/tcp/tcp_codec.rs +++ b/src/tcp/tcp_codec.rs @@ -46,6 +46,16 @@ pub enum NatType { Symmetric = 5, } +#[derive(Debug, Copy, Clone, TryFromPrimitive)] +#[repr(u8)] +pub enum NakMsgCode { + InvalidToken = 1, + NodeDisabled = 2, + NoIpAddress = 3, + NetworkFault = 4, + InternalFault = 5, +} + #[derive(Debug, Copy, Clone, TryFromPrimitive)] #[repr(u8)] pub enum PacketType { diff --git a/src/tcp/tcp_conn.rs b/src/tcp/tcp_conn.rs index f9ad6c1..ccf8db5 100644 --- a/src/tcp/tcp_conn.rs +++ b/src/tcp/tcp_conn.rs @@ -19,6 +19,7 @@ use tokio::{ use tracing::error; use crate::config::TCP_PING_TIME; +use crate::network::StartStopInfo; use crate::tcp::read_a_packet; use super::tcp_codec::SdlanTcp; @@ -57,23 +58,25 @@ impl ReadWriteActor { mut to_tcp: Receiver>, on_connected: T, on_disconnected: T2, - mut start_stop_chan: Receiver, + mut start_stop_chan: Receiver, // cancel: CancellationToken, ) where - T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()>, + T: for<'b> Fn(&'b mut TcpStream, Option) -> BoxFuture<'b, ()>, T2: Fn() -> F, F: Future, { // let (tx, rx) = channel(20); let mut started = false; + let mut start_pkt_id = None; loop { self.connected.store(false, Ordering::Relaxed); if !started { // println!("waiting for start"); while let Some(m) = start_stop_chan.recv().await { - if m { + if m.is_start { // println!("start received"); started = true; + start_pkt_id = m.pkt_id; break; } else { // println!("stop received"); @@ -97,7 +100,7 @@ impl ReadWriteActor { return; }; self.connected.store(true, Ordering::Relaxed); - on_connected(&mut stream).await; + on_connected(&mut stream, start_pkt_id.take()).await; // stream.write("hello".as_bytes()).await; let (reader, mut write) = stream.into_split(); @@ -149,15 +152,29 @@ impl ReadWriteActor { } }; + let check_stop = async { + loop { + match start_stop_chan.recv().await { + Some(v) => { + if !v.is_start { + return; + } + } + other => { + // send chan is closed; + return; + } + } + } + }; + pin_mut!(read_from_tcp, write_to_tcp); tokio::select! { _ = read_from_tcp => {}, _ = write_to_tcp => {}, _ = check_pong => {}, - Some(false) = start_stop_chan.recv() => { - started = false; - } + _ = check_stop => {}, } on_disconnected().await; debug!("connect retrying"); @@ -196,11 +213,11 @@ impl ReadWriterHandle { on_disconnected: T3, on_message: T2, pong_time: Arc, - start_stop_chan: Receiver, + start_stop_chan: Receiver, // cancel: CancellationToken, ) -> Self where - T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static, + T: for<'b> Fn(&'b mut TcpStream, Option) -> BoxFuture<'b, ()> + Send + 'static, T3: Fn() -> F2 + Send + 'static, T2: Fn(SdlanTcp) -> F + Send + 'static, F: Future + Send, @@ -242,9 +259,9 @@ pub fn init_tcp_conn<'a, T, T3, T2, F, F2>( on_message: T2, pong_time: Arc, // cancel: CancellationToken, - start_stop_chan: Receiver, + start_stop_chan: Receiver, ) where - T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static, + T: for<'b> Fn(&'b mut TcpStream, Option) -> BoxFuture<'b, ()> + Send + 'static, T3: Fn() -> F2 + Send + 'static, T2: Fn(SdlanTcp) -> F + Send + 'static, F: Future + Send,