packet id match is merged to cookie_match

This commit is contained in:
alex 2026-03-26 19:50:15 +08:00
parent b10c721179
commit be401afc7b
5 changed files with 182 additions and 38 deletions

View File

@ -354,7 +354,7 @@ fn main() {
} }
} }
let should_daemonize = true; let should_daemonize = false;
#[cfg(not(target_os = "windows"))] #[cfg(not(target_os = "windows"))]
if should_daemonize { if should_daemonize {

View File

@ -250,7 +250,7 @@ async fn loop_tap(eee: &'static Node, cancel: CancellationToken) {
} }
} }
} }
debug!("loop_tap exited"); error!("loop_tap exited");
} }
async fn get_tun_flow(eee: &'static Node, tx: Sender<Vec<u8>>) { async fn get_tun_flow(eee: &'static Node, tx: Sender<Vec<u8>>) {

View File

@ -1,9 +1,12 @@
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use dashmap::DashMap; use dashmap::DashMap;
use prost::Message;
use quinn::Endpoint; use quinn::Endpoint;
use rsa::RsaPrivateKey; use rsa::RsaPrivateKey;
use sdlan_sn_rs::config::{AF_INET, AF_INET6}; use sdlan_sn_rs::config::{AF_INET, AF_INET6};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use std::any::Any;
use std::future::Future;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicU8, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicU8, Ordering};
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
@ -233,15 +236,20 @@ pub struct Node {
nat_type: Mutex<NatType>, nat_type: Mutex<NatType>,
nat_cookie: AtomicU32, nat_cookie: AtomicU32,
cookie_match: DashMap<u32, oneshot::Sender<SdlStunProbeReply>>, // cookie_match: DashMap<u32, oneshot::Sender<SdlStunProbeReply>>,
cookie_match: Queryer,
packet_id_match: DashMap<u32, oneshot::Sender<RegisterSuperFeedback>>, // packet_id_match: Queryer,
// packet_id_match: DashMap<u32, oneshot::Sender<RegisterSuperFeedback>>,
} }
unsafe impl Sync for Node {} unsafe impl Sync for Node {}
impl Node { impl Node {
pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) { pub fn send_register_super_feedback(&self, pktid: u32, feed: RegisterSuperFeedback) {
// self.packet_id_match.write_feedback(pktid, feed);
self.cookie_match.write_feedback(pktid, feed);
/*
match self.packet_id_match.remove(&pktid) { match self.packet_id_match.remove(&pktid) {
Some(sender) => { Some(sender) => {
let _ = sender.1.send(feed); let _ = sender.1.send(feed);
@ -250,6 +258,7 @@ impl Node {
return; return;
} }
} }
*/
} }
pub fn get_nat_type(&self) -> NatType { pub fn get_nat_type(&self) -> NatType {
@ -325,8 +334,30 @@ impl Node {
self.identity_id.store(identity_id); self.identity_id.store(identity_id);
// *self._token.lock().unwrap() = token; // *self._token.lock().unwrap() = token;
// *self.network_code.lock().unwrap() = network_code; // *self.network_code.lock().unwrap() = network_code;
let (tx, rx) = oneshot::channel();
let id = self.get_next_packet_id(); let id = self.get_next_packet_id();
// let result = self.packet_id_match.do_action_and_wait_for(
let result = self.cookie_match.do_action_and_wait_for(
id,
|| async {
let _ = self
.start_stop_sender
.send(StartStopInfo {
is_start: true,
pkt_id: Some(id),
})
.await;
debug!("start with feedback");
},
timeout
).await?;
if let Ok(data) = result.downcast() {
return Ok(*data);
}
Err(SDLanError::NormalError("timed out"))
/*
let (tx, rx) = oneshot::channel();
self.packet_id_match.insert(id, tx); self.packet_id_match.insert(id, tx);
let _ = self let _ = self
.start_stop_sender .start_stop_sender
@ -350,6 +381,7 @@ impl Node {
Err(SDLanError::NormalError("timed out")) Err(SDLanError::NormalError("timed out"))
} }
} }
*/
} }
pub async fn stop(&self) { pub async fn stop(&self) {
@ -453,9 +485,9 @@ impl Node {
stats: NodeStats::new(), stats: NodeStats::new(),
_last_register_req: AtomicU64::new(0), _last_register_req: AtomicU64::new(0),
packet_id_match: DashMap::new(), // packet_id_match: Queryer::new(),
nat_cookie: AtomicU32::new(1), nat_cookie: AtomicU32::new(1),
cookie_match: DashMap::new(), cookie_match: Queryer::new(),
server_ip, server_ip,
install_channel, install_channel,
} }
@ -537,11 +569,14 @@ impl Node {
*/ */
pub async fn send_nat_probe_reply(&self, cookie: u32, buf: SdlStunProbeReply) { pub async fn send_nat_probe_reply(&self, cookie: u32, buf: SdlStunProbeReply) {
self.cookie_match.write_feedback(cookie, buf);
/*
if let Some((_key, chan)) = self.cookie_match.remove(&cookie) { if let Some((_key, chan)) = self.cookie_match.remove(&cookie) {
let _ = chan.send(buf); let _ = chan.send(buf);
return; return;
} }
error!("failed to get such cookie stun probe"); */
// error!("failed to get such cookie stun probe");
} }
pub async fn probe_nat_type(&self) -> NatType { pub async fn probe_nat_type(&self) -> NatType {
@ -619,38 +654,30 @@ impl Node {
to_server: &SocketAddr, to_server: &SocketAddr,
) -> Result<SdlStunProbeReply> { ) -> Result<SdlStunProbeReply> {
let cookie = self.nat_cookie.fetch_add(1, Ordering::Relaxed); let cookie = self.nat_cookie.fetch_add(1, Ordering::Relaxed);
let probe = SdlStunProbe {
attr: msgattr as u32,
cookie,
step: 0,
};
// println!("==> sending probe request: {:?}", probe); // println!("==> sending probe request: {:?}", probe);
let (tx, rx) = oneshot::channel(); let res = self.cookie_match.do_action_and_wait_for(
self.cookie_match.insert(cookie, tx); cookie,
// let cookie = msg.cookie; || async {
let msg = encode_to_udp_message(Some(probe), PacketType::StunProbe as u8).unwrap(); let probe = SdlStunProbe {
if let Err(_e) = self.udp_sock_v4.send_to(&msg, to_server).await { attr: msgattr as u32,
self.cookie_match.remove(&cookie); cookie,
return Err(SDLanError::NormalError("send error")); step: 0,
};
let msg = encode_to_udp_message(Some(probe), PacketType::StunProbe as u8).unwrap();
if let Err(_e) = self.udp_sock_v4.send_to(&msg, to_server).await {
error!("failed to send StunProbe");
}
},
Duration::from_secs(3),
).await?;
if let Ok(data) = res.downcast() {
return Ok(*data);
} }
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(3)) => { Err(SDLanError::NormalError("reply recv error"))
self.cookie_match.remove(&cookie);
return Err(SDLanError::NormalError("timed out"));
}
reply = rx => {
self.cookie_match.remove(&cookie);
if let Ok(reply) = reply {
// reply received,
return Ok(reply);
// println!("got nat ip: {}:{}", ip_to_string(&reply.ip), reply.port);
}
return Err(SDLanError::NormalError("reply recv error"));
// step 1 received // step 1 received
}
}
} }
} }
@ -833,3 +860,110 @@ impl EdgePeer {
self.nat_type self.nat_type
} }
} }
type BoxedProstMessage = Box<dyn Any + Send + Sync + 'static>;
pub struct Queryer {
pub mailbox: DashMap<u32, tokio::sync::oneshot::Sender<BoxedProstMessage>>,
}
impl Queryer {
pub fn new() -> Self {
Self {
mailbox: DashMap::new(),
}
}
pub fn write_feedback<T: Any + Sync + Send + 'static>(&self, id: u32, data: T) {
if let Some((_, tx)) = self.mailbox.remove(&id) {
if let Err(e) = tx.send(Box::new(data)) {
error!("failed to write feedback");
}
}
}
pub async fn send_message_to_udp_and_wait_for<T: Message>(&self, sock: &Socket, id: u32, message: T, packet_type: u8, to_server: &SocketAddr, timeout: Duration) -> Result<BoxedProstMessage> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.mailbox.insert(id, tx);
let content = encode_to_udp_message(Some(message), packet_type)?;
if let Err(_e) = sock.send_to(&content, to_server).await {
self.mailbox.remove(&id);
return Err(SDLanError::NormalError("send error"));
}
let quic_conn = get_quic_write_conn();
quic_conn.send(content).await?;
tokio::select! {
data = rx => {
if let Ok(data) = data {
self.mailbox.remove(&id);
Ok(data)
} else {
self.mailbox.remove(&id);
Err(SDLanError::IOError("rx receive failed".to_string()))
}
}
_ = tokio::time::sleep(timeout) => {
self.mailbox.remove(&id);
Err(SDLanError::IOError("timed out".to_string()))
}
}
}
pub async fn do_action_and_wait_for<T, F>(&self, id: u32, action: T, timeout: Duration) -> Result<BoxedProstMessage>
where
F: Future<Output = ()>,
T: Fn() -> F,
{
let (tx, rx) = tokio::sync::oneshot::channel();
self.mailbox.insert(id, tx);
action().await;
tokio::select! {
data = rx => {
if let Ok(data) = data {
self.mailbox.remove(&id);
Ok(data)
} else {
self.mailbox.remove(&id);
Err(SDLanError::IOError("rx receive failed".to_string()))
}
}
_ = tokio::time::sleep(timeout) => {
self.mailbox.remove(&id);
Err(SDLanError::IOError("timed out".to_string()))
}
}
}
pub async fn send_message_to_quic_and_wait_for<T: Message>(&self, id: u32, message: T, packet_type: u8, timeout: Duration) -> Result<BoxedProstMessage> {
let (tx, rx) = tokio::sync::oneshot::channel();
self.mailbox.insert(id, tx);
let content = encode_to_tcp_message(Some(message), packet_type)?;
let quic_conn = get_quic_write_conn();
quic_conn.send(content).await?;
tokio::select! {
data = rx => {
if let Ok(data) = data {
self.mailbox.remove(&id);
Ok(data)
} else {
self.mailbox.remove(&id);
Err(SDLanError::IOError("rx receive failed".to_string()))
}
}
_ = tokio::time::sleep(timeout) => {
self.mailbox.remove(&id);
Err(SDLanError::IOError("timed out".to_string()))
}
}
}
}

View File

@ -1,13 +1,15 @@
use std::{collections::{HashMap}, net::Ipv4Addr, sync::{atomic::{AtomicBool, Ordering}}}; use std::{collections::HashMap, net::Ipv4Addr, sync::atomic::{AtomicBool, Ordering}, time::Duration};
use ahash::RandomState; use ahash::RandomState;
use dashmap::{DashMap}; use dashmap::{DashMap};
use ipnet::Ipv4Net; use ipnet::Ipv4Net;
use sdlan_sn_rs::utils::{Result, SDLanError}; use sdlan_sn_rs::utils::{Result, SDLanError};
use tokio::sync::oneshot::{Receiver, Sender, channel};
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{RouteTableTrie, network::tun::add_route}; use crate::{RouteTableTrie, network::tun::add_route, pb::{SdlArpResponse, SdlStunReply}};
pub struct RouteTable2 { pub struct RouteTable2 {
pub cache_table: DashMap<(Ipv4Net, Ipv4Addr), AtomicBool, RandomState>, pub cache_table: DashMap<(Ipv4Net, Ipv4Addr), AtomicBool, RandomState>,

View File

@ -36,7 +36,7 @@ impl ReadWriterHandle {
Ok(()) Ok(())
} }
fn new<>( fn new(
cancel: CancellationToken, cancel: CancellationToken,
addr: &str, addr: &str,
// on_connected: OnConnectedCallback<'a>, // on_connected: OnConnectedCallback<'a>,
@ -78,8 +78,10 @@ impl ReadWriterHandle {
loop { loop {
if let Some(msg) = data_from_tcp.recv().await { if let Some(msg) = data_from_tcp.recv().await {
handle_tcp_message(msg).await; handle_tcp_message(msg).await;
println!("handle_tcp_message ok");
} else { } else {
error!("data from tcp exited"); error!("data from tcp exited");
println!("data from tcp exited");
// eprintln!("data from tcp exited"); // eprintln!("data from tcp exited");
return; return;
} }
@ -109,6 +111,7 @@ async fn handle_tcp_message(msg: SdlanTcp) {
// edge.tcp_pong.store(now, Ordering::Relaxed); // edge.tcp_pong.store(now, Ordering::Relaxed);
debug!("got tcp message: {:?}", msg.packet_type); debug!("got tcp message: {:?}", msg.packet_type);
println!("got tcp message: {:?}", msg.packet_type);
match msg.packet_type { match msg.packet_type {
PacketType::RegisterSuperACK => { PacketType::RegisterSuperACK => {
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else { let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
@ -116,6 +119,7 @@ async fn handle_tcp_message(msg: SdlanTcp) {
return; return;
}; };
debug!("got register super ack1: {:?}", ack);
edge.send_register_super_feedback( edge.send_register_super_feedback(
ack.pkt_id, ack.pkt_id,
RegisterSuperFeedback { RegisterSuperFeedback {
@ -192,12 +196,14 @@ async fn handle_tcp_message(msg: SdlanTcp) {
debug!("nat type is {:?}", nattype); debug!("nat type is {:?}", nattype);
// println!("nat type is: {:?}", nattype); // println!("nat type is: {:?}", nattype);
}); });
println!("register message handled");
} }
PacketType::ArpResponse => { PacketType::ArpResponse => {
let Ok(resp) = SdlArpResponse::decode(&msg.current_packet[..]) else { let Ok(resp) = SdlArpResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode ARP RESPONSE"); error!("failed to decode ARP RESPONSE");
return; return;
}; };
debug!("got arp response: {:?}", resp);
if resp.target_mac.len() != 6 { if resp.target_mac.len() != 6 {
// invalid target_mac // invalid target_mac
error!("invalid target_mac: {:?}, ip={}", resp.target_mac, ip_to_string(&resp.target_ip)); error!("invalid target_mac: {:?}, ip={}", resp.target_mac, ip_to_string(&resp.target_ip));
@ -518,6 +524,7 @@ impl ReadWriteActor {
} }
} else { } else {
// None, just return // None, just return
println!("start or stop is None");
return; return;
} }
} }
@ -625,6 +632,7 @@ impl ReadWriteActor {
let write_to_tcp = async { let write_to_tcp = async {
while let Some(data) = to_tcp.recv().await { while let Some(data) = to_tcp.recv().await {
debug!("data size = {}", data.len());
match send.write(&data).await { match send.write(&data).await {
Ok(size) => { Ok(size) => {
debug!("{} bytes sent to tcp", size); debug!("{} bytes sent to tcp", size);