diff --git a/.gitignore b/.gitignore index b17c6c1..b3ffe7a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /.output /.vscode *.pdf +/.data diff --git a/Cargo.toml b/Cargo.toml index 21e50be..af3fcbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,11 +6,15 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +aes = "0.8.4" anyhow = "1.0.79" byteorder = "1.5.0" +cbc = "0.1.2" dashmap = "5.5.3" lazy_static = "1.4.0" +rand = "0.8.5" rolling-file = { git = "https://git.asxalex.pw/rust/rolling-file" } +rsa = "0.9.6" serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.113" serde_repr = "0.1.18" diff --git a/src/bin/sdlan-sn/main.rs b/src/bin/sdlan-sn/main.rs index c50432e..b42ec3a 100644 --- a/src/bin/sdlan-sn/main.rs +++ b/src/bin/sdlan-sn/main.rs @@ -1,13 +1,16 @@ use sdlan_sn_rs::log; use sdlan_sn_rs::packet; -use sdlan_sn_rs::utils; +use sdlan_sn_rs::utils::Result; use tracing::{debug, error}; //use std::io::Read; use std::net::SocketAddr; use std::time::Duration; use tokio::net::UdpSocket; -use utils::Result; + +mod utils; + +use utils::license_ok; const SERVER: &str = "127.0.0.1:7655"; @@ -35,6 +38,9 @@ async fn main() -> Result<()> { let _guard = log::init_log(); debug!("main starts here"); + // check license + license_ok()?; + let listener = UdpSocket::bind(SERVER).await?; tokio::spawn(async { diff --git a/src/bin/sdlan-sn/utils/mod.rs b/src/bin/sdlan-sn/utils/mod.rs new file mode 100644 index 0000000..a25b639 --- /dev/null +++ b/src/bin/sdlan-sn/utils/mod.rs @@ -0,0 +1,25 @@ +use sdlan_sn_rs::utils::{rsa_decrypt, Result}; +use serde::{Deserialize, Serialize}; +use tracing::error; + +#[derive(Serialize, Deserialize)] +struct LicenseInfo<'a> { + // 由谁颁发的 + #[serde(rename = "from")] + from: &'a str, + // 颁发给谁的 + #[serde(rename = "to")] + to: &'a str, + // 有效起始时间 + #[serde(rename = "starts")] + starts: i64, + // 有效结束时间 + #[serde(rename = "ends")] + ends: i64, +} + +pub fn license_ok() -> Result<()> { + // TODO: check license + error!("license expired"); + Ok(()) +} diff --git a/src/utils/error.rs b/src/utils/error.rs index 0107473..b97f613 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -8,6 +8,7 @@ pub enum SDLanError { NormalError(&'static str), ConvertError(String), SerializeError(String), + EncryptError(String), } impl From for SDLanError { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ba43934..899a22f 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,5 +1,23 @@ mod encode_decode; mod error; +mod myaes; +mod myrsa; pub use encode_decode::*; pub use error::*; +pub use myaes::{aes_decrypt, aes_encrypt}; +pub use myrsa::{gen_keys, load_private_key_file, load_public_key_file, rsa_decrypt, rsa_encrypt}; + +#[cfg(test)] +pub mod test_utils { + use rand::{thread_rng, Rng}; + + pub fn generate_info(n: i32) -> Vec { + let mut result = vec![]; + for i in 0..n { + let v: u8 = thread_rng().gen_range(0..255); + result.push(v); + } + result + } +} diff --git a/src/utils/myaes.rs b/src/utils/myaes.rs new file mode 100644 index 0000000..d2fb02c --- /dev/null +++ b/src/utils/myaes.rs @@ -0,0 +1,75 @@ +use super::{Result, SDLanError}; +use aes::cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit}; + +type Aes256CbcEnc = cbc::Encryptor; +type Aes256CbcDec = cbc::Decryptor; + +pub fn aes_encrypt(key: &[u8], plain: &[u8]) -> Result> { + let mut buf = Vec::new(); + buf.resize(plain.len() + 16, 0); + + let iv = &key[..16]; + match Aes256CbcEnc::new(key.into(), iv.into()).encrypt_padded_b2b_mut::(plain, &mut buf) + { + Err(e) => Err(SDLanError::EncryptError(format!("aes encrypt: {}", e))), + Ok(v) => Ok(Vec::from(v)), + } +} + +pub fn aes_decrypt(key: &[u8], cipherd: &[u8]) -> Result> { + let mut buf = Vec::new(); + let iv = &key[..16]; + buf.resize(cipherd.len() + 16, 0); + match Aes256CbcDec::new(key.into(), iv.into()) + .decrypt_padded_b2b_mut::(cipherd, &mut buf) + { + Err(e) => Err(SDLanError::EncryptError(format!("aes decrypt: {}", e))), + Ok(v) => Ok(Vec::from(v)), + } +} + +#[cfg(test)] +mod test { + use super::*; + use rand::{thread_rng, Rng}; + + fn gen_aes_key() -> [u8; 32] { + let mut res = [0; 32]; + for i in 0..32 { + let mut temp = thread_rng().gen_range(0..63) as u8; + if temp <= 10 { + temp += 0x30; + } else if temp <= 36 { + temp += 0x61; + } else { + temp += 0x41; + } + res[i as usize] = temp; + } + + return res; + } + + #[test] + fn test_aes() -> Result<()> { + use crate::utils::test_utils::generate_info; + let key = gen_aes_key(); + println!("aes key: {:?}", key); + let zero_msg = vec![]; + let short_msg = generate_info(10); + let middle_msg = generate_info(1024); + let big_msg = generate_info(65535); + + let msgs = [zero_msg, short_msg, middle_msg, big_msg]; + + for msg in msgs.iter() { + println!("test aes decrypt with {} bytes", msg.len()); + let encrypted = aes_encrypt(&key, msg.as_slice())?; + let decrypted = aes_decrypt(&key, encrypted.as_slice())?; + assert_eq!(&decrypted, msg); + println!("test aes decrypt {} ok", msg.len()); + } + + Ok(()) + } +} diff --git a/src/utils/myrsa.rs b/src/utils/myrsa.rs new file mode 100644 index 0000000..d351016 --- /dev/null +++ b/src/utils/myrsa.rs @@ -0,0 +1,135 @@ +use super::{Result, SDLanError}; +// use config::{convert_error, ErrorKind, SDLanError}; +use rsa::pkcs1::DecodeRsaPrivateKey; +use rsa::pkcs1::EncodeRsaPrivateKey; +use rsa::pkcs8::DecodePublicKey; +use rsa::pkcs8::EncodePublicKey; +use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey}; + +use std::io::Read; +use std::sync::Arc; + +pub fn gen_keys() { + let mut rng = rand::thread_rng(); + let bits = 2048; + let priv_key = RsaPrivateKey::new(&mut rng, bits).unwrap(); + let public_key = RsaPublicKey::from(&priv_key); + + std::fs::create_dir_all(".data").expect("failed to create .data"); + + priv_key + .write_pkcs1_pem_file(".data/id_rsa", rsa::pkcs8::LineEnding::LF) + .unwrap(); + public_key + .write_public_key_pem_file(".data/id_rsa.pub", rsa::pkcs8::LineEnding::LF) + .unwrap(); +} + +pub fn load_public_key(pubkey: &str) -> Result { + match RsaPublicKey::from_public_key_pem(pubkey) { + Err(e) => Err(SDLanError::EncryptError(format!("load pub key: {}", e))), + Ok(v) => Ok(v), + } +} + +pub fn load_public_key_file(pubkeyfile: &str) -> Result { + let mut fp = match std::fs::File::open(pubkeyfile) { + Ok(fp) => fp, + Err(e) => { + return Err(SDLanError::EncryptError(format!( + "open pub key file: {}", + e + ))); + } + }; + let mut res = String::new(); + if let Err(e) = fp.read_to_string(&mut res) { + return Err(SDLanError::EncryptError(format!( + "read pub key file: {}", + e + ))); + } + load_public_key(res.as_str()) +} + +pub fn load_private_key_file(privkeyfile: &str) -> Result { + let mut fp = match std::fs::File::open(privkeyfile) { + Ok(fp) => fp, + Err(e) => { + return Err(SDLanError::EncryptError(format!( + "open priv key file: {}", + e + ))); + } + }; + let mut res = String::new(); + if let Err(e) = fp.read_to_string(&mut res) { + return Err(SDLanError::EncryptError(format!( + "read priv key file: {}", + e + ))); + } + load_private_key(res.as_str()) +} + +pub fn load_private_key(privkey: &str) -> Result { + match RsaPrivateKey::from_pkcs1_pem(privkey) { + Err(e) => Err(SDLanError::EncryptError(format!("load priv key: {}", e))), + Ok(v) => Ok(v), + } +} + +pub fn rsa_encrypt(pubkey: Arc, data: &[u8]) -> Result> { + let mut rng = rand::thread_rng(); + match pubkey.encrypt(&mut rng, Pkcs1v15Encrypt, data) { + Err(e) => Err(SDLanError::EncryptError(format!("rsa encrypt: {}", e))), + Ok(v) => Ok(v), + } +} + +pub fn rsa_decrypt(privkey: Arc, cipherd: &[u8]) -> Result> { + match privkey.decrypt(Pkcs1v15Encrypt, cipherd) { + Err(e) => Err(SDLanError::EncryptError(format!("rsa decrypt: {}", e))), + Ok(v) => Ok(v), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::test_utils::generate_info; + + fn generate_key() { + if std::fs::File::open(".data/id_rsa").is_ok() { + return; + } + gen_keys(); + } + + #[test] + fn test_rsa() -> Result<()> { + generate_key(); + let public = load_public_key_file(".data/id_rsa.pub")?; + let private = load_private_key_file(".data/id_rsa")?; + + let zero_msg: Vec = generate_info(0).to_vec(); + let normal_msg: Vec = generate_info(100).to_vec(); + let big_msg: Vec = generate_info(128).to_vec(); + let max_msg: Vec = generate_info(256 - 16).to_vec(); + + let msgs = vec![zero_msg, normal_msg, big_msg, max_msg]; + + let public = Arc::new(public); + let private = Arc::new(private); + + for msg in msgs.iter() { + println!("testing {} size length", msg.len()); + let encrypted = rsa_encrypt(public.clone(), msg)?; + let decrypted = rsa_decrypt(private.clone(), encrypted.as_slice())?; + assert_eq!(decrypted.as_slice(), msg); + println!("testing {} ok", msg.len()); + } + + Ok(()) + } +}