sdlan-lib-rs/src/tcp/quic.rs

743 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::{net::IpAddr, sync::{Arc, OnceLock, atomic::{AtomicBool, AtomicU64, Ordering}}, time::Duration};
use futures_util::pin_mut;
use prost::Message;
use quinn::SendStream;
use sdlan_sn_rs::{config::AF_INET, peer::{SdlanSock, V6Info}, utils::{Result, SDLanError, get_current_timestamp, ip_to_string, rsa_decrypt}};
use tokio::{sync::mpsc::{Receiver, Sender, channel}, time::sleep};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::{AesEncryptor, Chacha20Encryptor, ConnectionInfo, ConnectionState, MyEncryptor, RuleFromServer, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{ARP_REPLY, ArpHdr, EthHdr, Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlArpResponse, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, SdlanTcp, read_a_packet, send_stun_request}};
static GLOBAL_QUIC_HANDLE: OnceLock<ReadWriterHandle> = OnceLock::new();
#[derive(Debug)]
pub struct ReadWriterHandle {
connected: Arc<AtomicBool>,
send_to_tcp: Sender<Vec<u8>>,
// pub data_from_tcp: Receiver<SdlanTcp>,
}
impl ReadWriterHandle {
pub async fn send(&self, data: Vec<u8>) -> Result<()> {
if self.connected.load(Ordering::Relaxed) {
// connected, send to it
if let Err(e) = self.send_to_tcp.send(data).await {
error!("failed to send to send_to_tcp: {}", e.to_string());
return Err(SDLanError::NormalError("failed to send"));
};
debug!("tcp info sent");
} else {
error!("tcp not connected, so not sending data");
return Err(SDLanError::NormalError("not connected, so not sending"));
}
Ok(())
}
fn new<>(
cancel: CancellationToken,
addr: &str,
// on_connected: OnConnectedCallback<'a>,
// on_disconnected: T3,
// on_message: T2,
pong_time: Arc<AtomicU64>,
start_stop_chan: Receiver<StartStopInfo>,
// cancel: CancellationToken,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
) -> Self {
let (send_to_tcp, to_tcp) = channel(20);
let (from_tcp, mut data_from_tcp) = channel(20);
let connected: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
let actor = ReadWriteActor::new(
cancel,
addr,
from_tcp,
connected.clone(),
pong_time,
connecting_chan,
ipv6_network_restarter,
);
tokio::spawn(async move {
actor
.run(
true,
to_tcp,
// on_connected,
// on_disconnected,
start_stop_chan
)
.await
});
tokio::spawn(async move {
loop {
if let Some(msg) = data_from_tcp.recv().await {
handle_tcp_message(msg).await;
} else {
error!("data from tcp exited");
// eprintln!("data from tcp exited");
return;
}
}
});
ReadWriterHandle {
connected,
send_to_tcp,
// data_from_tcp,
}
}
}
pub fn get_quic_write_conn() -> &'static ReadWriterHandle {
match GLOBAL_QUIC_HANDLE.get() {
Some(v) => v,
None => panic!("should call init_tcp_conn first"),
}
}
async fn handle_tcp_message(msg: SdlanTcp) {
let edge = get_edge();
// let now = get_current_timestamp();
// edge.tcp_pong.store(now, Ordering::Relaxed);
debug!("got tcp message: {:?}", msg.packet_type);
match msg.packet_type {
PacketType::RegisterSuperACK => {
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK");
return;
};
edge.send_register_super_feedback(
ack.pkt_id,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
should_exit: false,
},
);
debug!("got register super ack: {:?}", ack);
edge.session_token.set(ack.session_token);
let Ok(key) = rsa_decrypt(&edge.rsa_private, &ack.key) else {
error!("failed to rsa decrypt aes key");
return;
};
match ack.algorithm.to_ascii_lowercase().as_str() {
"chacha20" => {
edge.encryptor.store(Arc::new(MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id))))
// *edge.encryptor.write().unwrap() = MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id));
}
"aes" => {
edge.encryptor.store(Arc::new(MyEncryptor::Aes(AesEncryptor::new(key))))
// *edge.encryptor.write().unwrap() = MyEncryptor::Aes(AesEncryptor::new(key));
}
_other => {
}
}
/*
let Some(dev) = ack.dev_addr else {
error!("no dev_addr is specified");
return;
};
*/
let ip = ip_to_string(&edge.device_config.get_ip());
// debug!("aes key is {:?}, ip is {}/{}", aes, ip, dev.net_bit_len,);
println!("assigned ip: {}", ip);
// let hostname = edge.hostname.read().unwrap().clone();
// println!("network is: {}.{}", hostname, dev.network_domain);
/*
edge.device_config
.ip
.net_addr
.store(dev.net_addr, Ordering::Relaxed);
*/
if let Some(ref chan) = edge.connection_chan {
let _ = chan.send(ConnectionInfo::IPInfo(ip)).await;
}
/*
let mac = match dev.mac.try_into() {
Err(_) => NULL_MAC,
Ok(m) => m,
};
*/
// *edge.device_config.mac.write().unwrap() = mac;
/*
edge.device_config
.ip
.net_bit_len
.store(dev.net_bit_len as u8, Ordering::Relaxed);
edge.network_id.store(dev.network_id, Ordering::Relaxed);
*/
// edge.device.reload_config(&edge.device_config, &dev.network_domain);
edge.device.reload_config(&edge.device_config, &edge.network_domain.read().unwrap().clone());
edge.set_authorized(true);
send_stun_request(edge).await;
tokio::spawn(async {
let nattype = edge.probe_nat_type().await;
debug!("nat type is {:?}", nattype);
// println!("nat type is: {:?}", nattype);
});
}
PacketType::ArpResponse => {
let Ok(resp) = SdlArpResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode ARP RESPONSE");
return;
};
if resp.target_mac.len() != 6 {
// invalid target_mac
error!("invalid target_mac: {:?}, ip={}", resp.target_mac, ip_to_string(&resp.target_ip));
return;
}
// TODO: construct the arp reply, and write to tun;
let src_mac = resp.target_mac.try_into().unwrap();
let dst_mac = edge.device_config.get_mac();
let dst_ip = edge.device_config.get_ip();
let hdr = ArpHdr{
ethhdr: EthHdr {
dest: dst_mac,
src: src_mac,
eth_type: 0x0806,
},
hwtype: 0x0001,
protocol: 0x0800,
hwlen: 6,
protolen: 4,
opcode: ARP_REPLY,
shwaddr: src_mac,
sipaddr: [((resp.target_ip >> 16) as u16) & 0xffff, (resp.target_ip as u16) & 0xffff],
dhwaddr: dst_mac,
dipaddr: [((dst_ip >> 16) & 0x0000ffff) as u16, (dst_ip & 0x0000ffff) as u16]
};
let data = hdr.marshal_to_bytes();
if let Err(_e) = edge.device.send(&data) {
error!("failed to write arp response to device");
}
}
PacketType::PolicyReply => {
let Ok(policy) = SdlPolicyResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode POLICY RESPONSE");
return;
};
let identity = policy.src_identity_id;
let mut infos = Vec::new();
let mut start = 0;
while start < policy.rules.len() {
if start + 3 > policy.rules.len() {
break;
}
let proto = policy.rules[start];
let port = u16::from_be_bytes([policy.rules[start+1], policy.rules[start+2]]);
start += 3;
infos.push(RuleFromServer{
proto,
port,
});
}
edge.rule_cache.set_identity_cache(identity, infos);
}
PacketType::RegisterSuperNAK => {
let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_NAK");
/*
edge.send_register_super_feedback(
msg._packet_id,
RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
should_exit: false,
},
);
*/
return;
};
error!("got nak: {:?}", _nak);
let pkt_id = _nak.pkt_id;
let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else {
edge.send_register_super_feedback(
//msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 2,
message: "error_code not recognized".to_owned(),
should_exit: false,
},
);
return;
};
match error_code {
NakMsgCode::InvalidToken => {
edge.send_register_super_feedback(
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 3,
message: "invalid token".to_owned(),
should_exit: true,
},
);
edge.stop().await;
}
NakMsgCode::NodeDisabled => {
edge.send_register_super_feedback(
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 4,
message: "Node is disabled".to_owned(),
should_exit: true,
},
);
edge.stop().await;
}
_other => {
edge.send_register_super_feedback(
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
should_exit: false,
},
);
}
}
/*
edge.send_register_super_feedback(msg._packet_id, RegisterSuperFeedback {
result: 1,
message: "failed to decode REGISTER SUPER NAK".to_owned(),
});
*/
edge.set_authorized(false);
edge.encryptor.store(Arc::new(MyEncryptor::Invalid));
// *edge.encryptor.write().unwrap() = MyEncryptor::Invalid;
// std::process::exit(0);
}
PacketType::Command => {
if msg.current_packet.len() < 1 {
error!("malformed COMMAND received");
return;
}
handle_tcp_command(edge, msg.current_packet[0], &msg.current_packet[1..]).await;
}
PacketType::Event => {
if msg.current_packet.len() < 1 {
error!("malformed EVENT received");
return;
}
let Ok(event) = msg.current_packet[0].try_into() else {
error!("failed to parse event type");
return;
};
handle_tcp_event(edge, event, &msg.current_packet[1..]).await;
}
PacketType::PeerInfo => {
let _ = handle_packet_peer_info(edge, &msg.current_packet[..]).await;
}
PacketType::Pong => {
debug!("tcp pong received");
let now = get_current_timestamp();
edge.tcp_pong.store(now, Ordering::Relaxed);
}
other => {
debug!("tcp not handling {:?}", other);
}
}
}
async fn handle_tcp_command(_edge: &Node, _cmdtype: u8, _cmdprotobuf: &[u8]) {}
async fn handle_tcp_event(edge: &'static Node, eventtype: EventType, eventprotobuf: &[u8]) {
match eventtype {
EventType::SendRegister => {
let Ok(reg) = SdlSendRegisterEvent::decode(eventprotobuf) else {
error!("failed to decode SendRegister Event");
return;
};
let v4 = reg.nat_ip.to_be_bytes();
let mut v6_sock = None;
if let Some(v6_info) = reg.v6_info {
if let Ok(v6_bytes) = v6_info.v6.try_into() {
v6_sock = Some(V6Info {
port: v6_info.port as u16,
v6: v6_bytes,
});
}
}
let dst_mac = match reg.dst_mac.try_into() {
Ok(m) => m,
Err(_e) => NULL_MAC,
};
let remote_nat_byte = reg.nat_type as u8;
let remote_nat = match remote_nat_byte.try_into() {
Ok(t) => t,
Err(_) => NatType::NoNat,
};
check_peer_registration_needed(
edge,
false,
dst_mac,
// &v6_sock,
remote_nat,
&v6_sock,
&SdlanSock {
family: AF_INET,
port: reg.nat_port as u16,
v4,
v6: [0; 16],
},
)
.await;
}
other => {
debug!("unhandled event {:?}", other);
}
}
}
pub fn init_quic_conn(
cancel: CancellationToken,
addr: &str,
// on_connected: OnConnectedCallback<'a>,
// on_disconnected: T3,
// on_message: T2,
pong_time: Arc<AtomicU64>,
// cancel: CancellationToken,
start_stop_chan: Receiver<StartStopInfo>,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
)
// T2: Fn(SdlanTcp) -> F + Send + 'static,
// F: Future<Output = ()> + Send,
{
let tcp_handle = ReadWriterHandle::new(
cancel,
addr,
// on_connected,
// on_disconnected,
// on_message,
pong_time,
start_stop_chan,
connecting_chan,
ipv6_network_restarter,
);
GLOBAL_QUIC_HANDLE
.set(tcp_handle)
.expect("failed to set global tcp handle");
}
pub struct ReadWriteActor {
// actor接收的发送给tcp的接收端由handle存放发送端
// to_tcp: Receiver<Vec<u8>>,
remote: String,
connected: Arc<AtomicBool>,
pong_time: Arc<AtomicU64>,
// actor收到数据之后发送给上层的发送端口,接收端由handle保存
from_tcp: Sender<SdlanTcp>,
_cancel: CancellationToken,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
}
impl ReadWriteActor {
pub fn new(
cancel: CancellationToken,
remote: &str,
from_tcp: Sender<SdlanTcp>,
connected: Arc<AtomicBool>,
pong_time: Arc<AtomicU64>,
connecting_chan: Option<Sender<ConnectionInfo>>,
ipv6_network_restarter: Option<Sender<bool>>,
) -> Self {
Self {
// to_tcp,
_cancel: cancel,
pong_time,
connected,
remote: remote.to_owned(),
from_tcp,
connecting_chan,
ipv6_network_restarter,
}
}
pub async fn run<'a>(
&self,
keep_reconnect: bool,
mut to_tcp: Receiver<Vec<u8>>,
mut start_stop_chan: Receiver<StartStopInfo>,
) {
let edge = get_edge();
let mut started = false;
let mut start_pkt_id = None;
loop {
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::NotConnected);
let _ = connecting_chan.send(state).await;
}
self.connected.store(false, Ordering::Relaxed);
if !started {
// println!("waiting for start");
loop {
let start_or_stop = start_stop_chan.recv().await;
if let Some(m) = start_or_stop {
if m.is_start {
started = true;
start_pkt_id = m.pkt_id;
break;
}
} else {
// None, just return
return;
}
}
/*
while let Some(m) = start_stop_chan.recv().await {
println!("4");
if m.is_start {
// println!("start received");
started = true;
start_pkt_id = m.pkt_id;
break;
} else {
// println!("stop received");
}
}
*/
debug!("start stop chan received: {}", started);
continue;
}
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::Connecting);
let _ = connecting_chan.send(state).await;
}
let host = self.remote.split(":").next().unwrap();
debug!("try connecting to {}, host = {}", self.remote, host);
let conn = match edge.quic_endpoint.connect(self.remote.parse().unwrap(), host) {
Ok(conn) => conn,
Err(e) => {
error!("failed to connect: {}", e);
// println!("failed to connect: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
};
let conn = match conn.await {
Err(e) => {
// println!("failed to connect await: {}", e);
error!("failed to connect await: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
Ok(conn) => conn,
};
let local_ip = conn.local_ip();
let Ok((mut send, mut recv)) = conn.open_bi().await else {
println!("failed to open-bi");
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
};
self.connected.store(true, Ordering::Relaxed);
debug!("connected");
sleep(Duration::from_millis(200)).await;
on_connected_callback(local_ip, &mut send, start_pkt_id.take()).await;
if let Some(ref connecting_chan) = self.connecting_chan {
let state = ConnectionInfo::ConnState(ConnectionState::Connected);
let _ = connecting_chan.send(state).await;
}
if let Some(ref ipv6_restarter) = self.ipv6_network_restarter {
let _ = ipv6_restarter.send(true).await;
}
// stream.write("hello".as_bytes()).await;
// let (reader, mut write) = stream.into_split();
let read_from_tcp = async move {
// let mut buffed_reader = BufReader::new(recv);
loop {
match read_a_packet(&mut recv).await {
Ok(packet) => {
warn!("got packet: {:?}", packet);
if let Err(_e) = self.from_tcp.send(packet).await {
error!("failed to receive a packet: {:?}", _e);
}
}
Err(e) => {
error!("failed to read a packet: {}, reconnecting...", e);
return;
}
}
}
};
let write_to_tcp = async {
while let Some(data) = to_tcp.recv().await {
match send.write(&data).await {
Ok(size) => {
debug!("{} bytes sent to tcp", size);
}
Err(e) => {
error!("failed to write to tcp: {}", e.to_string());
return;
}
}
}
error!("to_tcp recv None");
};
let check_pong = async {
loop {
tokio::time::sleep(Duration::from_secs(3600)).await;
let connected = self.connected.load(Ordering::Relaxed);
let now = get_current_timestamp();
if connected && now - self.pong_time.load(Ordering::Relaxed) > TCP_PING_TIME * 2
{
// pong time expire, need to re-connect
error!("pong check expired");
return;
}
}
};
let check_stop = async {
loop {
match start_stop_chan.recv().await {
Some(v) => {
if !v.is_start {
started = false;
return;
}
}
_other => {
// send chan is closed;
started = false;
return;
}
}
}
};
pin_mut!(read_from_tcp, write_to_tcp);
tokio::select! {
_ = read_from_tcp => {},
_ = write_to_tcp => {},
_ = check_pong => {},
_ = check_stop => {},
}
on_disconnected_callback().await;
conn.close(0u32.into(), "close".as_bytes());
debug!("connect retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
debug!("disconnected");
// future::select(read_from_tcp, write_to_tcp).await;
}
}
}
async fn on_disconnected_callback() {
let edge = get_edge();
edge.set_authorized(false);
edge.encryptor.store(Arc::new(MyEncryptor::Invalid));
}
async fn on_connected_callback(local_ip: Option<IpAddr>, stream: &mut SendStream, _pkt_id: Option<u32>) {
let edge = get_edge();
// let installed_channel = install_channel.to_owned();
// let token = edge._token.lock().unwrap().clone();
// let code = edge.network_code.lock().unwrap().clone();
// let edge = get_edge();
// let edge = get_edge();
// let token = args.token.clone();
if let Some(ipaddr) = local_ip {
match ipaddr {
IpAddr::V4(v4) => {
let ip = v4.into();
// println!("outer ip is {} => {}", v4, ip);
edge.outer_ip_v4.store(ip, Ordering::Relaxed);
}
_other => {}
}
}
let register_super = SdlRegisterSuper {
mac: Vec::from(edge.device_config.get_mac()),
pkt_id: edge.get_next_packet_id(),
network_id: edge.network_id.load(Ordering::Relaxed),
ip: edge.device_config.get_ip(),
mask_len: edge.device_config.get_net_bit() as u32,
access_token: edge.access_token.get(),
// installed_channel,
client_id: edge.config.node_uuid.clone(),
pub_key: edge.rsa_pubkey.clone(),
hostname: edge.hostname.read().unwrap().clone(),
};
println!("register super: {:?}", register_super);
// debug!("send register super: {:?}", register_super);
// let packet_id = edge.get_next_packet_id();
let data = encode_to_tcp_message(
Some(register_super),
PacketType::RegisterSuper as u8,
)
.unwrap();
if let Err(e) = stream.write(&data).await {
error!("failed to write to tcp: {}", e.to_string());
}
}