support the feedback

This commit is contained in:
asxalex 2024-07-19 19:51:56 +08:00
parent 5408cfa286
commit bdde2cb99a
5 changed files with 155 additions and 29 deletions

View File

@ -34,7 +34,7 @@ async fn main() {
let _ = rx.recv(); let _ = rx.recv();
let edge = get_edge(); let edge = get_edge();
edge.start(cmd.token).await; let _ = edge.start_without_feedback(cmd.token).await;
/* /*
tokio::time::sleep(Duration::from_secs(20)).await; tokio::time::sleep(Duration::from_secs(20)).await;

View File

@ -3,15 +3,16 @@ use std::time::Duration;
use std::net::IpAddr; use std::net::IpAddr;
use crate::config::TCP_PING_TIME; use crate::config::TCP_PING_TIME;
use crate::network::{get_edge, ping_to_sn, read_and_parse_packet}; use crate::network::{get_edge, ping_to_sn, read_and_parse_packet, RegisterSuperFeedback};
use crate::pb::{ use crate::pb::{
encode_to_tcp_message, encode_to_udp_message, SdlData, SdlDevAddr, SdlRegisterSuper, encode_to_tcp_message, encode_to_udp_message, SdlData, SdlDevAddr, SdlRegisterSuper,
SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, SdlStunRequest, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, SdlStunRequest,
}; };
use crate::tcp::{init_tcp_conn, EventType, PacketType, SdlanTcp}; use crate::tcp::{init_tcp_conn, EventType, NakMsgCode, PacketType, SdlanTcp};
use crate::utils::{send_to_sock, CommandLine}; use crate::utils::{send_to_sock, CommandLine};
use etherparse::IpHeaders; use etherparse::IpHeaders;
use sdlan_sn_rs::config::{AF_INET, SDLAN_DEFAULT_TTL}; use sdlan_sn_rs::config::{AF_INET, SDLAN_DEFAULT_TTL};
use sdlan_sn_rs::packet::Register;
use sdlan_sn_rs::peer::SdlanSock; use sdlan_sn_rs::peer::SdlanSock;
use sdlan_sn_rs::utils::{ use sdlan_sn_rs::utils::{
aes_encrypt, get_current_timestamp, aes_encrypt, get_current_timestamp,
@ -22,7 +23,7 @@ use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::{check_peer_registration_needed, packet, Node }; use super::{check_peer_registration_needed, packet, Node, StartStopInfo };
use crate::utils::Socket; use crate::utils::Socket;
use prost::Message; use prost::Message;
@ -37,6 +38,10 @@ async fn handle_tcp_message(msg: SdlanTcp) {
debug!("got tcp message: {:?}", msg.packet_type); debug!("got tcp message: {:?}", msg.packet_type);
match msg.packet_type { match msg.packet_type {
PacketType::RegisterSuperACK => { PacketType::RegisterSuperACK => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
});
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else { let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK"); error!("failed to decode REGISTER_SUPER_ACK");
return; return;
@ -74,12 +79,47 @@ async fn handle_tcp_message(msg: SdlanTcp) {
PacketType::RegisterSuperNAK => { PacketType::RegisterSuperNAK => {
let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else { let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_NAK"); error!("failed to decode REGISTER_SUPER_NAK");
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
return; return;
}; };
error!(
"RegisterSuperNAK received: {}({})", let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else {
_nak.error_code, _nak.error_message edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
); result: 2,
message: "error_code not recognized".to_owned(),
});
return;
};
match error_code {
NakMsgCode::InvalidToken => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 3,
message: "invalid token".to_owned(),
});
}
NakMsgCode::NodeDisabled => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 4,
message: "Node is disabled".to_owned(),
});
}
_other => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
});
}
}
/*
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
*/
edge.set_authorized(false, Vec::new()); edge.set_authorized(false, Vec::new());
// std::process::exit(0); // std::process::exit(0);
} }
@ -148,7 +188,7 @@ async fn handle_tcp_event(edge: &Node, eventtype: EventType, eventprotobuf: &[u8
pub async fn async_main( pub async fn async_main(
install_channel: String, install_channel: String,
args: CommandLine, args: CommandLine,
start_stop_chan: Receiver<bool>, start_stop_chan: Receiver<StartStopInfo>,
cancel: CancellationToken, cancel: CancellationToken,
) -> Result<()> { ) -> Result<()> {
// let _ = PidRecorder::new(".pid"); // let _ = PidRecorder::new(".pid");
@ -185,9 +225,9 @@ pub async fn async_main(
// let token = args.token.clone(); // let token = args.token.clone();
init_tcp_conn( init_tcp_conn(
&args.tcp, &args.tcp,
move |stream| { move |stream, pkt_id| {
let installed_channel = install_channel.to_owned(); let installed_channel = install_channel.to_owned();
Box::pin(async { Box::pin(async move {
let token = edge._token.lock().unwrap().clone(); let token = edge._token.lock().unwrap().clone();
// let edge = get_edge(); // let edge = get_edge();
// let edge = get_edge(); // let edge = get_edge();
@ -215,7 +255,11 @@ pub async fn async_main(
token, token,
}; };
// debug!("send register super: {:?}", register_super); // debug!("send register super: {:?}", register_super);
let packet_id = edge.get_next_packet_id(); let packet_id = match pkt_id {
Some(id) => id,
None => edge.get_next_packet_id(),
};
// let packet_id = edge.get_next_packet_id();
let data = encode_to_tcp_message( let data = encode_to_tcp_message(
Some(register_super), Some(register_super),
packet_id, packet_id,

View File

@ -33,7 +33,7 @@ pub async fn init_edge(
token: &str, token: &str,
node_conf: NodeConfig, node_conf: NodeConfig,
tos: u32, tos: u32,
start_stop: Sender<bool>, start_stop: Sender<StartStopInfo>,
) -> Result<()> { ) -> Result<()> {
let _ = PidRecorder::new(".pid"); let _ = PidRecorder::new(".pid");
@ -89,6 +89,17 @@ pub fn get_edge() -> &'static Node {
edge.unwrap() edge.unwrap()
} }
pub struct RegisterSuperFeedback {
// 0 for success, other for error
pub result: u8,
pub message: String,
}
pub struct StartStopInfo {
pub is_start: bool,
pub pkt_id: Option<u32>,
}
pub struct Node { pub struct Node {
packet_id: AtomicU32, packet_id: AtomicU32,
@ -96,7 +107,7 @@ pub struct Node {
pub tcp_pong: Arc<AtomicU64>, pub tcp_pong: Arc<AtomicU64>,
start_stop_sender: Sender<bool>, start_stop_sender: Sender<StartStopInfo>,
// user token info // user token info
pub _token: Mutex<String>, pub _token: Mutex<String>,
@ -139,22 +150,65 @@ pub struct Node {
nat_cookie: AtomicU32, nat_cookie: AtomicU32,
cookie_match: DashMap<u32, oneshot::Sender<SdlStunProbeReply>>, cookie_match: DashMap<u32, oneshot::Sender<SdlStunProbeReply>>,
packet_id_match: DashMap<u32, oneshot::Sender<RegisterSuperFeedback>>,
} }
unsafe impl Sync for Node {} unsafe impl Sync for Node {}
impl Node { impl Node {
pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) {
match self.packet_id_match.remove(&pktid) {
Some(sender) => {
let _ = sender.1.send(feed);
}
None => {
return;
}
}
}
pub fn get_nat_type(&self) -> NatType { pub fn get_nat_type(&self) -> NatType {
self.nat_type.lock().unwrap().clone() self.nat_type.lock().unwrap().clone()
} }
pub async fn start(&self, token: String) {
pub async fn start_without_feedback(&self, token: String) -> Result<()> {
*self._token.lock().unwrap() = token; *self._token.lock().unwrap() = token;
let _ = self.start_stop_sender.send(true).await; let _ = self.start_stop_sender.send(StartStopInfo{
is_start: true,
pkt_id: None,
}).await;
Ok(())
}
pub async fn start_with_feedback(&self, token: String, timeout: Duration) -> Result<RegisterSuperFeedback> {
*self._token.lock().unwrap() = token;
let (tx, mut rx) = oneshot::channel();
let id = self.get_next_packet_id();
self.packet_id_match.insert(id, tx);
let _ = self.start_stop_sender.send(StartStopInfo{
is_start: true,
pkt_id: Some(id),
}).await;
tokio::select! {
Ok(result) = rx => {
self.packet_id_match.remove(&id);
Ok(result)
}
_ = tokio::time::sleep(timeout) => {
Err(SDLanError::NormalError("timed out"))
}
}
} }
pub async fn stop(&self) { pub async fn stop(&self) {
*self._token.lock().unwrap() = "".to_owned(); *self._token.lock().unwrap() = "".to_owned();
let _ = self.start_stop_sender.send(false).await; let _ = self.start_stop_sender.send(StartStopInfo{
is_start: false,
pkt_id: None,
}).await;
} }
pub fn new( pub fn new(
@ -166,7 +220,7 @@ impl Node {
token: &str, token: &str,
private: RsaPrivateKey, private: RsaPrivateKey,
tcp_pong: Arc<AtomicU64>, tcp_pong: Arc<AtomicU64>,
start_stop: Sender<bool>, start_stop: Sender<StartStopInfo>,
) -> Self { ) -> Self {
Self { Self {
packet_id: AtomicU32::new(1), packet_id: AtomicU32::new(1),
@ -209,6 +263,7 @@ impl Node {
stats: NodeStats::new(), stats: NodeStats::new(),
_last_register_req: AtomicU64::new(0), _last_register_req: AtomicU64::new(0),
packet_id_match: DashMap::new(),
nat_cookie: AtomicU32::new(1), nat_cookie: AtomicU32::new(1),
cookie_match: DashMap::new(), cookie_match: DashMap::new(),
} }

View File

@ -46,6 +46,16 @@ pub enum NatType {
Symmetric = 5, Symmetric = 5,
} }
#[derive(Debug, Copy, Clone, TryFromPrimitive)]
#[repr(u8)]
pub enum NakMsgCode {
InvalidToken = 1,
NodeDisabled = 2,
NoIpAddress = 3,
NetworkFault = 4,
InternalFault = 5,
}
#[derive(Debug, Copy, Clone, TryFromPrimitive)] #[derive(Debug, Copy, Clone, TryFromPrimitive)]
#[repr(u8)] #[repr(u8)]
pub enum PacketType { pub enum PacketType {

View File

@ -19,6 +19,7 @@ use tokio::{
use tracing::error; use tracing::error;
use crate::config::TCP_PING_TIME; use crate::config::TCP_PING_TIME;
use crate::network::StartStopInfo;
use crate::tcp::read_a_packet; use crate::tcp::read_a_packet;
use super::tcp_codec::SdlanTcp; use super::tcp_codec::SdlanTcp;
@ -57,23 +58,25 @@ impl ReadWriteActor {
mut to_tcp: Receiver<Vec<u8>>, mut to_tcp: Receiver<Vec<u8>>,
on_connected: T, on_connected: T,
on_disconnected: T2, on_disconnected: T2,
mut start_stop_chan: Receiver<bool>, mut start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken, // cancel: CancellationToken,
) where ) where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()>, T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()>,
T2: Fn() -> F, T2: Fn() -> F,
F: Future<Output = ()>, F: Future<Output = ()>,
{ {
// let (tx, rx) = channel(20); // let (tx, rx) = channel(20);
let mut started = false; let mut started = false;
let mut start_pkt_id = None;
loop { loop {
self.connected.store(false, Ordering::Relaxed); self.connected.store(false, Ordering::Relaxed);
if !started { if !started {
// println!("waiting for start"); // println!("waiting for start");
while let Some(m) = start_stop_chan.recv().await { while let Some(m) = start_stop_chan.recv().await {
if m { if m.is_start {
// println!("start received"); // println!("start received");
started = true; started = true;
start_pkt_id = m.pkt_id;
break; break;
} else { } else {
// println!("stop received"); // println!("stop received");
@ -97,7 +100,7 @@ impl ReadWriteActor {
return; return;
}; };
self.connected.store(true, Ordering::Relaxed); self.connected.store(true, Ordering::Relaxed);
on_connected(&mut stream).await; on_connected(&mut stream, start_pkt_id.take()).await;
// stream.write("hello".as_bytes()).await; // stream.write("hello".as_bytes()).await;
let (reader, mut write) = stream.into_split(); let (reader, mut write) = stream.into_split();
@ -149,15 +152,29 @@ impl ReadWriteActor {
} }
}; };
let check_stop = async {
loop {
match start_stop_chan.recv().await {
Some(v) => {
if !v.is_start {
return;
}
}
other => {
// send chan is closed;
return;
}
}
}
};
pin_mut!(read_from_tcp, write_to_tcp); pin_mut!(read_from_tcp, write_to_tcp);
tokio::select! { tokio::select! {
_ = read_from_tcp => {}, _ = read_from_tcp => {},
_ = write_to_tcp => {}, _ = write_to_tcp => {},
_ = check_pong => {}, _ = check_pong => {},
Some(false) = start_stop_chan.recv() => { _ = check_stop => {},
started = false;
}
} }
on_disconnected().await; on_disconnected().await;
debug!("connect retrying"); debug!("connect retrying");
@ -196,11 +213,11 @@ impl ReadWriterHandle {
on_disconnected: T3, on_disconnected: T3,
on_message: T2, on_message: T2,
pong_time: Arc<AtomicU64>, pong_time: Arc<AtomicU64>,
start_stop_chan: Receiver<bool>, start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken, // cancel: CancellationToken,
) -> Self ) -> Self
where where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static, T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()> + Send + 'static,
T3: Fn() -> F2 + Send + 'static, T3: Fn() -> F2 + Send + 'static,
T2: Fn(SdlanTcp) -> F + Send + 'static, T2: Fn(SdlanTcp) -> F + Send + 'static,
F: Future<Output = ()> + Send, F: Future<Output = ()> + Send,
@ -242,9 +259,9 @@ pub fn init_tcp_conn<'a, T, T3, T2, F, F2>(
on_message: T2, on_message: T2,
pong_time: Arc<AtomicU64>, pong_time: Arc<AtomicU64>,
// cancel: CancellationToken, // cancel: CancellationToken,
start_stop_chan: Receiver<bool>, start_stop_chan: Receiver<StartStopInfo>,
) where ) where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static, T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()> + Send + 'static,
T3: Fn() -> F2 + Send + 'static, T3: Fn() -> F2 + Send + 'static,
T2: Fn(SdlanTcp) -> F + Send + 'static, T2: Fn(SdlanTcp) -> F + Send + 'static,
F: Future<Output = ()> + Send, F: Future<Output = ()> + Send,