Compare commits

...

4 Commits

13 changed files with 380 additions and 125 deletions

View File

@ -1,4 +1,4 @@
{ {
// "rust-analyzer.cargo.target": "x86_64-pc-windows-gnu", "rust-analyzer.cargo.target": "x86_64-pc-windows-gnu",
// "rust-analyzer.cargo.features": ["tun"] // "rust-analyzer.cargo.features": ["tun"]
} }

View File

@ -55,8 +55,12 @@ message SDLRegisterSuper {
// quic去通讯session_token校验 // quic去通讯session_token校验
message SDLRegisterSuperAck { message SDLRegisterSuperAck {
uint32 pkt_id = 1; uint32 pkt_id = 1;
bytes aes_key = 2; // aes chacha20
bytes session_token = 3; string algorithm = 2;
bytes key = 3;
// chacha20加密算法需要使用该字段
uint32 region_id = 4;
bytes session_token = 5;
} }
message SDLRegisterSuperNak { message SDLRegisterSuperNak {

View File

@ -2,18 +2,19 @@ mod api;
mod local_udp_info; mod local_udp_info;
use std::fs; use std::fs;
use std::fs::File;
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::io::stdout;
use std::process; use std::process;
use std::env; use std::env;
use std::time::Duration;
use clap::Parser; use clap::Parser;
#[cfg(not(target_os = "windows"))]
use daemonize::Daemonize; use daemonize::Daemonize;
use futures_util::io; use futures_util::io;
use libc::SIGTERM;
use libc::kill; #[cfg(not(target_os = "windows"))]
use libc::{SIGTERM, kill};
use punchnet::CachedLoginInfo; use punchnet::CachedLoginInfo;
use punchnet::CommandLineInput2; use punchnet::CommandLineInput2;
use punchnet::Commands; use punchnet::Commands;
@ -22,26 +23,20 @@ use punchnet::get_access_token;
use punchnet::get_base_dir; use punchnet::get_base_dir;
use punchnet::get_edge; use punchnet::get_edge;
use punchnet::ip_string_to_u32; use punchnet::ip_string_to_u32;
use punchnet::mod_hostname;
use punchnet::restore_dns; use punchnet::restore_dns;
use punchnet::run_sdlan; use punchnet::run_sdlan;
use punchnet::set_access_token; use punchnet::set_access_token;
use punchnet::set_base_dir; use punchnet::set_base_dir;
use punchnet::CommandLine; use punchnet::CommandLine;
use punchnet::CommandLineInput;
use sdlan_sn_rs::log; use sdlan_sn_rs::log;
use sdlan_sn_rs::utils::Mac; use sdlan_sn_rs::utils::Mac;
use sdlan_sn_rs::utils::Result; use sdlan_sn_rs::utils::Result;
use sdlan_sn_rs::utils::create_or_load_uuid; use sdlan_sn_rs::utils::create_or_load_uuid;
use tokio::io::AsyncWriteExt;
use tokio::net::UdpSocket;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tokio::time::sleep;
use tracing::error; use tracing::error;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use structopt::StructOpt;
use crate::api::ConnectData; use crate::api::ConnectData;
use crate::api::ConnectResponse; use crate::api::ConnectResponse;
@ -272,15 +267,27 @@ async fn daemonize_me(
} }
#[cfg(target_os = "windows")]
const SYSTEM: &'static str = "windows";
#[cfg(target_os = "windows")]
const DEFAULT_BASE_DIR: &'static str = ".";
#[cfg(not(target_os = "windows"))]
const SYSTEM: &'static str = "linux";
#[cfg(target_os = "linux")]
const DEFAULT_BASE_DIR: &'static str = "/usr/local/punchnet";
fn main() { fn main() {
set_base_dir("/usr/local/punchnet"); set_base_dir(DEFAULT_BASE_DIR);
// let _guard = log::init_log(&format!("{}/.output", get_base_dir())); // let _guard = log::init_log(&format!("{}/.output", get_base_dir()));
let client_id = create_or_load_uuid(&format!("{}/.id", get_base_dir()), None).unwrap(); let client_id = create_or_load_uuid(&format!("{}/.id", get_base_dir()), None).unwrap();
let mac = create_or_load_mac(); let mac = create_or_load_mac();
let system = "linux"; let system = SYSTEM;
let version = "1.0.0"; let version = env!("CARGO_PKG_VERSION");
// let cmd = CommandLineInput::from_args(); // let cmd = CommandLineInput::from_args();
let cmd = CommandLineInput2::parse(); let cmd = CommandLineInput2::parse();
@ -306,6 +313,8 @@ fn main() {
}); });
process::exit(0); process::exit(0);
} }
#[cfg(not(target_os = "windows"))]
Commands::Stop => { Commands::Stop => {
match fs::read_to_string("/tmp/punchnet.pid") { match fs::read_to_string("/tmp/punchnet.pid") {
Ok(content) => { Ok(content) => {
@ -345,83 +354,99 @@ fn main() {
} }
} }
let out = OpenOptions::new() let should_daemonize = true;
.create(true)
.truncate(true)
.write(true)
.open("/tmp/punchnet.out").unwrap();
let err = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open("/tmp/punchnet.err").unwrap();
let daemonize = Daemonize::new() #[cfg(not(target_os = "windows"))]
.pid_file("/tmp/punchnet.pid") if should_daemonize {
.chown_pid_file(true) let out = OpenOptions::new()
.working_directory(get_base_dir()) .create(true)
.stdout(out) .truncate(true)
.stderr(err) .write(true)
.privileged_action(|| { .open("/tmp/punchnet.out").unwrap();
}); let err = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open("/tmp/punchnet.err").unwrap();
match daemonize.start() { let daemonize = Daemonize::new()
Ok(_) => { .pid_file("/tmp/punchnet.pid")
let rt = Runtime::new().unwrap(); .chown_pid_file(true)
match &cmd.cmd { .working_directory(get_base_dir())
Commands::Start => { .stdout(out)
rt.block_on(async move { .stderr(err)
let remembered_token = get_access_token(); .privileged_action(|| {
if remembered_token.is_none() { });
eprintln!("not logged in, should login with user/pass or token first");
process::exit(-2);
}
let remembered = remembered_token.unwrap(); match daemonize.start() {
Ok(_) => {
let connect_info = parse_connect_result( run_it(cmd, client_id, mac, system, version);
connect(TEST_PREFIX, &client_id, &remembered.access_token).await
);
daemonize_me(connect_info, remembered, client_id, mac).await;
})
}
Commands::AutoRun(tk) => {
rt.block_on(async move {
let mut remembered_token = get_access_token();
if remembered_token.is_none() {
let data = parse_login_result(
login_with_token(TEST_PREFIX, &client_id, &tk.token, mac, system, version).await
);
remembered_token = Some(CachedLoginInfo{
access_token: data.access_token,
username: data.username,
user_type: data.user_type,
audit: data.audit,
network_id: data.network_id,
network_name: data.network_name,
});
}
let remembered = remembered_token.unwrap();
let connect_info = parse_connect_result(
connect(TEST_PREFIX, &client_id, &remembered.access_token).await
);
daemonize_me(connect_info, remembered, client_id, mac).await;
})
}
other => {
eprintln!("should not comes here");
process::exit(-1);
}
} }
Err(e) => {
eprintln!("failed to daemonize");
}
}
} else {
run_it(cmd, client_id, mac, system, version);
}
#[cfg(target_os = "windows")]
run_it(cmd, client_id, mac, system, version);
}
fn run_it(cmd: CommandLineInput2, client_id: String, mac: Mac, system: &str, version: &str) {
let rt = Runtime::new().unwrap();
match &cmd.cmd {
Commands::Start => {
rt.block_on(async move {
let remembered_token = get_access_token();
if remembered_token.is_none() {
eprintln!("not logged in, should login with user/pass or token first");
process::exit(-2);
}
let remembered = remembered_token.unwrap();
let connect_info = parse_connect_result(
connect(TEST_PREFIX, &client_id, &remembered.access_token).await
);
daemonize_me(connect_info, remembered, client_id, mac).await;
})
}
Commands::AutoRun(tk) => {
rt.block_on(async move {
let mut remembered_token = get_access_token();
if remembered_token.is_none() {
let data = parse_login_result(
login_with_token(TEST_PREFIX, &client_id, &tk.token, mac, system, version).await
);
remembered_token = Some(CachedLoginInfo{
access_token: data.access_token,
username: data.username,
user_type: data.user_type,
audit: data.audit,
network_id: data.network_id,
network_name: data.network_name,
});
}
let remembered = remembered_token.unwrap();
let connect_info = parse_connect_result(
connect(TEST_PREFIX, &client_id, &remembered.access_token).await
);
daemonize_me(connect_info, remembered, client_id, mac).await;
})
} }
Err(e) => { other => {
eprintln!("failed to daemonize"); eprintln!("should not comes here");
process::exit(-1);
} }
} }
} }
pub fn delete_pid_file() { pub fn delete_pid_file() {

View File

@ -252,7 +252,7 @@ async fn loop_tap(eee: &'static Node, cancel: CancellationToken) {
packet.extend_from_slice(&reply); packet.extend_from_slice(&reply);
/// TODO: check the packet should /// TODO: check the packet should
if let Err(_e) = eee.device.handle_packet_from_net(&packet, &Vec::new()).await { if let Err(_e) = eee.device.handle_packet_from_net(&packet).await {
error!("failed to write dns packet to device"); error!("failed to write dns packet to device");
} }
} }
@ -311,14 +311,16 @@ async fn read_and_parse_tun_packet(eee: &'static Node, buf: Vec<u8>) {
async fn edge_send_packet_to_net(eee: &Node, data: Vec<u8>) { async fn edge_send_packet_to_net(eee: &Node, data: Vec<u8>) {
// debug!("edge send packet to net({} bytes): {:?}", data.len(), data); // debug!("edge send packet to net({} bytes): {:?}", data.len(), data);
/*
let encrypt_key = eee.get_encrypt_key(); let encrypt_key = eee.get_encrypt_key();
if encrypt_key.len() == 0 { if encrypt_key.len() == 0 {
error!("drop tun packet due to encrypt key len is 0"); error!("drop tun packet due to encrypt key len is 0");
return; return;
} }
*/
if let Err(e) = eee if let Err(e) = eee
.device .device
.handle_packet_from_device(data, encrypt_key.as_slice()) .handle_packet_from_device(data)
.await .await
{ {
error!("failed to handle packet from device: {}", e.to_string()); error!("failed to handle packet from device: {}", e.to_string());

View File

@ -13,7 +13,7 @@ use tokio::sync::oneshot;
use tracing::{debug, error}; use tracing::{debug, error};
use crate::quic::quic_init; use crate::quic::quic_init;
use crate::{ConnectionInfo, get_base_dir}; use crate::{ConnectionInfo, Encryptor, MyEncryptor, get_base_dir};
use crate::pb::{ use crate::pb::{
encode_to_tcp_message, encode_to_udp_message, SdlEmpty, SdlStunProbe, SdlStunProbeReply, encode_to_tcp_message, encode_to_udp_message, SdlEmpty, SdlStunProbe, SdlStunProbeReply,
}; };
@ -162,6 +162,8 @@ impl IdentityID {
pub struct Node { pub struct Node {
packet_id: AtomicU32, packet_id: AtomicU32,
pub encryptor: RwLock<MyEncryptor>,
pub network_id: AtomicU32, pub network_id: AtomicU32,
pub network_domain: RwLock<String>, pub network_domain: RwLock<String>,
@ -192,7 +194,7 @@ pub struct Node {
// authorize related // authorize related
pub authorized: AtomicBool, pub authorized: AtomicBool,
// pub header_key: RwLock<Arc<Vec<u8>>>, // pub header_key: RwLock<Arc<Vec<u8>>>,
pub encrypt_key: RwLock<Arc<Vec<u8>>>, // pub encrypt_key: RwLock<Arc<Vec<u8>>>,
pub rsa_pubkey: String, pub rsa_pubkey: String,
pub rsa_private: RsaPrivateKey, pub rsa_private: RsaPrivateKey,
@ -384,6 +386,8 @@ impl Node {
Self { Self {
packet_id: AtomicU32::new(1), packet_id: AtomicU32::new(1),
encryptor: RwLock::new(MyEncryptor::new()),
network_id: AtomicU32::new(0), network_id: AtomicU32::new(0),
hostname: RwLock::new(hostname), hostname: RwLock::new(hostname),
@ -411,7 +415,7 @@ impl Node {
device: new_iface("dev", mode), device: new_iface("dev", mode),
authorized: AtomicBool::new(false), authorized: AtomicBool::new(false),
encrypt_key: RwLock::new(Arc::new(Vec::new())), // encrypt_key: RwLock::new(Arc::new(Vec::new())),
// rsa_pubkey: // rsa_pubkey:
rsa_pubkey: pubkey, rsa_pubkey: pubkey,
rsa_private: private, rsa_private: private,
@ -456,9 +460,9 @@ impl Node {
self.authorized.load(Ordering::Relaxed) self.authorized.load(Ordering::Relaxed)
} }
pub fn set_authorized(&self, authorized: bool, encrypt_key: Vec<u8>) { pub fn set_authorized(&self, authorized: bool) {
self.authorized.store(authorized, Ordering::Relaxed); self.authorized.store(authorized, Ordering::Relaxed);
*(self.encrypt_key.write().unwrap()) = Arc::new(encrypt_key); // *(self.encrypt_key.write().unwrap()) = Arc::new(encrypt_key);
} }
/* /*
@ -467,9 +471,11 @@ impl Node {
} }
*/ */
/*
pub fn get_encrypt_key(&self) -> Arc<Vec<u8>> { pub fn get_encrypt_key(&self) -> Arc<Vec<u8>> {
self.encrypt_key.read().unwrap().clone() self.encrypt_key.read().unwrap().clone()
} }
*/
/* /*
pub fn sn_is_known(&self, sock: &SdlanSock) -> bool { pub fn sn_is_known(&self, sock: &SdlanSock) -> bool {

View File

@ -21,7 +21,7 @@ use sdlan_sn_rs::{
config::{AF_INET, AF_INET6}, config::{AF_INET, AF_INET6},
peer::{is_sdlan_sock_equal, SdlanSock, V6Info}, peer::{is_sdlan_sock_equal, SdlanSock, V6Info},
utils::{ utils::{
aes_decrypt, get_current_timestamp, get_sdlan_sock_from_socketaddr, is_multi_broadcast, get_current_timestamp, get_sdlan_sock_from_socketaddr, is_multi_broadcast,
Mac, Result, SDLanError, Mac, Result, SDLanError,
}, },
}; };
@ -883,15 +883,18 @@ async fn handle_tun_packet(
pkt: SdlData, //orig_sender: &SdlanSock pkt: SdlData, //orig_sender: &SdlanSock
) { ) {
let payload = pkt.data; let payload = pkt.data;
let key = eee.get_encrypt_key(); //let key = eee.get_encrypt_key();
if key.len() == 0 {
// if key.len() == 0 {
// check the encrypt key // check the encrypt key
error!("packet encrypt key not provided"); // error!("packet encrypt key not provided");
return; // return;
} // }
// test_aes(key.as_slice()); // test_aes(key.as_slice());
let origin = aes_decrypt(key.as_slice(), &payload);
let origin = eee.encryptor.read().unwrap().decrypt(&payload);
// let origin = aes_decrypt(&payload);
if let Err(_e) = origin { if let Err(_e) = origin {
error!("failed to decrypt original data"); error!("failed to decrypt original data");
return; return;
@ -947,7 +950,7 @@ async fn handle_tun_packet(
debug!("sending packet to tun, {} bytes", data.len()); debug!("sending packet to tun, {} bytes", data.len());
if let Err(e) = eee if let Err(e) = eee
.device .device
.handle_packet_from_net(&data, key.as_slice()) .handle_packet_from_net(&data)
.await .await
{ {
error!("failed to handle packet from net: {}", e.to_string()); error!("failed to handle packet from net: {}", e.to_string());

View File

@ -197,7 +197,7 @@ impl Iface {
#[cfg(not(feature = "tun"))] #[cfg(not(feature = "tun"))]
impl TunTapPacketHandler for Iface { impl TunTapPacketHandler for Iface {
async fn handle_packet_from_net(&self, data: &[u8], _: &[u8]) -> std::io::Result<()> { async fn handle_packet_from_net(&self, data: &[u8]) -> std::io::Result<()> {
// debug!("in tap mode, got data: {:?}", data); // debug!("in tap mode, got data: {:?}", data);
match self.send(data) { match self.send(data) {
Err(e) => { Err(e) => {
@ -207,10 +207,11 @@ impl TunTapPacketHandler for Iface {
Ok(_) => return Ok(()), Ok(_) => return Ok(()),
} }
} }
async fn handle_packet_from_device( async fn handle_packet_from_device(
&self, &self,
data: Vec<u8>, data: Vec<u8>,
encrypt_key: &[u8], // encrypt_key: &[u8],
) -> std::io::Result<()> { ) -> std::io::Result<()> {
use etherparse::PacketHeaders; use etherparse::PacketHeaders;
@ -288,7 +289,8 @@ impl TunTapPacketHandler for Iface {
} }
let size = data.len(); let size = data.len();
let Ok(encrypted) = aes_encrypt(encrypt_key, &data) else { let Ok(encrypted) = edge.encryptor.read().unwrap().encrypt(&data) else {
// let Ok(encrypted) = aes_encrypt(encrypt_key, &data) else {
error!("failed to encrypt packet request"); error!("failed to encrypt packet request");
return Ok(()); return Ok(());
}; };
@ -319,7 +321,7 @@ impl TunTapPacketHandler for Iface {
#[cfg(feature = "tun")] #[cfg(feature = "tun")]
impl TunTapPacketHandler for Iface { impl TunTapPacketHandler for Iface {
async fn handle_packet_from_net(&self, data: &[u8], key: &[u8]) -> std::io::Result<()> { async fn handle_packet_from_net(&self, data: &[u8]) -> std::io::Result<()> {
debug!("in tun mode"); debug!("in tun mode");
// got layer 2 frame // got layer 2 frame
@ -399,7 +401,8 @@ impl TunTapPacketHandler for Iface {
[((self_ip >> 16) & 0xffff) as u16, (self_ip & 0xffff) as u16]; [((self_ip >> 16) & 0xffff) as u16, (self_ip & 0xffff) as u16];
let data = arp.marshal_to_bytes(); let data = arp.marshal_to_bytes();
let Ok(encrypted) = aes_encrypt(key, &data) else { // let Ok(encrypted) = aes_encrypt(key, &data) else {
let Ok(encrypted) = edge.encryptor.read().unwrap().encrypt(&data) else {
error!("failed to encrypt arp reply"); error!("failed to encrypt arp reply");
return Ok(()); return Ok(());
}; };
@ -495,7 +498,6 @@ impl TunTapPacketHandler for Iface {
async fn handle_packet_from_device( async fn handle_packet_from_device(
&self, &self,
data: Vec<u8>, data: Vec<u8>,
encrypt_key: &[u8],
) -> std::io::Result<()> { ) -> std::io::Result<()> {
use etherparse::IpHeaders; use etherparse::IpHeaders;
@ -558,7 +560,8 @@ impl TunTapPacketHandler for Iface {
); );
let arp_msg = let arp_msg =
generate_arp_request(src_mac, ip, eee.device_config.get_ip()); generate_arp_request(src_mac, ip, eee.device_config.get_ip());
let Ok(encrypted) = aes_encrypt(&encrypt_key, &arp_msg) else { // let Ok(encrypted) = aes_encrypt(&encrypt_key, &arp_msg) else {
let Ok(encrypted) = eee.encryptor.read().unwrap().encrypt(&arp_msg) else {
error!("failed to encrypt arp request"); error!("failed to encrypt arp request");
return Ok(()); return Ok(());
}; };
@ -598,7 +601,8 @@ impl TunTapPacketHandler for Iface {
let pkt_size = packet.len(); let pkt_size = packet.len();
// println!("sending data with mac"); // println!("sending data with mac");
let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else { // let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else {
let Ok(encrypted) = eee.encryptor.read().unwrap().encrypt(&packet) else {
error!("failed to encrypt packet request"); error!("failed to encrypt packet request");
return Ok(()); return Ok(());
}; };

View File

@ -133,7 +133,7 @@ impl Iface {
} }
impl TunTapPacketHandler for Iface { impl TunTapPacketHandler for Iface {
async fn handle_packet_from_net(&self, data: &[u8], key: &[u8]) -> std::io::Result<()> { async fn handle_packet_from_net(&self, data: &[u8]) -> std::io::Result<()> {
// got layer 2 frame // got layer 2 frame
match Ethernet2Header::from_slice(&data) { match Ethernet2Header::from_slice(&data) {
Ok((hdr, rest)) => { Ok((hdr, rest)) => {
@ -199,7 +199,10 @@ impl TunTapPacketHandler for Iface {
[((self_ip >> 16) & 0xffff) as u16, (self_ip & 0xffff) as u16]; [((self_ip >> 16) & 0xffff) as u16, (self_ip & 0xffff) as u16];
let data = arp.marshal_to_bytes(); let data = arp.marshal_to_bytes();
let Ok(encrypted) = aes_encrypt(key, &data) else {
// let Ok(encrypted) = aes_encrypt(key, &data) else {
let Ok(encrypted) = edge.encryptor.read().unwrap().encrypt(&data) else {
error!("failed to encrypt arp reply"); error!("failed to encrypt arp reply");
return Ok(()); return Ok(());
}; };
@ -285,7 +288,7 @@ impl TunTapPacketHandler for Iface {
async fn handle_packet_from_device( async fn handle_packet_from_device(
&self, &self,
data: Vec<u8>, data: Vec<u8>,
encrypt_key: &[u8], // encrypt_key: &[u8],
) -> std::io::Result<()> { ) -> std::io::Result<()> {
let eee = get_edge(); let eee = get_edge();
@ -336,7 +339,9 @@ impl TunTapPacketHandler for Iface {
); );
let arp_msg = let arp_msg =
generate_arp_request(src_mac, ip, eee.device_config.get_ip()); generate_arp_request(src_mac, ip, eee.device_config.get_ip());
let Ok(encrypted) = aes_encrypt(&encrypt_key, &arp_msg) else {
let Ok(encrypted) = eee.encryptor.read().unwrap().encrypt(&arp_msg) else {
// let Ok(encrypted) = aes_encrypt(&encrypt_key, &arp_msg) else {
error!("failed to encrypt arp request"); error!("failed to encrypt arp request");
return Ok(()); return Ok(());
}; };
@ -380,7 +385,8 @@ impl TunTapPacketHandler for Iface {
let pkt_size = packet.len(); let pkt_size = packet.len();
// println!("sending data with mac"); // println!("sending data with mac");
let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else { // let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else {
let Ok(encrypted) = eee.encryptor.read().unwrap().encrypt(&packet) else {
error!("failed to encrypt packet request"); error!("failed to encrypt packet request");
return Ok(()); return Ok(());
}; };

View File

@ -24,8 +24,8 @@ use super::get_edge;
pub const MAX_WAIT_PACKETS: usize = 100; pub const MAX_WAIT_PACKETS: usize = 100;
pub trait TunTapPacketHandler { pub trait TunTapPacketHandler {
async fn handle_packet_from_net(&self, data: &[u8], key: &[u8]) -> std::io::Result<()>; async fn handle_packet_from_net(&self, data: &[u8]) -> std::io::Result<()>;
async fn handle_packet_from_device(&self, data: Vec<u8>, key: &[u8]) -> std::io::Result<()>; async fn handle_packet_from_device(&self, data: Vec<u8>) -> std::io::Result<()>;
} }
static ARP_WAIT_LIST: OnceCell<ArpWaitList> = OnceCell::new(); static ARP_WAIT_LIST: OnceCell<ArpWaitList> = OnceCell::new();
@ -75,7 +75,7 @@ impl ArpWaitList {
return; return;
} }
let encrypt_key = edge.get_encrypt_key(); // let encrypt_key = edge.get_encrypt_key();
let network_id = edge.network_id.load(Ordering::Relaxed); let network_id = edge.network_id.load(Ordering::Relaxed);
let src_mac = edge.device_config.get_mac(); let src_mac = edge.device_config.get_mac();
@ -88,7 +88,8 @@ impl ArpWaitList {
let pkt_size = packet.len(); let pkt_size = packet.len();
let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else { let Ok(encrypted) = edge.encryptor.read().unwrap().encrypt(&packet) else {
// let Ok(encrypted) = aes_encrypt(&encrypt_key, &packet) else {
error!("failed to encrypt packet request"); error!("failed to encrypt packet request");
return; return;
}; };

View File

@ -73,9 +73,15 @@ pub struct SdlRegisterSuper {
pub struct SdlRegisterSuperAck { pub struct SdlRegisterSuperAck {
#[prost(uint32, tag = "1")] #[prost(uint32, tag = "1")]
pub pkt_id: u32, pub pkt_id: u32,
#[prost(bytes = "vec", tag = "2")] /// 目前支持aes chacha20
pub aes_key: ::prost::alloc::vec::Vec<u8>, #[prost(string, tag = "2")]
pub algorithm: ::prost::alloc::string::String,
#[prost(bytes = "vec", tag = "3")] #[prost(bytes = "vec", tag = "3")]
pub key: ::prost::alloc::vec::Vec<u8>,
/// 逻辑分段chacha20加密算法需要使用该字段
#[prost(uint32, tag = "4")]
pub region_id: u32,
#[prost(bytes = "vec", tag = "5")]
pub session_token: ::prost::alloc::vec::Vec<u8>, pub session_token: ::prost::alloc::vec::Vec<u8>,
} }
#[allow(clippy::derive_partial_eq_without_eq)] #[allow(clippy::derive_partial_eq_without_eq)]

View File

@ -9,7 +9,7 @@ use tokio::{sync::mpsc::{Receiver, Sender, channel}, time::sleep};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use crate::{ConnectionInfo, ConnectionState, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{ARP_REPLY, ArpHdr, EthHdr, Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlArpResponse, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, RuleInfo, SdlanTcp, read_a_packet, send_stun_request, set_identity_cache}}; use crate::{AesEncryptor, Chacha20Encryptor, ConnectionInfo, ConnectionState, MyEncryptor, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{ARP_REPLY, ArpHdr, EthHdr, Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlArpResponse, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, RuleInfo, SdlanTcp, read_a_packet, send_stun_request, set_identity_cache}};
static GLOBAL_QUIC_HANDLE: OnceLock<ReadWriterHandle> = OnceLock::new(); static GLOBAL_QUIC_HANDLE: OnceLock<ReadWriterHandle> = OnceLock::new();
@ -126,11 +126,23 @@ async fn handle_tcp_message(msg: SdlanTcp) {
); );
debug!("got register super ack: {:?}", ack); debug!("got register super ack: {:?}", ack);
edge.session_token.set(ack.session_token); edge.session_token.set(ack.session_token);
let Ok(aes) = rsa_decrypt(&edge.rsa_private, &ack.aes_key) else { let Ok(key) = rsa_decrypt(&edge.rsa_private, &ack.key) else {
error!("failed to rsa decrypt aes key"); error!("failed to rsa decrypt aes key");
return; return;
}; };
match ack.algorithm.to_ascii_lowercase().as_str() {
"chacha20" => {
*edge.encryptor.write().unwrap() = MyEncryptor::ChaChao20(Chacha20Encryptor::new(key, ack.region_id));
}
"aes" => {
*edge.encryptor.write().unwrap() = MyEncryptor::Aes(AesEncryptor::new(key));
}
_other => {
}
}
/* /*
let Some(dev) = ack.dev_addr else { let Some(dev) = ack.dev_addr else {
@ -170,7 +182,8 @@ async fn handle_tcp_message(msg: SdlanTcp) {
// edge.device.reload_config(&edge.device_config, &dev.network_domain); // edge.device.reload_config(&edge.device_config, &dev.network_domain);
edge.device.reload_config(&edge.device_config, &edge.network_domain.read().unwrap().clone()); edge.device.reload_config(&edge.device_config, &edge.network_domain.read().unwrap().clone());
edge.set_authorized(true, aes); edge.set_authorized(true);
send_stun_request(edge).await; send_stun_request(edge).await;
tokio::spawn(async { tokio::spawn(async {
let nattype = edge.probe_nat_type().await; let nattype = edge.probe_nat_type().await;
@ -316,7 +329,8 @@ async fn handle_tcp_message(msg: SdlanTcp) {
message: "failed to decode REGISTER SUPER NAK".to_owned(), message: "failed to decode REGISTER SUPER NAK".to_owned(),
}); });
*/ */
edge.set_authorized(false, Vec::new()); edge.set_authorized(false);
*edge.encryptor.write().unwrap() = MyEncryptor::Invalid;
// std::process::exit(0); // std::process::exit(0);
} }
PacketType::Command => { PacketType::Command => {
@ -674,7 +688,8 @@ impl ReadWriteActor {
async fn on_disconnected_callback() { async fn on_disconnected_callback() {
let edge = get_edge(); let edge = get_edge();
edge.set_authorized(false, vec![]); edge.set_authorized(false);
*edge.encryptor.write().unwrap() = MyEncryptor::Invalid;
} }
async fn on_connected_callback(local_ip: Option<IpAddr>, stream: &mut SendStream, pkt_id: Option<u32>) { async fn on_connected_callback(local_ip: Option<IpAddr>, stream: &mut SendStream, pkt_id: Option<u32>) {

181
src/utils/encrypter.rs Normal file
View File

@ -0,0 +1,181 @@
use std::{sync::{Arc, OnceLock, RwLock, atomic::{AtomicBool, AtomicU32, Ordering}}, time::{SystemTime, UNIX_EPOCH}};
use tracing::debug;
use chacha20poly1305::{KeyInit, aead::Aead};
use dashmap::DashSet;
use sdlan_sn_rs::utils::{Result, SDLanError, aes_decrypt, aes_encrypt};
const CounterMask: u32 = (1<<24) - 1;
pub trait Encryptor {
fn is_setted(&self) -> bool;
fn set_key(&mut self, region_id: u32, key:Vec<u8>);
fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>>;
fn decrypt(&self, ciphered: &[u8]) -> Result<Vec<u8>>;
}
pub enum MyEncryptor {
Invalid,
ChaChao20(Chacha20Encryptor),
Aes(AesEncryptor),
}
impl MyEncryptor {
pub fn new() -> Self {
Self::Invalid
}
pub fn is_setted(&self) -> bool {
match self {
Self::Invalid => false,
Self::Aes(aes) => {
aes.is_setted()
}
Self::ChaChao20(cha) => {
cha.is_setted()
}
}
}
pub fn set_key(&mut self, region_id: u32, key:Vec<u8>) {
match self {
Self::Invalid => {}
Self::Aes(aes) => {
aes.set_key(region_id, key);
}
Self::ChaChao20(cha) => {
cha.set_key(region_id, key);
}
}
}
pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
match self {
Self::Invalid => {
Err(SDLanError::EncryptError("invalid encryptor".to_owned()))
}
Self::Aes(aes) => {
aes.encrypt(data)
}
Self::ChaChao20(cha) => {
cha.encrypt(data)
}
}
}
pub fn decrypt(&self, ciphered: &[u8]) -> Result<Vec<u8>> {
match self {
Self::Invalid => {
Err(SDLanError::EncryptError("invalid encryptor".to_owned()))
}
Self::Aes(aes) => {
aes.decrypt(ciphered)
}
Self::ChaChao20(cha) => {
cha.decrypt(ciphered)
}
}
}
}
pub struct Chacha20Encryptor {
key: Vec<u8>,
is_setted: bool,
next_counter: AtomicU32,
region_id: u32,
}
impl Chacha20Encryptor {
pub fn new(key: Vec<u8>, region_id: u32) -> Self {
Self {
key,
is_setted: true,
next_counter: AtomicU32::new(0),
region_id,
}
}
}
impl Encryptor for Chacha20Encryptor {
fn set_key(&mut self, region_id: u32, key:Vec<u8>) {
self.key = key;
self.region_id = region_id;
}
fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
let cipher = chacha20poly1305::ChaCha20Poly1305::new(self.key.as_slice().into());
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64;
let next_counter = self.next_counter.fetch_update(Ordering::Release, Ordering::Acquire, |current| {
Some((current + 1) & CounterMask)
}).unwrap() as u64;
let mut nonce = Vec::new();
let region_id = self.region_id.to_be_bytes();
nonce.extend_from_slice(&region_id);
let next_data = (now<<24) | next_counter;
nonce.extend_from_slice(&next_data.to_be_bytes());
match cipher.encrypt(nonce.as_slice().into(), data) {
Ok(data) => {
nonce.extend_from_slice(&data);
Ok(nonce)
},
Err(e) => {
Err(SDLanError::EncryptError(e.to_string()))
}
}
}
fn decrypt(&self, ciphered: &[u8]) -> Result<Vec<u8>> {
if ciphered.len() < 12 {
return Err(SDLanError::EncryptError("ciphered text size error".to_owned()))
}
let cipher = chacha20poly1305::ChaCha20Poly1305::new(self.key.as_slice().into());
let nonce = &ciphered[0..12];
match cipher.decrypt(nonce.into(), &ciphered[12..]) {
Ok(data) => Ok(data),
Err(e) => {
Err(SDLanError::EncryptError(format!("failed to decyrpt: {}", e.to_string())))
}
}
}
fn is_setted(&self) -> bool {
self.is_setted
}
}
pub struct AesEncryptor {
key: Vec<u8>,
is_setted: bool,
}
impl AesEncryptor {
pub fn new(key: Vec<u8>) -> Self {
Self {
key,
is_setted: true,
}
}
}
impl Encryptor for AesEncryptor {
fn decrypt(&self, ciphered: &[u8]) -> Result<Vec<u8>> {
aes_decrypt(&self.key, ciphered)
}
fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
aes_encrypt(&self.key, data)
}
fn is_setted(&self) -> bool {
self.is_setted
}
fn set_key(&mut self, _region_id: u32, key:Vec<u8>) {
self.key = key;
self.is_setted = true;
}
}

View File

@ -1,7 +1,9 @@
mod command; mod command;
mod encrypter;
use std::{fs::OpenOptions, io::Write, net::Ipv4Addr, path::Path}; use std::{fs::OpenOptions, io::Write, net::Ipv4Addr, path::Path};
pub use encrypter::*;
pub use command::*; pub use command::*;
mod socks; mod socks;