From 46d44e9f96ef0b8e8204eb548a84d55d6af0c208 Mon Sep 17 00:00:00 2001 From: asxalex Date: Wed, 21 Feb 2024 12:37:44 +0800 Subject: [PATCH] added encrypt encode and decode --- src/packet/common.rs | 49 +++++++++++++++++++++++++++++++++++- src/packet/register_super.rs | 32 ++++++++++++++++++++++- src/peer.rs | 4 +-- src/utils/helper.rs | 26 ++++++++++++++++++- 4 files changed, 106 insertions(+), 5 deletions(-) diff --git a/src/packet/common.rs b/src/packet/common.rs index ded10cf..876514b 100644 --- a/src/packet/common.rs +++ b/src/packet/common.rs @@ -61,13 +61,26 @@ impl<'a> Common<'a> { }; Ok((common, &value[5 + id_len..])) } + + fn new(id: &'a str) -> Self { + return Common { + id, + version: 1, + ttl: 2, + pc: PacketType::PKTInvalid, + flags: 0, + }; + } } #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] #[repr(u8)] pub enum PacketType { PKTInvalid, + // 向sn注册 PKTRegisterSuper, + // 数据转发 + PKTPacket, } impl std::convert::From for PacketType { @@ -75,6 +88,7 @@ impl std::convert::From for PacketType { match value { // 0 => Self::PacketInvalid, 1 => Self::PKTRegisterSuper, + 2 => Self::PKTPacket, _ => Self::PKTInvalid, } } @@ -85,6 +99,7 @@ impl PacketType { match *self { Self::PKTInvalid => 0, Self::PKTRegisterSuper => 1, + Self::PKTPacket => 2, } } } @@ -114,6 +129,29 @@ mod test { } } +pub fn encode_packet_encrypted( + cmn: &Common, + pkt: &T, + header_pass: &[u8], +) -> Result> { + let hdr = cmn.encode(); + let body = serde_json::to_vec(pkt)?; + let body = aes_encrypt(header_pass, &body)?; + + let mut result = Vec::with_capacity(4 + hdr.len() + body.len()); + let total_size = (2 + hdr.len() + body.len()) as u16; + // insert total size + result.extend_from_slice(&total_size.to_be_bytes()); + // packet_id + result.extend_from_slice(&[0, 0]); + // insert header + result.extend_from_slice(&hdr); + // insert body + result.extend_from_slice(&body); + + Ok(result) +} + pub fn encode_packet(cmn: &Common, pkt: &T) -> Result> { // header let hdr = cmn.encode(); @@ -134,7 +172,8 @@ pub fn encode_packet(cmn: &Common, pkt: &T) -> Result>(value: &'a [u8]) -> Result<(Common<'a>, T)> { +// decode the common, and return the encrypted bytes; +pub fn decode_common<'a>(value: &'a [u8]) -> Result<(Common<'a>, &'a [u8])> { if value.len() < 4 { return Err(SDLanError::NormalError("decode pkt length error")); } @@ -146,6 +185,14 @@ pub fn decode_packet<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<( let value2 = &value[4..2 + size as usize]; let (cmn, value2) = Common::from_slice(value2)?; + Ok((cmn, value2)) +} + +// decode the whole packet, suitable for un-encrypted packet info +pub fn decode_packet<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common<'a>, T)> { + // let value2 = &value[4..2 + size as usize]; + let (cmn, value2) = decode_common(value)?; + // let (cmn, value2) = Common::from_slice(value2)?; let res = serde_json::from_slice(value2)?; Ok((cmn, res)) } diff --git a/src/packet/register_super.rs b/src/packet/register_super.rs index fc60ff1..59ea651 100644 --- a/src/packet/register_super.rs +++ b/src/packet/register_super.rs @@ -22,9 +22,13 @@ pub struct RegisterSuper<'a> { #[cfg(test)] mod test { use crate::config; + use crate::packet::decode_common; use crate::packet::decode_packet; use crate::packet::encode_packet; + use crate::packet::encode_packet_encrypted; use crate::packet::PacketType; + use crate::utils::aes_decrypt; + use crate::utils::gen_uuid; use self::peer::SdlanSock; use crate::packet::Common; @@ -32,7 +36,7 @@ mod test { use super::*; - fn do_encode() -> Result<(Vec, Common<'static>, RegisterSuper<'static>)> { + fn prepare_data() -> (Common<'static>, RegisterSuper<'static>) { let cmn1 = Common { version: 1, id: "asxalex", @@ -54,10 +58,23 @@ mod test { pub_key: "public key", token: "user's token", }; + (cmn1, pkt1) + } + + fn do_encode() -> Result<(Vec, Common<'static>, RegisterSuper<'static>)> { + let (cmn1, pkt1) = prepare_data(); let res = encode_packet(&cmn1, &pkt1).unwrap(); Ok((res, cmn1, pkt1)) } + fn do_encode_encrypted( + pass: &[u8], + ) -> Result<(Vec, Common<'static>, RegisterSuper<'static>)> { + let (cmn1, pkt1) = prepare_data(); + let res = encode_packet_encrypted(&cmn1, &pkt1, pass).unwrap(); + Ok((res, cmn1, pkt1)) + } + fn do_decode<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common, T)> { decode_packet(value) } @@ -100,4 +117,17 @@ mod test { let (cmn2, pkt2) = do_decode::(&value).unwrap(); compare_two(cmn1, cmn2, pkt1, pkt2); } + + #[test] + fn test_encrypt_common_and_packet() -> Result<()> { + let pass = gen_uuid().into_bytes(); + let (data, cmn1, pkt1) = do_encode_encrypted(&pass)?; + + let (cmn2, encryptedinfo) = decode_common(&data)?; + let jsonstr = aes_decrypt(&pass, encryptedinfo)?; + let pkt2: RegisterSuper = serde_json::from_slice(&jsonstr)?; + + compare_two(cmn1, cmn2, pkt1, pkt2); + Ok(()) + } } diff --git a/src/peer.rs b/src/peer.rs index d11ce7d..4998997 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -109,9 +109,9 @@ impl IpSubnet { #[derive(Debug, Serialize, Deserialize, PartialEq, sqlx::FromRow)] pub struct SdlanSock { pub family: u8, - pub port: u32, + pub port: u16, pub has_v6: bool, - pub v6_port: u32, + pub v6_port: u16, // pub v4: Vec, // pub v6: Vec, #[sqlx(try_from = "Vec")] diff --git a/src/utils/helper.rs b/src/utils/helper.rs index cc664f8..7203bfe 100644 --- a/src/utils/helper.rs +++ b/src/utils/helper.rs @@ -1,7 +1,8 @@ +use super::SDLanError; use dashmap::DashMap; +use std::net::{IpAddr, SocketAddr}; use std::path::Path; use std::sync::Arc; -use tokio::sync::mpsc::UnboundedReceiver; pub struct MyDashMap(DashMap>); @@ -38,6 +39,8 @@ pub fn get_current_timestamp() -> u64 { .as_secs() } +use crate::peer::SdlanSock; + use super::{gen_uuid, Result}; use std::fs::{File, OpenOptions}; use std::io::{Read, Write}; @@ -59,3 +62,24 @@ pub fn create_or_load_uuid() -> Result { return Ok(uuid); } } + +// get sdlansock info from socketaddr +pub fn get_sdlan_sock_from_socketaddr(addr: SocketAddr) -> Result { + let port = addr.port(); + let ip = addr.ip(); + match ip { + IpAddr::V4(ipv4) => { + let v4: u32 = ipv4.into(); + let res = SdlanSock { + family: 0, + port: port, + has_v6: false, + v6_port: 0, + v4: v4.to_be_bytes(), + v6: [0; 16], + }; + Ok(res) + } + IpAddr::V6(_ipv6) => Err(SDLanError::NormalError("ipv6 found")), + } +}