use std::{net::IpAddr, sync::{Arc, OnceLock, atomic::{AtomicBool, AtomicU64, Ordering}}, time::Duration}; use futures_util::pin_mut; use prost::Message; use quinn::SendStream; use sdlan_sn_rs::{config::AF_INET, peer::{SdlanSock, V6Info}, utils::{Result, SDLanError, get_current_timestamp, ip_to_string, rsa_decrypt}}; use tokio::{sync::mpsc::{Receiver, Sender, channel}, time::sleep}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, warn}; use crate::{AesEncryptor, Chacha20Encryptor, ConnectionInfo, ConnectionState, MyEncryptor, RuleFromServer, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{ARP_REPLY, ArpHdr, EthHdr, Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlArpResponse, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, SdlanTcp, read_a_packet, send_stun_request}}; static GLOBAL_QUIC_HANDLE: OnceLock = OnceLock::new(); #[derive(Debug)] pub struct ReadWriterHandle { connected: Arc, send_to_tcp: Sender>, // pub data_from_tcp: Receiver, } impl ReadWriterHandle { pub async fn send(&self, data: Vec) -> Result<()> { if self.connected.load(Ordering::Relaxed) { // connected, send to it if let Err(e) = self.send_to_tcp.send(data).await { error!("failed to send to send_to_tcp: {}", e.to_string()); return Err(SDLanError::NormalError("failed to send")); }; debug!("tcp info sent"); } else { error!("tcp not connected, so not sending data"); return Err(SDLanError::NormalError("not connected, so not sending")); } Ok(()) } fn new<>( cancel: CancellationToken, addr: &str, // on_connected: OnConnectedCallback<'a>, // on_disconnected: T3, // on_message: T2, pong_time: Arc, start_stop_chan: Receiver, // cancel: CancellationToken, connecting_chan: Option>, ipv6_network_restarter: Option>, ) -> Self { let (send_to_tcp, to_tcp) = channel(20); let (from_tcp, mut data_from_tcp) = channel(20); let connected: Arc = Arc::new(AtomicBool::new(false)); let actor = ReadWriteActor::new( cancel, addr, from_tcp, connected.clone(), pong_time, connecting_chan, ipv6_network_restarter, ); tokio::spawn(async move { actor .run( true, to_tcp, // on_connected, // on_disconnected, start_stop_chan ) .await }); tokio::spawn(async move { loop { if let Some(msg) = data_from_tcp.recv().await { handle_tcp_message(msg).await; } else { error!("data from tcp exited"); // eprintln!("data from tcp exited"); return; } } }); ReadWriterHandle { connected, send_to_tcp, // data_from_tcp, } } } pub fn get_quic_write_conn() -> &'static ReadWriterHandle { match GLOBAL_QUIC_HANDLE.get() { Some(v) => v, None => panic!("should call init_tcp_conn first"), } } async fn handle_tcp_message(msg: SdlanTcp) { let edge = get_edge(); // let now = get_current_timestamp(); // edge.tcp_pong.store(now, Ordering::Relaxed); debug!("got tcp message: {:?}", msg.packet_type); match msg.packet_type { PacketType::RegisterSuperACK => { let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else { error!("failed to decode REGISTER_SUPER_ACK"); return; }; edge.send_register_super_feedback( ack.pkt_id, RegisterSuperFeedback { result: 0, message: "".to_owned(), should_exit: false, }, ); debug!("got register super ack: {:?}", ack); edge.session_token.set(ack.session_token); let Ok(key) = rsa_decrypt(&edge.rsa_private, &ack.key) else { error!("failed to rsa decrypt aes key"); return; }; match ack.algorithm.to_ascii_lowercase().as_str() { "chacha20" => { edge.encryptor.store(Arc::new(MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id)))) // *edge.encryptor.write().unwrap() = MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id)); } "aes" => { edge.encryptor.store(Arc::new(MyEncryptor::Aes(AesEncryptor::new(key)))) // *edge.encryptor.write().unwrap() = MyEncryptor::Aes(AesEncryptor::new(key)); } _other => { } } /* let Some(dev) = ack.dev_addr else { error!("no dev_addr is specified"); return; }; */ let ip = ip_to_string(&edge.device_config.get_ip()); // debug!("aes key is {:?}, ip is {}/{}", aes, ip, dev.net_bit_len,); println!("assigned ip: {}", ip); // let hostname = edge.hostname.read().unwrap().clone(); // println!("network is: {}.{}", hostname, dev.network_domain); /* edge.device_config .ip .net_addr .store(dev.net_addr, Ordering::Relaxed); */ if let Some(ref chan) = edge.connection_chan { let _ = chan.send(ConnectionInfo::IPInfo(ip)).await; } /* let mac = match dev.mac.try_into() { Err(_) => NULL_MAC, Ok(m) => m, }; */ // *edge.device_config.mac.write().unwrap() = mac; /* edge.device_config .ip .net_bit_len .store(dev.net_bit_len as u8, Ordering::Relaxed); edge.network_id.store(dev.network_id, Ordering::Relaxed); */ // edge.device.reload_config(&edge.device_config, &dev.network_domain); edge.device.reload_config(&edge.device_config, &edge.network_domain.read().unwrap().clone()); edge.set_authorized(true); send_stun_request(edge).await; tokio::spawn(async { let nattype = edge.probe_nat_type().await; debug!("nat type is {:?}", nattype); // println!("nat type is: {:?}", nattype); }); } PacketType::ArpResponse => { let Ok(resp) = SdlArpResponse::decode(&msg.current_packet[..]) else { error!("failed to decode ARP RESPONSE"); return; }; if resp.target_mac.len() != 6 { // invalid target_mac error!("invalid target_mac: {:?}, ip={}", resp.target_mac, ip_to_string(&resp.target_ip)); return; } // TODO: construct the arp reply, and write to tun; let src_mac = resp.target_mac.try_into().unwrap(); let dst_mac = edge.device_config.get_mac(); let dst_ip = edge.device_config.get_ip(); let hdr = ArpHdr{ ethhdr: EthHdr { dest: dst_mac, src: src_mac, eth_type: 0x0806, }, hwtype: 0x0001, protocol: 0x0800, hwlen: 6, protolen: 4, opcode: ARP_REPLY, shwaddr: src_mac, sipaddr: [((resp.target_ip >> 16) as u16) & 0xffff, (resp.target_ip as u16) & 0xffff], dhwaddr: dst_mac, dipaddr: [((dst_ip >> 16) & 0x0000ffff) as u16, (dst_ip & 0x0000ffff) as u16] }; let data = hdr.marshal_to_bytes(); if let Err(_e) = edge.device.send(&data) { error!("failed to write arp response to device"); } } PacketType::PolicyReply => { let Ok(policy) = SdlPolicyResponse::decode(&msg.current_packet[..]) else { error!("failed to decode POLICY RESPONSE"); return; }; let identity = policy.src_identity_id; let mut infos = Vec::new(); let mut start = 0; while start < policy.rules.len() { if start + 3 > policy.rules.len() { break; } let proto = policy.rules[start]; let port = u16::from_be_bytes([policy.rules[start+1], policy.rules[start+2]]); start += 3; infos.push(RuleFromServer{ proto, port, }); } edge.rule_cache.set_identity_cache(identity, infos); } PacketType::RegisterSuperNAK => { let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else { error!("failed to decode REGISTER_SUPER_NAK"); /* edge.send_register_super_feedback( msg._packet_id, RegisterSuperFeedback { result: 1, message: "failed to decode REGISTER SUPER NAK".to_owned(), should_exit: false, }, ); */ return; }; error!("got nak: {:?}", _nak); let pkt_id = _nak.pkt_id; let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else { edge.send_register_super_feedback( //msg._packet_id, pkt_id, RegisterSuperFeedback { result: 2, message: "error_code not recognized".to_owned(), should_exit: false, }, ); return; }; match error_code { NakMsgCode::InvalidToken => { edge.send_register_super_feedback( // msg._packet_id, pkt_id, RegisterSuperFeedback { result: 3, message: "invalid token".to_owned(), should_exit: true, }, ); edge.stop().await; } NakMsgCode::NodeDisabled => { edge.send_register_super_feedback( // msg._packet_id, pkt_id, RegisterSuperFeedback { result: 4, message: "Node is disabled".to_owned(), should_exit: true, }, ); edge.stop().await; } _other => { edge.send_register_super_feedback( // msg._packet_id, pkt_id, RegisterSuperFeedback { result: 0, message: "".to_owned(), should_exit: false, }, ); } } /* edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback { result: 1, message: "failed to decode REGISTER SUPER NAK".to_owned(), }); */ edge.set_authorized(false); edge.encryptor.store(Arc::new(MyEncryptor::Invalid)); // *edge.encryptor.write().unwrap() = MyEncryptor::Invalid; // std::process::exit(0); } PacketType::Command => { if msg.current_packet.len() < 1 { error!("malformed COMMAND received"); return; } handle_tcp_command(edge, msg.current_packet[0], &msg.current_packet[1..]).await; } PacketType::Event => { if msg.current_packet.len() < 1 { error!("malformed EVENT received"); return; } let Ok(event) = msg.current_packet[0].try_into() else { error!("failed to parse event type"); return; }; handle_tcp_event(edge, event, &msg.current_packet[1..]).await; } PacketType::PeerInfo => { let _ = handle_packet_peer_info(edge, &msg.current_packet[..]).await; } PacketType::Pong => { debug!("tcp pong received"); let now = get_current_timestamp(); edge.tcp_pong.store(now, Ordering::Relaxed); } other => { debug!("tcp not handling {:?}", other); } } } async fn handle_tcp_command(_edge: &Node, _cmdtype: u8, _cmdprotobuf: &[u8]) {} async fn handle_tcp_event(edge: &'static Node, eventtype: EventType, eventprotobuf: &[u8]) { match eventtype { EventType::SendRegister => { let Ok(reg) = SdlSendRegisterEvent::decode(eventprotobuf) else { error!("failed to decode SendRegister Event"); return; }; let v4 = reg.nat_ip.to_be_bytes(); let mut v6_sock = None; if let Some(v6_info) = reg.v6_info { if let Ok(v6_bytes) = v6_info.v6.try_into() { v6_sock = Some(V6Info { port: v6_info.port as u16, v6: v6_bytes, }); } } let dst_mac = match reg.dst_mac.try_into() { Ok(m) => m, Err(_e) => NULL_MAC, }; let remote_nat_byte = reg.nat_type as u8; let remote_nat = match remote_nat_byte.try_into() { Ok(t) => t, Err(_) => NatType::NoNat, }; check_peer_registration_needed( edge, false, dst_mac, // &v6_sock, remote_nat, &v6_sock, &SdlanSock { family: AF_INET, port: reg.nat_port as u16, v4, v6: [0; 16], }, ) .await; } other => { debug!("unhandled event {:?}", other); } } } pub fn init_quic_conn( cancel: CancellationToken, addr: &str, // on_connected: OnConnectedCallback<'a>, // on_disconnected: T3, // on_message: T2, pong_time: Arc, // cancel: CancellationToken, start_stop_chan: Receiver, connecting_chan: Option>, ipv6_network_restarter: Option>, ) // T2: Fn(SdlanTcp) -> F + Send + 'static, // F: Future + Send, { let tcp_handle = ReadWriterHandle::new( cancel, addr, // on_connected, // on_disconnected, // on_message, pong_time, start_stop_chan, connecting_chan, ipv6_network_restarter, ); GLOBAL_QUIC_HANDLE .set(tcp_handle) .expect("failed to set global tcp handle"); } pub struct ReadWriteActor { // actor接收的发送给tcp的接收端,由handle存放发送端 // to_tcp: Receiver>, remote: String, connected: Arc, pong_time: Arc, // actor收到数据之后,发送给上层的发送端口,接收端由handle保存 from_tcp: Sender, _cancel: CancellationToken, connecting_chan: Option>, ipv6_network_restarter: Option>, } impl ReadWriteActor { pub fn new( cancel: CancellationToken, remote: &str, from_tcp: Sender, connected: Arc, pong_time: Arc, connecting_chan: Option>, ipv6_network_restarter: Option>, ) -> Self { Self { // to_tcp, _cancel: cancel, pong_time, connected, remote: remote.to_owned(), from_tcp, connecting_chan, ipv6_network_restarter, } } pub async fn run<'a>( &self, keep_reconnect: bool, mut to_tcp: Receiver>, mut start_stop_chan: Receiver, ) { let edge = get_edge(); let mut started = false; let mut start_pkt_id = None; loop { if let Some(ref connecting_chan) = self.connecting_chan { let state = ConnectionInfo::ConnState(ConnectionState::NotConnected); let _ = connecting_chan.send(state).await; } self.connected.store(false, Ordering::Relaxed); if !started { // println!("waiting for start"); loop { let start_or_stop = start_stop_chan.recv().await; if let Some(m) = start_or_stop { if m.is_start { started = true; start_pkt_id = m.pkt_id; break; } } else { // None, just return return; } } /* while let Some(m) = start_stop_chan.recv().await { println!("4"); if m.is_start { // println!("start received"); started = true; start_pkt_id = m.pkt_id; break; } else { // println!("stop received"); } } */ debug!("start stop chan received: {}", started); continue; } if let Some(ref connecting_chan) = self.connecting_chan { let state = ConnectionInfo::ConnState(ConnectionState::Connecting); let _ = connecting_chan.send(state).await; } let host = self.remote.split(":").next().unwrap(); debug!("try connecting to {}, host = {}", self.remote, host); let conn = match edge.quic_endpoint.connect(self.remote.parse().unwrap(), host) { Ok(conn) => conn, Err(e) => { error!("failed to connect: {}", e); // println!("failed to connect: {}", e); self.connected.store(false, Ordering::Relaxed); if keep_reconnect { tokio::time::sleep(Duration::from_secs(3)).await; continue; } return; } }; let conn = match conn.await { Err(e) => { // println!("failed to connect await: {}", e); error!("failed to connect await: {}", e); self.connected.store(false, Ordering::Relaxed); if keep_reconnect { tokio::time::sleep(Duration::from_secs(3)).await; continue; } return; } Ok(conn) => conn, }; let local_ip = conn.local_ip(); let Ok((mut send, mut recv)) = conn.open_bi().await else { println!("failed to open-bi"); self.connected.store(false, Ordering::Relaxed); if keep_reconnect { tokio::time::sleep(Duration::from_secs(3)).await; continue; } return; }; self.connected.store(true, Ordering::Relaxed); debug!("connected"); sleep(Duration::from_millis(200)).await; on_connected_callback(local_ip, &mut send, start_pkt_id.take()).await; if let Some(ref connecting_chan) = self.connecting_chan { let state = ConnectionInfo::ConnState(ConnectionState::Connected); let _ = connecting_chan.send(state).await; } if let Some(ref ipv6_restarter) = self.ipv6_network_restarter { let _ = ipv6_restarter.send(true).await; } // stream.write("hello".as_bytes()).await; // let (reader, mut write) = stream.into_split(); let read_from_tcp = async move { // let mut buffed_reader = BufReader::new(recv); loop { match read_a_packet(&mut recv).await { Ok(packet) => { warn!("got packet: {:?}", packet); if let Err(_e) = self.from_tcp.send(packet).await { error!("failed to receive a packet: {:?}", _e); } } Err(e) => { error!("failed to read a packet: {}, reconnecting...", e); return; } } } }; let write_to_tcp = async { while let Some(data) = to_tcp.recv().await { match send.write(&data).await { Ok(size) => { debug!("{} bytes sent to tcp", size); } Err(e) => { error!("failed to write to tcp: {}", e.to_string()); return; } } } error!("to_tcp recv None"); }; let check_pong = async { loop { tokio::time::sleep(Duration::from_secs(3600)).await; let connected = self.connected.load(Ordering::Relaxed); let now = get_current_timestamp(); if connected && now - self.pong_time.load(Ordering::Relaxed) > TCP_PING_TIME * 2 { // pong time expire, need to re-connect error!("pong check expired"); return; } } }; let check_stop = async { loop { match start_stop_chan.recv().await { Some(v) => { if !v.is_start { started = false; return; } } _other => { // send chan is closed; started = false; return; } } } }; pin_mut!(read_from_tcp, write_to_tcp); tokio::select! { _ = read_from_tcp => {}, _ = write_to_tcp => {}, _ = check_pong => {}, _ = check_stop => {}, } on_disconnected_callback().await; conn.close(0u32.into(), "close".as_bytes()); debug!("connect retrying"); tokio::time::sleep(Duration::from_secs(1)).await; debug!("disconnected"); // future::select(read_from_tcp, write_to_tcp).await; } } } async fn on_disconnected_callback() { let edge = get_edge(); edge.set_authorized(false); edge.encryptor.store(Arc::new(MyEncryptor::Invalid)); } async fn on_connected_callback(local_ip: Option, stream: &mut SendStream, _pkt_id: Option) { let edge = get_edge(); // let installed_channel = install_channel.to_owned(); // let token = edge._token.lock().unwrap().clone(); // let code = edge.network_code.lock().unwrap().clone(); // let edge = get_edge(); // let edge = get_edge(); // let token = args.token.clone(); if let Some(ipaddr) = local_ip { match ipaddr { IpAddr::V4(v4) => { let ip = v4.into(); // println!("outer ip is {} => {}", v4, ip); edge.outer_ip_v4.store(ip, Ordering::Relaxed); } _other => {} } } let register_super = SdlRegisterSuper { mac: Vec::from(edge.device_config.get_mac()), pkt_id: edge.get_next_packet_id(), network_id: edge.network_id.load(Ordering::Relaxed), ip: edge.device_config.get_ip(), mask_len: edge.device_config.get_net_bit() as u32, access_token: edge.access_token.get(), // installed_channel, client_id: edge.config.node_uuid.clone(), pub_key: edge.rsa_pubkey.clone(), hostname: edge.hostname.read().unwrap().clone(), }; println!("register super: {:?}", register_super); // debug!("send register super: {:?}", register_super); // let packet_id = edge.get_next_packet_id(); let data = encode_to_tcp_message( Some(register_super), PacketType::RegisterSuper as u8, ) .unwrap(); if let Err(e) = stream.write(&data).await { error!("failed to write to tcp: {}", e.to_string()); } }