sdlan-lib-rs/src/tcp/quic.rs

775 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::{net::IpAddr, sync::{Arc, OnceLock, atomic::{AtomicBool, AtomicU64, Ordering}}, time::Duration};
use futures_util::pin_mut;
use prost::Message;
use quinn::SendStream;
use sdlan_sn_rs::{config::AF_INET, peer::{SdlanSock, V6Info}, utils::{Result, SDLanError, get_current_timestamp, ip_to_string, rsa_decrypt}};
use tokio::{sync::mpsc::{Receiver, Sender, channel}, time::sleep};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn, info};
#[cfg(target_os = "linux")]
use crate::network::{set_allow_routing, set_disallow_routing};
use crate::{AesEncryptor, Chacha20Encryptor, ConnectionInfo, ConnectionState, MyEncryptor, RuleFromServer, config::{NULL_MAC, TCP_PING_TIME}, get_edge, load_configuration, network::{ARP_REPLY, ArpHdr, EthHdr, Node, RegisterSuperFeedback, StartStopInfo, arp_reply_arrived, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlArpResponse, SdlCommand, SdlCommandAck, SdlEvent, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, encode_to_tcp_message, sdl_command, sdl_event::{self, Event, SendRegister}}, store_configuration, tcp::{EventType, NakMsgCode, NatType, PacketType, SdlanTcp, read_a_packet, send_stun_request}};
static GLOBAL_QUIC_HANDLE: OnceLock<ReadWriterHandle> = OnceLock::new();
#[derive(Debug)]
pub struct ReadWriterHandle {
connected: Arc<AtomicBool>,
send_to_tcp: Sender<Vec<u8>>,
// pub data_from_tcp: Receiver<SdlanTcp>,
}
impl ReadWriterHandle {
pub async fn send(&self, data: Vec<u8>) -> Result<()> {
if self.connected.load(Ordering::Relaxed) {
// connected, send to it
if let Err(e) = self.send_to_tcp.send(data).await {
error!("failed to send to send_to_tcp: {}", e.to_string());
return Err(SDLanError::NormalError("failed to send"));
};
// debug!("tcp info sent");
} else {
error!("tcp not connected, so not sending data");
return Err(SDLanError::NormalError("not connected, so not sending"));
}
Ok(())
}
fn new(
cancel: CancellationToken,
addr: &str,
domain: &str,
// on_connected: OnConnectedCallback<'a>,
// on_disconnected: T3,
// on_message: T2,
pong_time: Arc<AtomicU64>,
start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
) -> Self {
let (send_to_tcp, to_tcp) = channel(20);
let (from_tcp, mut data_from_tcp) = channel(20);
let connected: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
let actor = ReadWriteActor::new(
cancel,
addr,
domain,
from_tcp,
connected.clone(),
pong_time,
connecting_chan,
ipv6_network_restarter,
);
tokio::spawn(async move {
actor
.run(
true,
to_tcp,
// on_connected,
// on_disconnected,
start_stop_chan
)
.await
});
tokio::spawn(async move {
loop {
if let Some(msg) = data_from_tcp.recv().await {
handle_tcp_message(msg).await;
} else {
error!("data from tcp exited");
// println!("data from tcp exited");
// eprintln!("data from tcp exited");
return;
}
}
});
ReadWriterHandle {
connected,
send_to_tcp,
// data_from_tcp,
}
}
}
pub fn get_quic_write_conn() -> &'static ReadWriterHandle {
match GLOBAL_QUIC_HANDLE.get() {
Some(v) => v,
None => panic!("should call init_tcp_conn first"),
}
}
async fn handle_tcp_message(msg: SdlanTcp) {
let edge = get_edge();
// let now = get_current_timestamp();
// edge.tcp_pong.store(now, Ordering::Relaxed);
debug!("handling tcp message: {:?}", msg.packet_type);
match msg.packet_type {
PacketType::RegisterSuperACK => {
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK");
return;
};
edge.send_register_super_feedback(
0,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
should_exit: false,
},
);
debug!("got register super ack");
edge.session_token.set(ack.session_token);
let Ok(key) = rsa_decrypt(&edge.rsa_private, &ack.key) else {
error!("failed to rsa decrypt aes key");
return;
};
match ack.algorithm.to_ascii_lowercase().as_str() {
"chacha20" => {
edge.encryptor.store(Arc::new(MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id))))
// *edge.encryptor.write().unwrap() = MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id));
}
"aes" => {
edge.encryptor.store(Arc::new(MyEncryptor::Aes(AesEncryptor::new(key))))
// *edge.encryptor.write().unwrap() = MyEncryptor::Aes(AesEncryptor::new(key));
}
_other => {
}
}
/*
let Some(dev) = ack.dev_addr else {
error!("no dev_addr is specified");
return;
};
*/
let ip = ip_to_string(&edge.device_config.get_ip());
// debug!("aes key is {:?}, ip is {}/{}", aes, ip, dev.net_bit_len,);
println!("assigned ip: {}", ip);
debug!("assigned ip: {}", ip);
// let hostname = edge.hostname.read().unwrap().clone();
// println!("network is: {}.{}", hostname, dev.network_domain);
/*
edge.device_config
.ip
.net_addr
.store(dev.net_addr, Ordering::Relaxed);
*/
if let Some(ref chan) = edge.connection_chan {
let _ = chan.send(ConnectionInfo::IPInfo(ip)).await;
}
/*
let mac = match dev.mac.try_into() {
Err(_) => NULL_MAC,
Ok(m) => m,
};
*/
// *edge.device_config.mac.write().unwrap() = mac;
/*
edge.device_config
.ip
.net_bit_len
.store(dev.net_bit_len as u8, Ordering::Relaxed);
edge.network_id.store(dev.network_id, Ordering::Relaxed);
*/
// edge.device.reload_config(&edge.device_config, &dev.network_domain);
edge.device.reload_config(edge, &edge.device_config, &edge.network_domain.read().unwrap().clone());
edge.set_authorized(true);
send_stun_request(edge).await;
tokio::spawn(async {
let nattype = edge.probe_nat_type().await;
debug!("nat type is {:?}", nattype);
// println!("nat type is: {:?}", nattype);
});
}
PacketType::ArpResponse => {
let Ok(res) = SdlArpResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode ARP RESPONSE");
return;
};
arp_reply_arrived(edge, res).await;
return;
}
PacketType::PolicyReply => {
let Ok(policy) = SdlPolicyResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode POLICY RESPONSE");
return;
};
let identity = policy.src_identity_id;
let mut infos = Vec::new();
let mut start = 0;
while start < policy.rules.len() {
if start + 3 > policy.rules.len() {
break;
}
let proto = policy.rules[start];
let port = u16::from_be_bytes([policy.rules[start+1], policy.rules[start+2]]);
start += 3;
infos.push(RuleFromServer{
proto,
port,
});
}
edge.rule_cache.set_identity_cache(identity, infos);
}
PacketType::RegisterSuperNAK => {
let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_NAK");
/*
edge.send_register_super_feedback(
msg._packet_id,
RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
should_exit: false,
},
);
*/
return;
};
error!("got nak: {:?}", _nak);
// let pkt_id = _nak.pkt_id;
let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else {
edge.send_register_super_feedback(
//msg._packet_id,
0,
RegisterSuperFeedback {
result: 2,
message: "error_code not recognized".to_owned(),
should_exit: false,
},
);
return;
};
match error_code {
NakMsgCode::InvalidToken => {
edge.send_register_super_feedback(
// msg._packet_id,
0,
RegisterSuperFeedback {
result: 3,
message: "invalid token".to_owned(),
should_exit: true,
},
);
edge.stop().await;
}
NakMsgCode::NodeDisabled => {
edge.send_register_super_feedback(
// msg._packet_id,
0,
RegisterSuperFeedback {
result: 4,
message: "Node is disabled".to_owned(),
should_exit: true,
},
);
edge.stop().await;
}
_other => {
edge.send_register_super_feedback(
// msg._packet_id,
0,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
should_exit: false,
},
);
}
}
/*
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
*/
edge.set_authorized(false);
edge.encryptor.store(Arc::new(MyEncryptor::Invalid));
// *edge.encryptor.write().unwrap() = MyEncryptor::Invalid;
// std::process::exit(0);
}
PacketType::Command => {
if msg.current_packet.len() < 1 {
error!("malformed COMMAND received");
return;
}
handle_tcp_command(edge, &msg.current_packet[..]).await;
}
PacketType::Event => {
if msg.current_packet.len() < 1 {
error!("malformed EVENT received");
return;
}
handle_tcp_event(edge, &msg.current_packet[..]).await;
}
PacketType::PeerInfo => {
let _ = handle_packet_peer_info(edge, &msg.current_packet[..]).await;
}
PacketType::Pong => {
debug!("tcp pong received");
let now = get_current_timestamp();
edge.tcp_pong.store(now, Ordering::Relaxed);
}
other => {
debug!("tcp not handling {:?}", other);
}
}
}
async fn handle_tcp_command(edge: &Node, cmdprotobuf: &[u8]) {
let Ok(cmd) = SdlCommand::decode(cmdprotobuf) else {
error!("failed to decode SdlCommand");
return;
};
let pkt_id = cmd.pkt_id;
let Some(command) = cmd.command else {
error!("command type is none");
return;
};
match command {
sdl_command::Command::ExitNode(node) => {
debug!("got exit node command: {:?}", node);
// println!("got exit node command: {:?}", node);
// std::process::exit(0);
if node.action == 0 {
// stop
let origin = edge.config.allow_routing.fetch_and(false, Ordering::Relaxed);
let mut config = load_configuration();
config.allow_routing = Some(false);
let _ = store_configuration(&config);
if origin {
#[cfg(target_os = "linux")]
set_disallow_routing();
}
} else {
// start
let origin = edge.config.allow_routing.fetch_or(true, Ordering::Relaxed);
let mut config = load_configuration();
config.allow_routing = Some(true);
let _ = store_configuration(&config);
if !origin {
#[cfg(target_os = "linux")]
set_allow_routing();
}
}
let ack = SdlCommandAck{
pkt_id,
code: 0,
message: "ok".to_owned(),
data: vec![],
};
let msg = encode_to_tcp_message(Some(ack), PacketType::CommandACK as u8).unwrap();
let stream = get_quic_write_conn();
if let Err(e) = stream.send(msg).await {
error!("failed to write command ack to quic: {}", e.as_str());
}
}
}
}
async fn handle_tcp_event(edge: &'static Node, eventprotobuf: &[u8]) {
let Ok(sdl_event) = SdlEvent::decode(eventprotobuf) else {
error!("failed to decode SdlEvent");
return;
};
let Some(event) = sdl_event.event else {
error!("event type is none");
return;
};
match event {
sdl_event::Event::SendRegister(reg) => {
let v4 = reg.nat_ip.to_be_bytes();
let mut v6_sock = None;
if let Some(v6_info) = reg.v6_info {
if let Ok(v6_bytes) = v6_info.v6.try_into() {
v6_sock = Some(V6Info {
port: v6_info.port as u16,
v6: v6_bytes,
});
}
}
let dst_mac = match reg.dst_mac.try_into() {
Ok(m) => m,
Err(_e) => NULL_MAC,
};
let remote_nat_byte = reg.nat_type as u8;
let remote_nat = match remote_nat_byte.try_into() {
Ok(t) => t,
Err(_) => NatType::NoNat,
};
check_peer_registration_needed(
edge,
false,
dst_mac,
// &v6_sock,
remote_nat,
&v6_sock,
&SdlanSock {
family: AF_INET,
port: reg.nat_port as u16,
v4,
v6: [0; 16],
},
)
.await;
}
other => {
debug!("unhandled event {:?}", other);
}
}
}
pub fn init_quic_conn(
cancel: CancellationToken,
addr: &str,
domain: &str,
// on_connected: OnConnectedCallback<'a>,
// on_disconnected: T3,
// on_message: T2,
pong_time: Arc<AtomicU64>,
// cancel: CancellationToken,
start_stop_chan: Receiver<StartStopInfo>,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
)
// T2: Fn(SdlanTcp) -> F + Send + 'static,
// F: Future<Output = ()> + Send,
{
let tcp_handle = ReadWriterHandle::new(
cancel,
addr,
domain,
// on_connected,
// on_disconnected,
// on_message,
pong_time,
start_stop_chan,
connecting_chan,
ipv6_network_restarter,
);
GLOBAL_QUIC_HANDLE
.set(tcp_handle)
.expect("failed to set global tcp handle");
}
pub struct ReadWriteActor {
// actor接收的发送给tcp的接收端由handle存放发送端
// to_tcp: Receiver<Vec<u8>>,
// remote is the ip version of the remote server
remote: String,
// hostname is the domain name of the remote server
domain: String,
connected: Arc<AtomicBool>,
pong_time: Arc<AtomicU64>,
// actor收到数据之后发送给上层的发送端口,接收端由handle保存
from_tcp: Sender<SdlanTcp>,
_cancel: CancellationToken,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
}
impl ReadWriteActor {
pub fn new(
cancel: CancellationToken,
remote: &str,
domain: &str,
from_tcp: Sender<SdlanTcp>,
connected: Arc<AtomicBool>,
pong_time: Arc<AtomicU64>,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
) -> Self {
Self {
// to_tcp,
_cancel: cancel,
pong_time,
connected,
domain: domain.to_owned(),
remote: remote.to_owned(),
from_tcp,
connecting_chan,
ipv6_network_restarter,
}
}
pub async fn run<'a>(
&self,
keep_reconnect: bool,
mut to_tcp: Receiver<Vec<u8>>,
mut start_stop_chan: Receiver<StartStopInfo>,
) {
let edge = get_edge();
let mut started = false;
// let mut start_pkt_id = None;
loop {
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::NotConnected);
let _ = connecting_chan.send(state).await;
}
self.connected.store(false, Ordering::Relaxed);
if !started {
// println!("waiting for start");
loop {
let start_or_stop = start_stop_chan.recv().await;
if let Some(m) = start_or_stop {
if m.is_start {
started = true;
// start_pkt_id = m.pkt_id;
break;
}
} else {
// None, just return
info!("start or stop is None");
return;
}
}
debug!("start stop chan received: {}", started);
continue;
}
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::Connecting);
let _ = connecting_chan.send(state).await;
}
debug!("try connecting to {}", self.domain);
let conn = match edge.quic_endpoint.connect(self.remote.parse().unwrap(), &self.domain) {
Ok(conn) => conn,
Err(e) => {
error!("failed to connect: {}", e);
// println!("failed to connect: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
};
let conn = match conn.await {
Err(e) => {
// println!("failed to connect await: {}", e);
error!("failed to connect await: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
Ok(conn) => conn,
};
let local_ip = conn.local_ip();
let Ok((mut send, mut recv)) = conn.open_bi().await else {
error!("failed to open-bi");
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
};
self.connected.store(true, Ordering::Relaxed);
debug!("connected");
sleep(Duration::from_millis(200)).await;
on_connected_callback(local_ip, &mut send).await;
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::Connected);
let _ = connecting_chan.send(state).await;
}
if let Some(ref ipv6_restarter) = self.ipv6_network_restarter {
let _ = ipv6_restarter.send(true).await;
}
// stream.write("hello".as_bytes()).await;
// let (reader, mut write) = stream.into_split();
let read_from_tcp = async move {
// let mut buffed_reader = BufReader::new(recv);
loop {
match read_a_packet(&mut recv).await {
Ok(packet) => {
warn!("got packet: {:?}", packet.packet_type);
if let Err(_e) = self.from_tcp.send(packet).await {
error!("failed to receive a packet: {:?}", _e);
}
}
Err(e) => {
error!("failed to read a packet: {}, reconnecting...", e);
return;
}
}
}
};
let write_to_tcp = async {
while let Some(data) = to_tcp.recv().await {
// debug!("data size = {}", data.len());
match send.write(&data).await {
Ok(_size) => {
// debug!("{} bytes sent to tcp", size);
}
Err(e) => {
error!("failed to write to tcp: {}", e.to_string());
return;
}
}
}
error!("to_tcp recv None");
};
let check_pong = async {
loop {
tokio::time::sleep(Duration::from_secs(3600)).await;
let connected = self.connected.load(Ordering::Relaxed);
let now = get_current_timestamp();
if connected && now - self.pong_time.load(Ordering::Relaxed) > TCP_PING_TIME * 2
{
// pong time expire, need to re-connect
error!("pong check expired");
return;
}
}
};
let check_stop = async {
loop {
match start_stop_chan.recv().await {
Some(v) => {
if !v.is_start {
started = false;
return;
}
}
_other => {
// send chan is closed;
started = false;
return;
}
}
}
};
pin_mut!(read_from_tcp, write_to_tcp);
tokio::select! {
_ = read_from_tcp => {},
_ = write_to_tcp => {},
_ = check_pong => {},
_ = check_stop => {},
}
on_disconnected_callback().await;
conn.close(0u32.into(), "close".as_bytes());
debug!("connect retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
debug!("disconnected");
// future::select(read_from_tcp, write_to_tcp).await;
}
}
}
async fn on_disconnected_callback() {
let edge = get_edge();
edge.set_authorized(false);
edge.encryptor.store(Arc::new(MyEncryptor::Invalid));
}
async fn on_connected_callback(local_ip: Option<IpAddr>, stream: &mut SendStream) {
let edge = get_edge();
// let installed_channel = install_channel.to_owned();
// let token = edge._token.lock().unwrap().clone();
// let code = edge.network_code.lock().unwrap().clone();
// let edge = get_edge();
// let edge = get_edge();
// let token = args.token.clone();
if let Some(ipaddr) = local_ip {
match ipaddr {
IpAddr::V4(v4) => {
let ip = v4.into();
// println!("outer ip is {} => {}", v4, ip);
edge.outer_ip_v4.store(ip, Ordering::Relaxed);
}
_other => {}
}
}
let register_super = SdlRegisterSuper {
mac: Vec::from(edge.device_config.get_mac()),
// pkt_id: edge.get_next_packet_id(),
network_id: edge.network_id.load(Ordering::Relaxed),
ip: edge.device_config.get_ip(),
mask_len: edge.device_config.get_net_bit() as u32,
access_token: edge.access_token.get(),
// installed_channel,
client_id: edge.config.node_uuid.clone(),
pub_key: edge.rsa_pubkey.clone(),
hostname: edge.hostname.read().unwrap().clone(),
};
// println!("register super: {:?}", register_super);
// debug!("send register super: {:?}", register_super);
// let packet_id = edge.get_next_packet_id();
let data = encode_to_tcp_message(
Some(register_super),
PacketType::RegisterSuper as u8,
)
.unwrap();
if let Err(e) = stream.write(&data).await {
error!("failed to write to tcp: {}", e.to_string());
}
}