commit 8647b0435ae786ca567622ee0fd7111ab99efdfe Author: asxalex Date: Sat Feb 17 10:27:13 2024 +0800 added packet test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7f66fba --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/target +/Cargo.lock +/.output +/.vscode diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..21e50be --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "sdlan-sn-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.79" +byteorder = "1.5.0" +dashmap = "5.5.3" +lazy_static = "1.4.0" +rolling-file = { git = "https://git.asxalex.pw/rust/rolling-file" } +serde = { version = "1.0.196", features = ["derive"] } +serde_json = "1.0.113" +serde_repr = "0.1.18" +tokio = { version = "1.36.0", features = ["full"] } +tracing = "0.1.40" +tracing-appender = "0.2.3" diff --git a/examples/benchmark_peer/main.rs b/examples/benchmark_peer/main.rs new file mode 100644 index 0000000..b0197a1 --- /dev/null +++ b/examples/benchmark_peer/main.rs @@ -0,0 +1,61 @@ +use sdlan_sn_rs::peer::Peer; + +use dashmap::DashMap; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use lazy_static::lazy_static; + +lazy_static! { + static ref DASH: DashMap> = DashMap::new(); +} + +#[tokio::main] +async fn main() { + let numbers = 1000000; + let number_of_routine = 400; + + let start = Instant::now(); + for i in 0..numbers { + let id = i.to_string(); + let peer = Arc::new(Peer::new(&id)); + DASH.insert(id, peer); + } + + println!("insert {} record elapsed: {:?}", numbers, start.elapsed()); + + let mut handlers = vec![]; + let start = Instant::now(); + for i in 0..number_of_routine { + let handler = tokio::spawn(async move { + for j in 0..numbers { + // let info = DASH.get(&j.to_string()).unwrap().clone(); + let info = DASH.get(&j.to_string()).unwrap(); + info.last_seen.fetch_add(1, Ordering::Relaxed); + // println!("{j}"); + } + }); + handlers.push(handler); + } + + for handler in handlers { + let v = handler.await; + } + println!( + "{}x{} times add elapsed: {:?}", + number_of_routine, + numbers, + start.elapsed() + ); + for i in 0..10 { + println!( + "{:?}", + DASH.get(&i.to_string()) + .unwrap() + .clone() + .last_seen + .load(Ordering::Relaxed) + ) + } +} diff --git a/src/bin/sdlan-sn/main.rs b/src/bin/sdlan-sn/main.rs new file mode 100644 index 0000000..c50432e --- /dev/null +++ b/src/bin/sdlan-sn/main.rs @@ -0,0 +1,64 @@ +use sdlan_sn_rs::log; +use sdlan_sn_rs::packet; +use sdlan_sn_rs::utils; +use tracing::{debug, error}; + +//use std::io::Read; +use std::net::SocketAddr; +use std::time::Duration; +use tokio::net::UdpSocket; +use utils::Result; + +const SERVER: &str = "127.0.0.1:7655"; + +async fn client(address: &str) -> Result<()> { + let socket = UdpSocket::bind("127.0.0.1:0").await?; + let id = "Dejavu"; + let cmn = packet::Common { + version: 1, + id, + ttl: 128, + pc: packet::PacketType::PKTRegisterSuper, + flags: 0x0200, + }; + let value = cmn.encode(); + loop { + socket.send_to(&value, address).await?; + tokio::time::sleep(Duration::from_millis(5000)).await; + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + let _guard = log::init_log(); + debug!("main starts here"); + + let listener = UdpSocket::bind(SERVER).await?; + + tokio::spawn(async { + client(SERVER).await.unwrap(); + }); + + loop { + // let mut buffer = [0; 2048]; + let mut buffer = vec![0; 2048]; + // let mut data = Vec::with_capacity(2048); + let (n, addr) = listener.recv_from(&mut buffer).await?; + buffer.truncate(n); + tokio::spawn(async move { + if let Err(e) = handle_packet(&buffer, addr).await { + error!("failed to handle packet: {:?}", e); + } + }); + } + + Ok(()) +} + +async fn handle_packet(pkt: &[u8], from: SocketAddr) -> Result<()> { + let common = packet::Common::from_slice(pkt)?; + println!("common: {:?}", common); + Ok(()) +} diff --git a/src/bin/sdlan/main.rs b/src/bin/sdlan/main.rs new file mode 100644 index 0000000..7f755fb --- /dev/null +++ b/src/bin/sdlan/main.rs @@ -0,0 +1,2 @@ +#[tokio::main] +async fn main() {} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..4ee5927 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,11 @@ +// common的flag掩码 +pub const FLAGS_TYPE_MASK: u16 = 0x001f; +pub const FLAGS_BITS_MASK: u16 = 0xffe0; + +// common头的flags里面的flag,可以组合 +pub const SDLAN_FLAGS_FROM_SN: u16 = 0x0020; +pub const SDLAN_FLAGS_SOCKET: u16 = 0x0040; +pub const SDLAN_FLAGS_OPTIONS: u16 = 0x0080; + +pub const IPV4_SIZE: u8 = 4; +pub const IPV6_SIZE: u8 = 16; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a21bc06 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,5 @@ +pub mod config; +pub mod log; +pub mod packet; +pub mod peer; +pub mod utils; diff --git a/src/log.rs b/src/log.rs new file mode 100644 index 0000000..4301cb1 --- /dev/null +++ b/src/log.rs @@ -0,0 +1,5 @@ +use tracing_appender::non_blocking::WorkerGuard; + +pub fn init_log() -> WorkerGuard { + rolling_file::default() +} diff --git a/src/packet/common.rs b/src/packet/common.rs new file mode 100644 index 0000000..d403687 --- /dev/null +++ b/src/packet/common.rs @@ -0,0 +1,142 @@ +use crate::config; +use crate::utils::Result; + +use byteorder::{BigEndian, ByteOrder}; + +use serde_repr::*; + +use crate::utils::*; + +#[derive(Debug)] +pub struct Common<'a> { + pub version: u8, + pub id: &'a str, + pub ttl: u8, + // packet type + pub pc: PacketType, + pub flags: u16, +} + +impl<'a> Common<'a> { + pub fn encode(&self) -> Vec { + let mut result = vec![]; + encode_u8(&mut result, self.version); + encode_buf_with_size_1(&mut result, self.id.as_bytes()); + encode_u8(&mut result, self.ttl); + + let mut flag = self.pc.to_u16() & config::FLAGS_TYPE_MASK; + flag |= self.flags & config::FLAGS_BITS_MASK; + encode_u16(&mut result, flag); + + result + } + + pub fn from_slice(value: &'a [u8]) -> Result<(Common<'a>, &'a [u8])> { + if value.len() < 2 { + return Err("common header length error".into()); + } + let version = value[0]; + let id_len = value[1] as usize; + if value.len() < (id_len + 5) as usize { + return Err("common header id length error".into()); + } + let v1 = &value[2..2 + id_len]; + let id = match std::str::from_utf8(v1) { + Ok(s) => s, + Err(e) => return Err(SDLanError::ConvertError(e.to_string())), + }; + let ttl = value[2 + id_len]; + let flags = BigEndian::read_u16(&value[3 + id_len..5 + id_len]); + let pc = flags & config::FLAGS_TYPE_MASK; + let flag = flags & config::FLAGS_BITS_MASK; + + let common = Self { + version, + id, + ttl, + pc: pc.into(), + flags: flag, + }; + Ok((common, &value[5 + id_len..])) + } +} + +#[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum PacketType { + PKTInvalid, + PKTRegisterSuper, +} + +impl std::convert::From for PacketType { + fn from(value: u16) -> Self { + match value { + // 0 => Self::PacketInvalid, + 1 => Self::PKTRegisterSuper, + _ => Self::PKTInvalid, + } + } +} + +impl PacketType { + pub fn to_u16(&self) -> u16 { + match *self { + Self::PKTInvalid => 0, + Self::PKTRegisterSuper => 1, + } + } +} + +#[cfg(test)] +mod test { + #[test] + fn test_common_encode_and_decode() { + use super::*; + + let id = "hello"; + let common = Common { + version: 0, + id, + ttl: 2, + pc: 1.into(), + flags: 0, + }; + let value1 = common.encode(); + let (common2, _) = Common::from_slice(&value1).unwrap(); + println!("common2 = {:?}", common2); + assert_eq!(common.id, common2.id); + assert_eq!(common.version, common2.version); + assert_eq!(common.ttl, common2.ttl); + assert_eq!(common.pc, common2.pc); + assert_eq!(common.flags, common2.flags); + } +} + +pub fn encode_packet(cmn: &Common, pkt: &T) -> Result> { + let mut res = cmn.encode(); + + let content = serde_json::to_vec(pkt)?; + + let size = content.len() as u32; + // add size bytes + res.extend_from_slice(&size.to_be_bytes()); + res.extend_from_slice(&content); + Ok(res) +} + +pub fn decode_packet<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common<'a>, T)> { + let (cmn, value2) = Common::from_slice(value)?; + if value2.len() < 4 { + return Err(SDLanError::NormalError("decode pkt length error")); + } + let size_bytes = value2[0..4].try_into().expect("never goes here"); + let size = u32::from_be_bytes(size_bytes) as usize; + let value2 = &value2[4..]; + // if value2.len() < size { + if size > value2.len() { + return Err(SDLanError::NormalError("decode pkt size error")); + } + + let res = serde_json::from_slice(&value2[..size])?; + Ok((cmn, res)) +} diff --git a/src/packet/mod.rs b/src/packet/mod.rs new file mode 100644 index 0000000..57333c4 --- /dev/null +++ b/src/packet/mod.rs @@ -0,0 +1,5 @@ +mod common; +pub use common::*; + +mod register_super; +pub use register_super::*; diff --git a/src/packet/register_super.rs b/src/packet/register_super.rs new file mode 100644 index 0000000..9359d24 --- /dev/null +++ b/src/packet/register_super.rs @@ -0,0 +1,114 @@ +use crate::peer; + +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub struct RegisterSuper<'a> { + // pass, 用于给registersuper一个初步的雁阵,固定8位 + pub pass: &'a str, + // 自身的sock信息 + pub sock: peer::SdlanSock, + + // 自身的ip信息 + pub dev_addr: peer::IpSubnet, + + // 自身的公钥 + pub pub_key: &'a str, + + // user's token, can be used to specify a user + pub token: &'a str, +} + +#[cfg(test)] +mod test { + use crate::config; + use crate::packet::decode_packet; + use crate::packet::encode_packet; + use crate::packet::PacketType; + + use self::peer::SdlanSock; + use crate::packet::Common; + use crate::utils::Result; + use std::sync::atomic; + use std::sync::atomic::Ordering::Relaxed; + + use super::*; + + fn do_encode() -> Result<(Vec, Common<'static>, RegisterSuper<'static>)> { + let cmn1 = Common { + version: 1, + id: "asxalex", + ttl: 128, + pc: PacketType::PKTRegisterSuper, + flags: config::SDLAN_FLAGS_FROM_SN, + }; + let pkt1 = RegisterSuper { + pass: "encrypt!", + sock: SdlanSock { + family: 0, + port: 1, + has_v6: true, + v6_port: 2345, + v4: [0; 4], + v6: [1; 16], + }, + dev_addr: peer::IpSubnet { + net_addr: atomic::AtomicU32::new(192), + net_bit_len: atomic::AtomicU8::new(24), + }, + pub_key: "public key", + token: "user's token", + }; + let res = encode_packet(&cmn1, &pkt1).unwrap(); + Ok((res, cmn1, pkt1)) + } + + fn do_decode<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common, T)> { + decode_packet(value) + } + + fn compare_two(cmn1: Common, cmn2: Common, pkt1: RegisterSuper, pkt2: RegisterSuper) { + assert_eq!(cmn1.version, cmn2.version); + assert_eq!(cmn1.id, cmn2.id); + assert_eq!(cmn1.ttl, cmn2.ttl); + assert_eq!(cmn1.pc, cmn2.pc); + assert_eq!(cmn1.flags, cmn2.flags); + + assert_eq!(pkt1.pass, pkt2.pass); + assert_eq!(pkt1.sock, pkt2.sock); + assert_eq!( + pkt1.dev_addr.net_addr.load(Relaxed), + pkt2.dev_addr.net_addr.load(Relaxed) + ); + assert_eq!( + pkt1.dev_addr.net_bit_len.load(Relaxed), + pkt2.dev_addr.net_bit_len.load(Relaxed) + ); + assert_eq!(pkt1.pub_key, pkt2.pub_key); + assert_eq!(pkt1.token, pkt2.token); + } + + #[test] + fn test_register_super_encode_decode() { + let (value, cmn1, pkt1) = do_encode().unwrap(); + let (cmn2, pkt2) = do_decode::(&value).unwrap(); + compare_two(cmn1, cmn2, pkt1, pkt2); + } + + #[test] + #[should_panic] + fn test_register_super_encode_decode_panic() { + let (mut value, cmn1, pkt1) = do_encode().unwrap(); + value.pop(); + let (cmn2, pkt2) = do_decode::(&value).unwrap(); + compare_two(cmn1, cmn2, pkt1, pkt2); + } + + #[test] + fn test_register_super_encode_decode_nopanic() { + let (mut value, cmn1, pkt1) = do_encode().unwrap(); + value.extend_from_slice(&[1; 10]); + let (cmn2, pkt2) = do_decode::(&value).unwrap(); + compare_two(cmn1, cmn2, pkt1, pkt2); + } +} diff --git a/src/peer.rs b/src/peer.rs new file mode 100644 index 0000000..84475a8 --- /dev/null +++ b/src/peer.rs @@ -0,0 +1,71 @@ +#![allow(unused)] +use std::default::Default; +use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU8}; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug)] +pub struct Peer { + pub id: String, + // 对端阶段的udp信息,包括ipv4地址和子网掩码位数 + pub dev_addr: IpSubnet, + // 对端对外开放的ip和端口信息 + pub sock: Mutex, + pub pub_key: Mutex>, + pub timeout: isize, + + // 最近一次遇见 + pub last_seen: AtomicI64, + // 最近一次p2p消息 + pub last_p2p: AtomicI64, + // 最近一次发送query + pub last_send_query: AtomicI64, + + // 该节点锁属的网络 + pub network_id: Mutex, +} + +impl Peer { + pub fn new(id: &str) -> Self { + Self { + id: id.to_string(), + dev_addr: IpSubnet { + net_addr: AtomicU32::new(0), + net_bit_len: AtomicU8::new(0), + }, + sock: Mutex::new(SdlanSock { + family: 0, + port: 0, + has_v6: false, + v6_port: 0, + v4: [0; 4], + v6: [0; 16], + }), + pub_key: Mutex::new(vec![]), + timeout: 0, + last_seen: AtomicI64::new(0), + last_p2p: AtomicI64::new(0), + last_send_query: AtomicI64::new(0), + network_id: Mutex::new(String::new()), + } + } +} + +/// IpSubnet, 对端ipv4信息 +#[derive(Debug, Serialize, Deserialize)] +pub struct IpSubnet { + pub net_addr: AtomicU32, + pub net_bit_len: AtomicU8, +} + +/// SdlanSock: 对端对外的ip信息,包括ipv4和ipv6 +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SdlanSock { + pub family: u8, + pub port: u32, + pub has_v6: bool, + pub v6_port: u32, + pub v4: [u8; 4], + pub v6: [u8; 16], +} diff --git a/src/utils/encode_decode.rs b/src/utils/encode_decode.rs new file mode 100644 index 0000000..ca4fdc1 --- /dev/null +++ b/src/utils/encode_decode.rs @@ -0,0 +1,62 @@ +pub fn encode_u8(v: &mut Vec, value: u8) { + v.push(value); +} + +pub fn encode_u16(v: &mut Vec, value: u16) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_u32(v: &mut Vec, value: u32) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_u64(v: &mut Vec, value: u64) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_i8(v: &mut Vec, value: u8) { + v.push(value); +} + +pub fn encode_i16(v: &mut Vec, value: u16) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_i32(v: &mut Vec, value: u32) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_i64(v: &mut Vec, value: u64) { + let val = value.to_be_bytes(); + v.extend_from_slice(&val); +} + +pub fn encode_buf_withoud_size(v: &mut Vec, buf: &[u8]) { + v.extend_from_slice(buf); +} + +pub fn encode_buf_with_size_1(v: &mut Vec, buf: &[u8]) { + let l = buf.len() as u8; + let n = l.to_be_bytes(); + v.push(n[0]); + v.extend_from_slice(buf); +} + +pub fn encode_buf_with_size_2(v: &mut Vec, buf: &[u8]) { + let l = buf.len() as u16; + let n = l.to_be_bytes(); + v.extend_from_slice(&n); + v.extend_from_slice(buf); +} + +pub fn encode_buf_with_size_4(v: &mut Vec, buf: &[u8]) { + let l = buf.len() as u32; + let n = l.to_be_bytes(); + v.extend_from_slice(&n); + v.extend_from_slice(buf); +} diff --git a/src/utils/error.rs b/src/utils/error.rs new file mode 100644 index 0000000..0107473 --- /dev/null +++ b/src/utils/error.rs @@ -0,0 +1,29 @@ +use std::convert::From; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum SDLanError { + IOError(std::io::Error), + NormalError(&'static str), + ConvertError(String), + SerializeError(String), +} + +impl From for SDLanError { + fn from(value: std::io::Error) -> Self { + Self::IOError(value) + } +} + +impl From<&'static str> for SDLanError { + fn from(value: &'static str) -> Self { + Self::NormalError(value) + } +} + +impl From for SDLanError { + fn from(value: serde_json::Error) -> Self { + Self::SerializeError(value.to_string()) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..ba43934 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,5 @@ +mod encode_decode; +mod error; + +pub use encode_decode::*; +pub use error::*;