support the feedback

This commit is contained in:
asxalex 2024-07-19 19:51:56 +08:00
parent 5408cfa286
commit bdde2cb99a
5 changed files with 155 additions and 29 deletions

View File

@ -34,7 +34,7 @@ async fn main() {
let _ = rx.recv();
let edge = get_edge();
edge.start(cmd.token).await;
let _ = edge.start_without_feedback(cmd.token).await;
/*
tokio::time::sleep(Duration::from_secs(20)).await;

View File

@ -3,15 +3,16 @@ use std::time::Duration;
use std::net::IpAddr;
use crate::config::TCP_PING_TIME;
use crate::network::{get_edge, ping_to_sn, read_and_parse_packet};
use crate::network::{get_edge, ping_to_sn, read_and_parse_packet, RegisterSuperFeedback};
use crate::pb::{
encode_to_tcp_message, encode_to_udp_message, SdlData, SdlDevAddr, SdlRegisterSuper,
SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, SdlStunRequest,
};
use crate::tcp::{init_tcp_conn, EventType, PacketType, SdlanTcp};
use crate::tcp::{init_tcp_conn, EventType, NakMsgCode, PacketType, SdlanTcp};
use crate::utils::{send_to_sock, CommandLine};
use etherparse::IpHeaders;
use sdlan_sn_rs::config::{AF_INET, SDLAN_DEFAULT_TTL};
use sdlan_sn_rs::packet::Register;
use sdlan_sn_rs::peer::SdlanSock;
use sdlan_sn_rs::utils::{
aes_encrypt, get_current_timestamp,
@ -22,7 +23,7 @@ use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_util::sync::CancellationToken;
use super::{check_peer_registration_needed, packet, Node };
use super::{check_peer_registration_needed, packet, Node, StartStopInfo };
use crate::utils::Socket;
use prost::Message;
@ -37,6 +38,10 @@ async fn handle_tcp_message(msg: SdlanTcp) {
debug!("got tcp message: {:?}", msg.packet_type);
match msg.packet_type {
PacketType::RegisterSuperACK => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
});
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK");
return;
@ -74,12 +79,47 @@ async fn handle_tcp_message(msg: SdlanTcp) {
PacketType::RegisterSuperNAK => {
let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_NAK");
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
return;
};
error!(
"RegisterSuperNAK received: {}({})",
_nak.error_code, _nak.error_message
);
let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 2,
message: "error_code not recognized".to_owned(),
});
return;
};
match error_code {
NakMsgCode::InvalidToken => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 3,
message: "invalid token".to_owned(),
});
}
NakMsgCode::NodeDisabled => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 4,
message: "Node is disabled".to_owned(),
});
}
_other => {
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
});
}
}
/*
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
*/
edge.set_authorized(false, Vec::new());
// std::process::exit(0);
}
@ -148,7 +188,7 @@ async fn handle_tcp_event(edge: &Node, eventtype: EventType, eventprotobuf: &[u8
pub async fn async_main(
install_channel: String,
args: CommandLine,
start_stop_chan: Receiver<bool>,
start_stop_chan: Receiver<StartStopInfo>,
cancel: CancellationToken,
) -> Result<()> {
// let _ = PidRecorder::new(".pid");
@ -185,9 +225,9 @@ pub async fn async_main(
// let token = args.token.clone();
init_tcp_conn(
&args.tcp,
move |stream| {
move |stream, pkt_id| {
let installed_channel = install_channel.to_owned();
Box::pin(async {
Box::pin(async move {
let token = edge._token.lock().unwrap().clone();
// let edge = get_edge();
// let edge = get_edge();
@ -215,7 +255,11 @@ pub async fn async_main(
token,
};
// debug!("send register super: {:?}", register_super);
let packet_id = edge.get_next_packet_id();
let packet_id = match pkt_id {
Some(id) => id,
None => edge.get_next_packet_id(),
};
// let packet_id = edge.get_next_packet_id();
let data = encode_to_tcp_message(
Some(register_super),
packet_id,

View File

@ -33,7 +33,7 @@ pub async fn init_edge(
token: &str,
node_conf: NodeConfig,
tos: u32,
start_stop: Sender<bool>,
start_stop: Sender<StartStopInfo>,
) -> Result<()> {
let _ = PidRecorder::new(".pid");
@ -89,6 +89,17 @@ pub fn get_edge() -> &'static Node {
edge.unwrap()
}
pub struct RegisterSuperFeedback {
// 0 for success, other for error
pub result: u8,
pub message: String,
}
pub struct StartStopInfo {
pub is_start: bool,
pub pkt_id: Option<u32>,
}
pub struct Node {
packet_id: AtomicU32,
@ -96,7 +107,7 @@ pub struct Node {
pub tcp_pong: Arc<AtomicU64>,
start_stop_sender: Sender<bool>,
start_stop_sender: Sender<StartStopInfo>,
// user token info
pub _token: Mutex<String>,
@ -139,22 +150,65 @@ pub struct Node {
nat_cookie: AtomicU32,
cookie_match: DashMap<u32, oneshot::Sender<SdlStunProbeReply>>,
packet_id_match: DashMap<u32, oneshot::Sender<RegisterSuperFeedback>>,
}
unsafe impl Sync for Node {}
impl Node {
pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) {
match self.packet_id_match.remove(&pktid) {
Some(sender) => {
let _ = sender.1.send(feed);
}
None => {
return;
}
}
}
pub fn get_nat_type(&self) -> NatType {
self.nat_type.lock().unwrap().clone()
}
pub async fn start(&self, token: String) {
pub async fn start_without_feedback(&self, token: String) -> Result<()> {
*self._token.lock().unwrap() = token;
let _ = self.start_stop_sender.send(true).await;
let _ = self.start_stop_sender.send(StartStopInfo{
is_start: true,
pkt_id: None,
}).await;
Ok(())
}
pub async fn start_with_feedback(&self, token: String, timeout: Duration) -> Result<RegisterSuperFeedback> {
*self._token.lock().unwrap() = token;
let (tx, mut rx) = oneshot::channel();
let id = self.get_next_packet_id();
self.packet_id_match.insert(id, tx);
let _ = self.start_stop_sender.send(StartStopInfo{
is_start: true,
pkt_id: Some(id),
}).await;
tokio::select! {
Ok(result) = rx => {
self.packet_id_match.remove(&id);
Ok(result)
}
_ = tokio::time::sleep(timeout) => {
Err(SDLanError::NormalError("timed out"))
}
}
}
pub async fn stop(&self) {
*self._token.lock().unwrap() = "".to_owned();
let _ = self.start_stop_sender.send(false).await;
let _ = self.start_stop_sender.send(StartStopInfo{
is_start: false,
pkt_id: None,
}).await;
}
pub fn new(
@ -166,7 +220,7 @@ impl Node {
token: &str,
private: RsaPrivateKey,
tcp_pong: Arc<AtomicU64>,
start_stop: Sender<bool>,
start_stop: Sender<StartStopInfo>,
) -> Self {
Self {
packet_id: AtomicU32::new(1),
@ -209,6 +263,7 @@ impl Node {
stats: NodeStats::new(),
_last_register_req: AtomicU64::new(0),
packet_id_match: DashMap::new(),
nat_cookie: AtomicU32::new(1),
cookie_match: DashMap::new(),
}

View File

@ -46,6 +46,16 @@ pub enum NatType {
Symmetric = 5,
}
#[derive(Debug, Copy, Clone, TryFromPrimitive)]
#[repr(u8)]
pub enum NakMsgCode {
InvalidToken = 1,
NodeDisabled = 2,
NoIpAddress = 3,
NetworkFault = 4,
InternalFault = 5,
}
#[derive(Debug, Copy, Clone, TryFromPrimitive)]
#[repr(u8)]
pub enum PacketType {

View File

@ -19,6 +19,7 @@ use tokio::{
use tracing::error;
use crate::config::TCP_PING_TIME;
use crate::network::StartStopInfo;
use crate::tcp::read_a_packet;
use super::tcp_codec::SdlanTcp;
@ -57,23 +58,25 @@ impl ReadWriteActor {
mut to_tcp: Receiver<Vec<u8>>,
on_connected: T,
on_disconnected: T2,
mut start_stop_chan: Receiver<bool>,
mut start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken,
) where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()>,
T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()>,
T2: Fn() -> F,
F: Future<Output = ()>,
{
// let (tx, rx) = channel(20);
let mut started = false;
let mut start_pkt_id = None;
loop {
self.connected.store(false, Ordering::Relaxed);
if !started {
// println!("waiting for start");
while let Some(m) = start_stop_chan.recv().await {
if m {
if m.is_start {
// println!("start received");
started = true;
start_pkt_id = m.pkt_id;
break;
} else {
// println!("stop received");
@ -97,7 +100,7 @@ impl ReadWriteActor {
return;
};
self.connected.store(true, Ordering::Relaxed);
on_connected(&mut stream).await;
on_connected(&mut stream, start_pkt_id.take()).await;
// stream.write("hello".as_bytes()).await;
let (reader, mut write) = stream.into_split();
@ -149,15 +152,29 @@ impl ReadWriteActor {
}
};
let check_stop = async {
loop {
match start_stop_chan.recv().await {
Some(v) => {
if !v.is_start {
return;
}
}
other => {
// send chan is closed;
return;
}
}
}
};
pin_mut!(read_from_tcp, write_to_tcp);
tokio::select! {
_ = read_from_tcp => {},
_ = write_to_tcp => {},
_ = check_pong => {},
Some(false) = start_stop_chan.recv() => {
started = false;
}
_ = check_stop => {},
}
on_disconnected().await;
debug!("connect retrying");
@ -196,11 +213,11 @@ impl ReadWriterHandle {
on_disconnected: T3,
on_message: T2,
pong_time: Arc<AtomicU64>,
start_stop_chan: Receiver<bool>,
start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken,
) -> Self
where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static,
T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()> + Send + 'static,
T3: Fn() -> F2 + Send + 'static,
T2: Fn(SdlanTcp) -> F + Send + 'static,
F: Future<Output = ()> + Send,
@ -242,9 +259,9 @@ pub fn init_tcp_conn<'a, T, T3, T2, F, F2>(
on_message: T2,
pong_time: Arc<AtomicU64>,
// cancel: CancellationToken,
start_stop_chan: Receiver<bool>,
start_stop_chan: Receiver<StartStopInfo>,
) where
T: for<'b> Fn(&'b mut TcpStream) -> BoxFuture<'b, ()> + Send + 'static,
T: for<'b> Fn(&'b mut TcpStream, Option<u32>) -> BoxFuture<'b, ()> + Send + 'static,
T3: Fn() -> F2 + Send + 'static,
T2: Fn(SdlanTcp) -> F + Send + 'static,
F: Future<Output = ()> + Send,