sdlan-rs/src/packet/common.rs

203 lines
5.3 KiB
Rust

use crate::utils::Result;
use byteorder::{BigEndian, ByteOrder};
use serde_repr::*;
use crate::utils::*;
#[derive(Debug)]
pub struct Common<'a> {
pub version: u8,
pub id: &'a str,
pub ttl: u8,
// packet type
pub pc: PacketType,
pub flags: u16,
}
impl<'a> Common<'a> {
pub fn encode(&self) -> Vec<u8> {
let mut result = vec![];
encode_u8(&mut result, self.version);
encode_buf_without_size(&mut result, &self.id.as_bytes(), 32);
// encode_buf_with_size_1(&mut result, self.id.as_bytes());
encode_u8(&mut result, self.ttl);
encode_u16(&mut result, self.flags);
let flag = self.pc.to_u8();
encode_u8(&mut result, flag);
result
}
pub fn from_slice(value: &'a [u8]) -> Result<(Common<'a>, &'a [u8])> {
let id_len = 32;
if value.len() < 5 + id_len {
return Err("common header length error".into());
}
let version = value[0];
let v1 = &value[1..1 + id_len];
let mut id = match std::str::from_utf8(v1) {
Ok(s) => s,
Err(e) => return Err(SDLanError::ConvertError(e.to_string())),
};
id = id.trim_end_matches('\0');
let ttl = value[1 + id_len];
let flags = BigEndian::read_u16(&value[2 + id_len..4 + id_len]);
let pc = value[4 + id_len];
let common = Self {
version,
id,
ttl,
pc: pc.into(),
flags: flags,
};
Ok((common, &value[5 + id_len..]))
}
pub fn new(id: &'a str) -> Self {
return Common {
id,
version: 1,
ttl: 2,
pc: PacketType::PKTInvalid,
flags: 0,
};
}
}
#[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr, Copy, Clone)]
#[repr(u8)]
pub enum PacketType {
PKTInvalid,
// 向sn注册
PKTRegisterSuper,
// 数据转发
PKTPacket,
// 打洞消息
PKTRegister,
}
impl std::convert::From<u8> for PacketType {
fn from(value: u8) -> Self {
match value {
// 0 => Self::PacketInvalid,
1 => Self::PKTRegisterSuper,
2 => Self::PKTPacket,
3 => Self::PKTRegister,
_ => Self::PKTInvalid,
}
}
}
impl PacketType {
pub fn to_u8(&self) -> u8 {
match *self {
Self::PKTInvalid => 0,
Self::PKTRegisterSuper => 1,
Self::PKTPacket => 2,
Self::PKTRegister => 3,
}
}
}
#[cfg(test)]
mod test {
#[test]
fn test_common_encode_and_decode() {
use super::*;
let id = "hello";
let common = Common {
version: 0,
id,
ttl: 2,
pc: 1.into(),
flags: 0,
};
let value1 = common.encode();
let (common2, _) = Common::from_slice(&value1).unwrap();
println!("common2 = {:?}", common2);
assert_eq!(common.id, common2.id);
assert_eq!(common.version, common2.version);
assert_eq!(common.ttl, common2.ttl);
assert_eq!(common.pc, common2.pc);
assert_eq!(common.flags, common2.flags);
}
}
pub fn encode_packet_encrypted<T: serde::Serialize>(
cmn: &Common,
pkt: &T,
header_pass: &[u8],
) -> Result<Vec<u8>> {
let hdr = cmn.encode();
let body = serde_json::to_vec(pkt)?;
let body = aes_encrypt(header_pass, &body)?;
let mut result = Vec::with_capacity(4 + hdr.len() + body.len());
let total_size = (2 + hdr.len() + body.len()) as u16;
// insert total size
result.extend_from_slice(&total_size.to_be_bytes());
// packet_id
result.extend_from_slice(&[0, 0]);
// insert header
result.extend_from_slice(&hdr);
// insert body
result.extend_from_slice(&body);
Ok(result)
}
pub fn encode_packet<T: serde::Serialize>(cmn: &Common, pkt: &T) -> Result<Vec<u8>> {
// header
let hdr = cmn.encode();
// body
let body = serde_json::to_vec(pkt)?;
let mut result = Vec::with_capacity(4 + hdr.len() + body.len());
let total_size = (2 + hdr.len() + body.len()) as u16;
// insert total size
result.extend_from_slice(&total_size.to_be_bytes());
// packet_id
result.extend_from_slice(&[0, 0]);
// insert header
result.extend_from_slice(&hdr);
// insert body
result.extend_from_slice(&body);
Ok(result)
}
// decode the common, and return the encrypted bytes;
pub fn decode_common<'a>(value: &'a [u8]) -> Result<(Common<'a>, &'a [u8])> {
if value.len() < 4 {
return Err(SDLanError::NormalError("decode pkt length error"));
}
let size_bytes = value[0..2].try_into().expect("never goes here");
let size = u16::from_be_bytes(size_bytes);
if value.len() < 2 + size as usize {
return Err(SDLanError::NormalError("decode pkt header size error"));
}
let value2 = &value[4..2 + size as usize];
let (cmn, value2) = Common::from_slice(value2)?;
Ok((cmn, value2))
}
// decode the whole packet, suitable for un-encrypted packet info
pub fn decode_packet<'a, T: serde::Deserialize<'a>>(value: &'a [u8]) -> Result<(Common<'a>, T)> {
// let value2 = &value[4..2 + size as usize];
let (cmn, value2) = decode_common(value)?;
// let (cmn, value2) = Common::from_slice(value2)?;
let res = serde_json::from_slice(value2)?;
Ok((cmn, res))
}