diff --git a/.gitignore b/.gitignore index b3ffe7a..6a3fdc2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ /target /Cargo.lock -/.output +.output /.vscode *.pdf /.data diff --git a/Cargo.toml b/Cargo.toml index 4d04194..1186bcd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,8 +20,8 @@ serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.113" serde_repr = "0.1.18" sqlx = { version = "0.7.3", features = [ - "sqlx-sqlite", - "sqlite", + "sqlx-mysql", + "mysql", "runtime-tokio", ] } structopt = "0.3.26" diff --git a/src/bin/sdlan-sn/main.rs b/src/bin/sdlan-sn/main.rs index 0f9b38c..bdcc354 100644 --- a/src/bin/sdlan-sn/main.rs +++ b/src/bin/sdlan-sn/main.rs @@ -12,14 +12,16 @@ use tokio::net::UdpSocket; use structopt::StructOpt; mod config; +mod models; mod utils; +use models::create_network; +use models::init_db_pool; + use utils::license_ok; -const SERVER: &str = "127.0.0.1:7655"; - async fn client(address: &str) -> Result<()> { - let socket = UdpSocket::bind("127.0.0.1:0").await?; + let socket = UdpSocket::bind("0.0.0.0:0").await?; let id = "Dejavu"; let cmn = packet::Common { version: 1, @@ -52,14 +54,20 @@ async fn main() -> Result<()> { // check the argument let _: Ipv4Addr = args.address.parse().expect("invalid address found"); + // init database + init_db_pool().await?; + + create_network("测试network").await?; + let server = format!("{}:{}", args.address, args.port); + let cloned_server = server.clone(); let listener = UdpSocket::bind(server).await?; - let supernode = utils::SuperNode::new(listener); + let supernode = utils::SuperNode::new(listener, args.port); utils::init_supernode(supernode).expect("failed to init supernode"); - tokio::spawn(async { - client(SERVER).await.unwrap(); + tokio::spawn(async move { + client(&cloned_server).await.unwrap(); }); let listener = utils::get_supernode(); diff --git a/src/bin/sdlan-sn/migrations/20240219020143_create-network.sql b/src/bin/sdlan-sn/migrations/20240219020143_create-network.sql new file mode 100644 index 0000000..7694d43 --- /dev/null +++ b/src/bin/sdlan-sn/migrations/20240219020143_create-network.sql @@ -0,0 +1,12 @@ +-- Add migration script here +CREATE TABLE IF NOT EXISTS network ( + id bigint auto_increment primary key, + uuid CHAR(32) NOT NULL, + name VARCHAR(128) NOT NULL default "", + is_fedration TINYINT(1) NOT NULL default 0, + enabled TINYINT(1) NOT NULL default 1, + network_pass VARCHAR(64) NOT NULL, + header_pass VARCHAR(64) NOT NULL, + net_addr INT UNSIGNED NOT NULL, + net_bit_len TINYINT UNSIGNED NOT NULL +); diff --git a/src/bin/sdlan-sn/models/mod.rs b/src/bin/sdlan-sn/models/mod.rs new file mode 100644 index 0000000..77ff86c --- /dev/null +++ b/src/bin/sdlan-sn/models/mod.rs @@ -0,0 +1,52 @@ +use once_cell::sync::OnceCell; +use sdlan_sn_rs::utils::{gen_uuid, Result, SDLanError}; +use sqlx::mysql::{MySqlPool, MySqlPoolOptions}; + +use crate::utils::Network; + +const url: &'static str = "mysql://sdlan:sdlan-pass@localhost:3306/sdlan"; + +static DBPOOL: OnceCell = OnceCell::new(); + +pub async fn init_db_pool() -> Result<()> { + let pool = MySqlPoolOptions::new().connect(url).await?; + if let Err(_) = DBPOOL.set(pool) { + return Err(SDLanError::NormalError("init db pool failed")); + } + Ok(()) +} + +pub fn get_db_pool() -> &'static MySqlPool { + DBPOOL.get().expect("db pool has not been initialized") +} + +pub async fn get_network() -> Result> { + let pool = get_db_pool(); + let mut res = + sqlx::query_as::("select * from network where id = 2") + .fetch_all(pool) + .await?; + + for r in res.iter_mut() { + r.edges.insert("hello".to_string()); + println!("network is {:?}", r); + } + + Ok(res) +} + +pub async fn create_network(name: &str) -> Result { + let pool = get_db_pool(); + let uuid = gen_uuid(); + let networkpass = gen_uuid(); + let headerpass = gen_uuid(); + let res = sqlx::query(r#"insert into network (`uuid`, `name`, `network_pass`, `header_pass`, `net_addr`, `net_bit_len`) values (?,?,?,?,?,?);"#) + .bind(&uuid) + .bind(name) + .bind(networkpass) + .bind(headerpass) + .bind(0) + .bind(0) + .execute(pool).await?; + Ok(uuid) +} diff --git a/src/bin/sdlan-sn/utils/sn.rs b/src/bin/sdlan-sn/utils/sn.rs index 1fa061e..6835580 100644 --- a/src/bin/sdlan-sn/utils/sn.rs +++ b/src/bin/sdlan-sn/utils/sn.rs @@ -2,15 +2,15 @@ use dashmap::{DashMap, DashSet}; use sdlan_sn_rs::peer::IpSubnet; use sdlan_sn_rs::utils::gen_uuid; use sdlan_sn_rs::utils::MyDashMap; +use serde::Deserialize; +use serde::Serialize; +use sqlx::prelude::FromRow; use tokio::net::UdpSocket; use crate::config; use std::net::Ipv4Addr; -use std::{ - sync::atomic::{AtomicU32, AtomicU8}, - time::{SystemTime, UNIX_EPOCH}, -}; +use std::time::{SystemTime, UNIX_EPOCH}; use sdlan_sn_rs::utils::{Result, SDLanError}; @@ -85,14 +85,8 @@ impl SuperNode { fedration: Network::new("*fedration", gen_uuid(), true, 0, 0), pending: Network::new("*pending", gen_uuid(), false, 0, 0), ip_range: AutoIpAssign { - start_ip: IpSubnet { - net_addr: AtomicU32::new(startip), - net_bit_len: AtomicU8::new(config::DEFAULT_IP_NET_BIT_LEN), - }, - end_ip: IpSubnet { - net_addr: AtomicU32::new(endip), - net_bit_len: AtomicU8::new(config::DEFAULT_IP_NET_BIT_LEN), - }, + start_ip: IpSubnet::new(startip, config::DEFAULT_IP_NET_BIT_LEN), + end_ip: IpSubnet::new(endip, config::DEFAULT_IP_NET_BIT_LEN), }, sock, auth_key: String::from("encrypt!"), @@ -105,17 +99,23 @@ pub struct AutoIpAssign { pub end_ip: IpSubnet, } +#[derive(Serialize, Deserialize, FromRow, Debug)] pub struct Network { - pub id: String, + pub uuid: String, pub name: String, pub is_fedration: bool, pub enabled: bool, pub network_pass: String, pub header_pass: String, + #[sqlx(flatten)] pub auto_ip_net: IpSubnet, // 这个网络下面的节点,ip到peer uuid // pub edges: DashMap, + #[serde(skip)] + #[sqlx(skip)] pub edges: DashSet, + #[serde(skip)] + #[sqlx(skip)] pub ip_to_edge_id: DashMap, } @@ -123,15 +123,12 @@ impl Network { fn new(name: &str, id: String, is_fedration: bool, ipnet: u32, ipnet_bitlen: u8) -> Self { Self { name: name.to_string(), - id, + uuid: id, is_fedration, enabled: true, network_pass: gen_uuid(), header_pass: gen_uuid(), - auto_ip_net: IpSubnet { - net_addr: AtomicU32::new(ipnet), - net_bit_len: AtomicU8::new(ipnet_bitlen), - }, + auto_ip_net: IpSubnet::new(ipnet, ipnet_bitlen), edges: DashSet::new(), ip_to_edge_id: DashMap::new(), } diff --git a/src/packet/register_super.rs b/src/packet/register_super.rs index c1122f4..fc60ff1 100644 --- a/src/packet/register_super.rs +++ b/src/packet/register_super.rs @@ -29,8 +29,6 @@ mod test { use self::peer::SdlanSock; use crate::packet::Common; use crate::utils::Result; - use std::sync::atomic; - use std::sync::atomic::Ordering::Relaxed; use super::*; @@ -52,10 +50,7 @@ mod test { v4: [0; 4], v6: [1; 16], }), - dev_addr: peer::IpSubnet { - net_addr: atomic::AtomicU32::new(192), - net_bit_len: atomic::AtomicU8::new(24), - }, + dev_addr: peer::IpSubnet::new(192, 24), pub_key: "public key", token: "user's token", }; @@ -76,14 +71,8 @@ mod test { assert_eq!(pkt1.pass, pkt2.pass); assert_eq!(pkt1.sock, pkt2.sock); - assert_eq!( - pkt1.dev_addr.net_addr.load(Relaxed), - pkt2.dev_addr.net_addr.load(Relaxed) - ); - assert_eq!( - pkt1.dev_addr.net_bit_len.load(Relaxed), - pkt2.dev_addr.net_bit_len.load(Relaxed) - ); + assert_eq!(pkt1.dev_addr.net_addr(), pkt2.dev_addr.net_addr(),); + assert_eq!(pkt1.dev_addr.net_bit_len(), pkt2.dev_addr.net_bit_len()); assert_eq!(pkt1.pub_key, pkt2.pub_key); assert_eq!(pkt1.token, pkt2.token); } diff --git a/src/peer.rs b/src/peer.rs index 84475a8..eda6b3c 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -1,9 +1,11 @@ #![allow(unused)] use std::default::Default; -use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU8}; +use std::os::unix::net; +use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU8, Ordering}; use std::sync::Mutex; use serde::{Deserialize, Serialize}; +use sqlx::prelude::FromRow; #[derive(Debug)] pub struct Peer { @@ -30,10 +32,7 @@ impl Peer { pub fn new(id: &str) -> Self { Self { id: id.to_string(), - dev_addr: IpSubnet { - net_addr: AtomicU32::new(0), - net_bit_len: AtomicU8::new(0), - }, + dev_addr: IpSubnet::new(0, 0), sock: Mutex::new(SdlanSock { family: 0, port: 0, @@ -53,12 +52,51 @@ impl Peer { } /// IpSubnet, 对端ipv4信息 -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, FromRow)] pub struct IpSubnet { + net_addr: u32, + net_bit_len: u8, +} + +impl IpSubnet { + pub fn new(net_addr: u32, net_bit_len: u8) -> IpSubnet { + Self { + net_addr, + net_bit_len, + } + } + + pub fn net_addr(&self) -> u32 { + self.net_addr + } + + pub fn net_bit_len(&self) -> u8 { + self.net_bit_len + } +} +pub struct IpSubnetAtomic { pub net_addr: AtomicU32, pub net_bit_len: AtomicU8, } +impl From for IpSubnetAtomic { + fn from(value: IpSubnet) -> Self { + Self { + net_addr: AtomicU32::new(value.net_addr()), + net_bit_len: AtomicU8::new(value.net_bit_len()), + } + } +} + +impl IpSubnetAtomic { + pub fn new(net_addr: u32, net_bit_len: u8) -> Self { + Self { + net_addr: AtomicU32::new(net_addr), + net_bit_len: AtomicU8::new(net_bit_len), + } + } +} + /// SdlanSock: 对端对外的ip信息,包括ipv4和ipv6 #[derive(Debug, Serialize, Deserialize, PartialEq)] pub struct SdlanSock { diff --git a/src/utils/error.rs b/src/utils/error.rs index b97f613..d2c051f 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -1,3 +1,4 @@ +use sqlx::Error; use std::convert::From; pub type Result = std::result::Result; @@ -9,6 +10,7 @@ pub enum SDLanError { ConvertError(String), SerializeError(String), EncryptError(String), + DBError(String), } impl From for SDLanError { @@ -28,3 +30,9 @@ impl From for SDLanError { Self::SerializeError(value.to_string()) } } + +impl From for SDLanError { + fn from(value: sqlx::Error) -> Self { + Self::DBError(value.to_string()) + } +}