sdlan-rs/src/packet/common.rs
2024-03-30 22:40:51 +08:00

319 lines
8.8 KiB
Rust

use crate::utils::Result;
use super::packet::Packet;
use byteorder::{BigEndian, ByteOrder};
use tracing::{debug, error};
use serde_repr::*;
use crate::utils::*;
#[derive(Debug)]
pub struct Common<'a> {
pub packet_id: u16,
pub version: u8,
// client's uuid
pub id: &'a str,
pub token: u64,
pub ttl: u8,
// packet type
pub pc: PacketType,
pub flags: u16,
}
impl<'a> Common<'a> {
pub fn encode(&self) -> Vec<u8> {
let mut result = vec![];
result.extend_from_slice(&self.packet_id.to_be_bytes());
encode_u8(&mut result, self.version);
encode_buf_without_size(&mut result, &self.id.as_bytes(), 32);
// result.extend_from_slice(self.id.as_bytes());
result.extend_from_slice(&self.token.to_be_bytes());
// encode_buf_with_size_1(&mut result, self.id.as_bytes());
encode_u8(&mut result, self.ttl);
encode_u16(&mut result, self.flags);
let flag = self.pc.to_u8();
encode_u8(&mut result, flag);
result
}
pub fn from_slice(value: &'a [u8]) -> Result<(Common, &'a [u8])> {
const id_len: usize = 32;
const token_len: usize = 8;
if value.len() < 7 + id_len + token_len {
return Err("common header length error".into());
}
let packet_id = u16::from_be_bytes(value[0..2].try_into().unwrap());
let version = value[2];
let v1 = &value[3..3 + id_len];
let mut id = match std::str::from_utf8(v1) {
Ok(s) => s,
Err(e) => return Err(SDLanError::ConvertError(e.to_string())),
};
// let id = u64::from_be_bytes(v1);
id = id.trim_end_matches('\0');
let token = u64::from_be_bytes(
value[3 + id_len..3 + id_len + token_len]
.try_into()
.unwrap(),
);
let ttl = value[3 + id_len + token_len];
let flags = BigEndian::read_u16(&value[4 + id_len + token_len..6 + id_len + token_len]);
let pc = value[6 + id_len + token_len];
let common = Self {
packet_id,
version,
id,
ttl,
token,
pc: pc.into(),
flags: flags,
};
Ok((common, &value[7 + id_len + token_len..]))
}
pub fn from_old_common(cmn: &'a Common) -> Self {
return Common {
packet_id: cmn.packet_id,
id: cmn.id,
token: cmn.token,
version: cmn.version,
ttl: cmn.ttl,
pc: cmn.pc,
flags: cmn.flags,
};
}
pub fn new(id: &'a str) -> Self {
return Common {
packet_id: 0,
id,
token: 0,
version: 1,
ttl: 2,
pc: PacketType::PKTInvalid,
flags: 0,
};
}
}
#[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr, Copy, Clone)]
#[repr(u8)]
pub enum PacketType {
PKTInvalid,
// 向sn注册
PKTRegisterSuper,
// 数据转发
PKTPacket,
// 打洞消息
PKTRegister,
// 打洞消息ACK
PKTRegisterACK,
// RegisterSuperACK,
PKTRegisterSuperACK,
PKTRegisterSuperAcknowledge,
PKTRegisterSuperNAK,
PKTUnregisterSuper,
// 客户端向服务端发送另一个客户端的信息
PKTQueryPeer,
// 服务端向客户端返回消息
PKTPeerInfo,
// sn send command to other sn or edges
PKTCommand,
PKTCommandResp,
}
impl std::convert::From<u8> for PacketType {
fn from(value: u8) -> Self {
match value {
// 0 => Self::PacketInvalid,
1 => Self::PKTRegisterSuper,
2 => Self::PKTPacket,
3 => Self::PKTRegister,
4 => Self::PKTRegisterACK,
5 => Self::PKTRegisterSuperACK,
6 => Self::PKTRegisterSuperAcknowledge,
7 => Self::PKTRegisterSuperNAK,
8 => Self::PKTUnregisterSuper,
9 => Self::PKTQueryPeer,
10 => Self::PKTPeerInfo,
11 => Self::PKTCommand,
12 => Self::PKTCommandResp,
_ => Self::PKTInvalid,
}
}
}
impl PacketType {
pub fn to_u8(&self) -> u8 {
match *self {
Self::PKTInvalid => 0,
Self::PKTRegisterSuper => 1,
Self::PKTPacket => 2,
Self::PKTRegister => 3,
Self::PKTRegisterACK => 4,
Self::PKTRegisterSuperACK => 5,
Self::PKTRegisterSuperAcknowledge => 6,
Self::PKTRegisterSuperNAK => 7,
Self::PKTUnregisterSuper => 8,
Self::PKTQueryPeer => 9,
Self::PKTPeerInfo => 10,
Self::PKTCommand => 11,
Self::PKTCommandResp => 12,
}
}
}
#[cfg(test)]
mod test {
#[test]
fn test_common_encode_and_decode() {
use super::*;
let id = "hello";
let common = Common {
packet_id: 0,
version: 0,
id,
token: 0,
ttl: 2,
pc: 1.into(),
flags: 0,
};
let value1 = common.encode();
println!("value1.len: {}", value1.len());
let (common2, _) = Common::from_slice(&value1).unwrap();
println!("common2 = {:?}", common2);
assert_eq!(common.id, common2.id);
assert_eq!(common.version, common2.version);
assert_eq!(common.ttl, common2.ttl);
assert_eq!(common.pc, common2.pc);
assert_eq!(common.flags, common2.flags);
}
}
pub fn encode_packet_encrypted<T: serde::Serialize>(
cmn: &Common,
pkt: &T,
header_pass: &[u8],
) -> Result<Vec<u8>> {
let hdr = cmn.encode();
let body = serde_json::to_vec(pkt)?;
let body = aes_encrypt(header_pass, &body)?;
let mut result = Vec::with_capacity(2 + hdr.len() + body.len());
let total_size = (hdr.len() + body.len()) as u16;
// insert total size
result.extend_from_slice(&total_size.to_be_bytes());
// packet_id
// result.extend_from_slice(&[0, 0]);
// insert header
result.extend_from_slice(&hdr);
// insert body
result.extend_from_slice(&body);
Ok(result)
}
/*
pub fn encode_packet_packet_encrypt(
cmn: &Common,
pkt: &Packet<'_>,
header_pass: &[u8],
) -> Result<Vec<u8>> {
let data2 = aes_encrypt(header_pass, pkt.data)?;
pkt.data = &data2;
encode_packet_packet(cmn, pkt)
}
*/
pub fn encode_packet_packet(cmn: &Common, pkt: &Packet<'_>) -> Result<Vec<u8>> {
let hdr = cmn.encode();
let body = pkt.marshal()?;
let mut result = Vec::with_capacity(2 + hdr.len() + body.len());
let total_size = (hdr.len() + body.len()) as u16;
// insert total size
result.extend_from_slice(&total_size.to_be_bytes());
// packet_id
// result.extend_from_slice(&[0, 0]);
// insert header
result.extend_from_slice(&hdr);
// insert body
result.extend_from_slice(&body);
Ok(result)
}
pub fn encode_packet<T: serde::Serialize>(cmn: &Common, pkt: &T) -> Result<Vec<u8>> {
// header
let hdr = cmn.encode();
// body
let body = serde_json::to_vec(pkt)?;
debug!(
"body({}): {}",
body.len(),
std::str::from_utf8(&body).unwrap()
);
let mut result = Vec::with_capacity(2 + hdr.len() + body.len());
let total_size = (hdr.len() + body.len()) as u16;
// insert total size
result.extend_from_slice(&total_size.to_be_bytes());
// packet_id
// result.extend_from_slice(&[0, 0]);
// insert header
result.extend_from_slice(&hdr);
// insert body
result.extend_from_slice(&body);
debug!("enocded packet size: {}", result.len());
// debug!("encoded packet: {:?}", result);
Ok(result)
}
// decode the common, and return the encrypted bytes;
pub fn decode_common<'a>(value: &'a [u8]) -> Result<(Common<'a>, &'a [u8])> {
if value.len() < 4 {
return Err(SDLanError::NormalError("decode pkt length error"));
}
let size_bytes = value[0..2].try_into().expect("never goes here");
let size = u16::from_be_bytes(size_bytes);
if value.len() < 2 + size as usize {
error!("decode pkt header error: {}: {}", value.len(), 2 + size);
return Err(SDLanError::NormalError("decode pkt header size error"));
}
let value2 = &value[2..2 + size as usize];
let Ok((cmn, value2)) = Common::from_slice(value2) else {
error!("failed to decode common");
return Err(SDLanError::NormalError("failed to deocde common"));
};
Ok((cmn, value2))
}
// decode the whole packet, suitable for un-encrypted packet info
pub fn decode_packet<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common<'a>, T)> {
// let value2 = &value[4..2 + size as usize];
let (cmn, value2) = decode_common(value)?;
// let (cmn, value2) = Common::from_slice(value2)?;
let res = serde_json::from_slice(value2)?;
Ok((cmn, res))
}