added policy request and response, need test tomorrow

This commit is contained in:
alex 2026-03-05 22:53:42 +08:00
parent e8e8655100
commit 31845c6707
13 changed files with 272 additions and 51 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ sdlan.exe
*.tar.gz
*.tgz
/punchnet
/ca

View File

@ -1,4 +1,4 @@
{
"rust-analyzer.cargo.target": "x86_64-pc-windows-gnu",
"rust-analyzer.cargo.features": ["tun"]
// "rust-analyzer.cargo.target": "x86_64-pc-windows-gnu",
// "rust-analyzer.cargo.features": ["tun"]
}

View File

@ -170,6 +170,8 @@ async fn main() {
}
}
println!("identity_id = {}", connect_info.identity_id);
let self_host_name = connect_info.hostname;
let (tx, rx) = std::sync::mpsc::channel();

View File

@ -7,7 +7,7 @@ use crate::network::ipv6::run_ipv6;
use crate::network::{
get_edge, ping_to_sn, read_and_parse_packet, TunTapPacketHandler,
};
use crate::tcp::{init_quic_conn, send_stun_request};
use crate::tcp::{init_identity_cache, init_quic_conn, send_stun_request};
use crate::utils::{send_to_sock, CommandLine};
use crate::{ConnectionInfo};
use sdlan_sn_rs::peer::{SdlanSock};
@ -32,6 +32,8 @@ pub async fn async_main(
// let _ = PidRecorder::new(".pid");
let edge = get_edge();
init_identity_cache();
// let token = args.token.clone();
let cancel_tcp = cancel.clone();
let (ipv6_network_restarter, rx) = channel(10);

View File

@ -1,6 +1,7 @@
use std::{net::SocketAddr, sync::atomic::Ordering, time::Duration};
use crate::tcp::{NatType, get_quic_write_conn};
use crate::pb::SdlPolicyRequest;
use crate::tcp::{NatType, get_quic_write_conn, is_identity_ok};
use crate::{network::TunTapPacketHandler, utils::mac_to_string};
use crate::{
@ -12,7 +13,7 @@ use crate::{
tcp::{PacketType},
utils::{send_to_sock, Socket},
};
use etherparse::Ethernet2Header;
use etherparse::{Ethernet2Header, PacketHeaders, ip_number};
use prost::Message;
use sdlan_sn_rs::utils::{BROADCAST_MAC};
use sdlan_sn_rs::{
@ -24,7 +25,7 @@ use sdlan_sn_rs::{
},
};
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use super::{EdgePeer, Node};
@ -111,6 +112,7 @@ pub async fn handle_packet(eee: &'static Node, addr: SocketAddr, buf: &[u8]) ->
error!("failed to convert src mac");
return Err(SDLanError::NormalError("failed to convert vec to Mac"));
};
// let from_sock = get_sdlan_sock_from_socketaddr(addr).unwrap();
if data.is_p2p {
debug!("[P2P] Rx data from {}", from_sock.to_string());
@ -818,6 +820,47 @@ pub fn print_hex(key: &[u8]) {
println!("[{}]", value.join(" "))
}
async fn check_identity_is_ok(eee: &Node, identity: u32, protocol: u8, port: u16) -> bool{
match is_identity_ok(identity, protocol, port) {
Some(true) => {
// identity is ok
true
}
Some(false) => {
// identity is not allowed
warn!("identity is not allowed for protocol={:?}, port={}", protocol, port);
false
}
None => {
let policy_request = SdlPolicyRequest {
pkt_id: eee.get_next_packet_id(),
src_identity_id: identity,
dst_identity_id: eee.identity_id.load(),
version: 1,
};
println!("policy request: {:?}", policy_request);
// debug!("send register super: {:?}", register_super);
// let packet_id = edge.get_next_packet_id();
let data = encode_to_tcp_message(
Some(policy_request),
0,
PacketType::PolicyRequest as u8,
)
.unwrap();
let stream = get_quic_write_conn();
if let Err(e) = stream.send(data).await {
error!("failed to write to quic: {}", e.as_str());
}
false
// no such identity, should request for it
}
}
}
async fn handle_tun_packet(
eee: &Node,
_from_sock: &SdlanSock,
@ -838,7 +881,47 @@ async fn handle_tun_packet(
error!("failed to decrypt original data");
return;
}
let data = origin.unwrap();
let Ok(headers) = PacketHeaders::from_ethernet_slice(&data) else {
error!("failed to parse packet");
return;
};
if let Some(ip) = headers.net {
match ip {
etherparse::NetHeaders::Ipv4(ipv4, _) => {
let protocol = ipv4.protocol;
match protocol {
ip_number::TCP => {
let tcp_header = headers.transport.unwrap().tcp().unwrap();
let port = tcp_header.destination_port;
if !check_identity_is_ok(eee, pkt.identity_id, protocol.0, port).await {
return;
}
}
ip_number::UDP => {
let udp_header = headers.transport.unwrap().udp().unwrap();
let port = udp_header.destination_port;
if !check_identity_is_ok(eee, pkt.identity_id, protocol.0, port).await {
return;
}
}
_other => {
// just ok
}
}
}
_other => {
// just ignore, ok
}
}
}
if let Err(e) = eee
.device
.handle_packet_from_net(&data, key.as_slice())

View File

@ -154,9 +154,12 @@ impl Iface {
}
}
// TODO: set dns should be opened
/*
if let Err(e) = set_dns(self, &self.name, network_domain, &ip_to_string(&default_gw)) {
error!("failed to set dns: {}", e.as_str());
}
*/
} else {
info!("set tun device");
let res = Command::new("ifconfig")
@ -219,13 +222,23 @@ impl TunTapPacketHandler for Iface {
return Ok(());
};
if let Some(eth) = headers.link {
if let Some(hdr) = eth.ethernet2() {
use bytes::Bytes;
if let Some(ip) = headers.net {
match ip {
etherparse::NetHeaders::Ipv4(ipv4, _) => {
use etherparse::ip_number::{self, ICMP};
if let Some(transport) = headers.transport {
if let Some(tcp) = transport.tcp() {
// is tcp
}
}
if u32::from_be_bytes(ipv4.destination) == DNS_IP {
// should send to dns

View File

@ -7,7 +7,7 @@ use sdlan_sn_rs::utils::Result;
// tcp message has two-byte of size at header
pub fn encode_to_tcp_message<T: Message>(
msg: Option<T>,
packet_id: u32,
_packet_id: u32,
packet_type: u8,
) -> Result<Vec<u8>> {
let mut raw_data = Vec::new();
@ -16,10 +16,10 @@ pub fn encode_to_tcp_message<T: Message>(
msg.encode(&mut raw_data)?;
}
let mut result = Vec::with_capacity(raw_data.len() + 7);
let size = u16::to_be_bytes(raw_data.len() as u16 + 5);
let mut result = Vec::with_capacity(raw_data.len() + 3);
let size = u16::to_be_bytes(raw_data.len() as u16 + 1);
result.extend_from_slice(&size);
result.extend_from_slice(&u32::to_be_bytes(packet_id));
// result.extend_from_slice(&u32::to_be_bytes(packet_id));
result.push(packet_type);
result.extend_from_slice(&raw_data);
Ok(result)

View File

@ -4,6 +4,7 @@ use std::path::Path;
use std::sync::Arc;
use quinn::Endpoint;
use quinn::TransportConfig;
use quinn::crypto::rustls::QuicClientConfig;
use rustls::crypto::CryptoProvider;
use rustls::crypto::ring;

54
src/tcp/identity_cache.rs Normal file
View File

@ -0,0 +1,54 @@
use std::{collections::HashMap, sync::OnceLock};
use dashmap::DashMap;
use tracing::debug;
type IdentityID = u32;
type Port = u16;
type Proto = u8;
#[derive(Debug)]
pub struct RuleInfo {
pub proto: Proto,
pub port: Port,
}
static RULE_CACHE: OnceLock<DashMap<IdentityID, HashMap<Port, HashMap<Proto, bool>>>> = OnceLock::new();
pub fn init_identity_cache() {
RULE_CACHE.set(DashMap::new()).unwrap();
}
pub fn set_identity_cache(identity: IdentityID, infos: Vec<RuleInfo>) {
debug!("setting identity cache for identity={}, infos: {:?}", identity, infos);
let cache = RULE_CACHE.get().expect("should set first");
let mut temp = HashMap::new();
for info in &infos {
let mut protomap = HashMap::new();
protomap.insert(info.proto, true);
temp.insert(info.port, protomap);
}
cache.remove(&identity);
cache.insert(identity, temp);
}
pub fn is_identity_ok(identity: IdentityID, proto: Proto, port: Port) -> Option<bool> {
let cache = RULE_CACHE.get().expect("should set first");
match cache.get(&identity) {
Some(data) => {
if let Some(proto_info) = data.get(&port) {
if let Some(_has) = proto_info.get(&proto) {
return Some(true);
}
}
Some(false)
}
None => {
None
}
}
}

View File

@ -2,6 +2,11 @@ mod tcp_codec;
// mod tcp_conn;
mod quic;
mod identity_cache;
pub use tcp_codec::*;
pub use quic::*;
pub use identity_cache::*;
// pub use tcp_conn::*;

View File

@ -9,7 +9,7 @@ use tokio::{io::BufReader, net::TcpStream, sync::mpsc::{Receiver, Sender, channe
use tokio_util::sync::CancellationToken;
use tracing::{debug, error};
use crate::{ConnectionInfo, ConnectionState, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, SdlanTcp, read_a_packet, send_stun_request}};
use crate::{ConnectionInfo, ConnectionState, config::{NULL_MAC, TCP_PING_TIME}, get_edge, network::{Node, RegisterSuperFeedback, StartStopInfo, check_peer_registration_needed, handle_packet_peer_info}, pb::{SdlPolicyRequest, SdlPolicyResponse, SdlRegisterSuper, SdlRegisterSuperAck, SdlRegisterSuperNak, SdlSendRegisterEvent, encode_to_tcp_message}, tcp::{EventType, NakMsgCode, NatType, PacketType, RuleInfo, SdlanTcp, read_a_packet, send_stun_request, set_identity_cache}};
static GLOBAL_QUIC_HANDLE: OnceLock<ReadWriterHandle> = OnceLock::new();
@ -53,10 +53,6 @@ impl ReadWriterHandle {
let connected: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
tokio::spawn(async move {
});
let actor = ReadWriteActor::new(
cancel,
addr,
@ -115,25 +111,28 @@ async fn handle_tcp_message(msg: SdlanTcp) {
debug!("got tcp message: {:?}", msg.packet_type);
match msg.packet_type {
PacketType::RegisterSuperACK => {
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK");
return;
};
edge.send_register_super_feedback(
msg._packet_id,
ack.pkt_id,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
should_exit: false,
},
);
let Ok(ack) = SdlRegisterSuperAck::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_ACK");
return;
};
debug!("got register super ack: {:?}", ack);
edge.session_token.set(ack.session_token);
let Ok(aes) = rsa_decrypt(&edge.rsa_private, &ack.aes_key) else {
error!("failed to rsa decrypt aes key");
return;
};
/*
let Some(dev) = ack.dev_addr else {
error!("no dev_addr is specified");
@ -180,9 +179,35 @@ async fn handle_tcp_message(msg: SdlanTcp) {
// println!("nat type is: {:?}", nattype);
});
}
PacketType::PolicyReply => {
let Ok(policy) = SdlPolicyResponse::decode(&msg.current_packet[..]) else {
error!("failed to decode POLICY RESPONSE");
return;
};
let identity = policy.src_identity_id;
let mut infos = Vec::new();
let mut start = 0;
while start < policy.rules.len() {
if start + 3 > policy.rules.len() {
break;
}
let proto = policy.rules[start];
let port = u16::from_be_bytes([policy.rules[start+1], policy.rules[start+2]]);
start += 3;
infos.push(RuleInfo{
proto,
port,
});
}
set_identity_cache(identity, infos);
}
PacketType::RegisterSuperNAK => {
let Ok(_nak) = SdlRegisterSuperNak::decode(&msg.current_packet[..]) else {
error!("failed to decode REGISTER_SUPER_NAK");
/*
edge.send_register_super_feedback(
msg._packet_id,
RegisterSuperFeedback {
@ -191,12 +216,17 @@ async fn handle_tcp_message(msg: SdlanTcp) {
should_exit: false,
},
);
*/
return;
};
error!("got nak: {:?}", _nak);
let pkt_id = _nak.pkt_id;
let Ok(error_code) = NakMsgCode::try_from(_nak.error_code as u8) else {
edge.send_register_super_feedback(
msg._packet_id,
//msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 2,
message: "error_code not recognized".to_owned(),
@ -208,7 +238,8 @@ async fn handle_tcp_message(msg: SdlanTcp) {
match error_code {
NakMsgCode::InvalidToken => {
edge.send_register_super_feedback(
msg._packet_id,
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 3,
message: "invalid token".to_owned(),
@ -219,7 +250,8 @@ async fn handle_tcp_message(msg: SdlanTcp) {
}
NakMsgCode::NodeDisabled => {
edge.send_register_super_feedback(
msg._packet_id,
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 4,
message: "Node is disabled".to_owned(),
@ -230,7 +262,8 @@ async fn handle_tcp_message(msg: SdlanTcp) {
}
_other => {
edge.send_register_super_feedback(
msg._packet_id,
// msg._packet_id,
pkt_id,
RegisterSuperFeedback {
result: 0,
message: "".to_owned(),
@ -454,28 +487,45 @@ impl ReadWriteActor {
let state = ConnectionInfo::ConnState(ConnectionState::Connecting);
let _ = connecting_chan.send(state).await;
}
debug!("try connecting...");
let Ok(conn) = edge.quic_endpoint.connect(self.remote.parse().unwrap(), "") else {
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
let host = self.remote.split(":").next().unwrap();
debug!("try connecting to {}, host = {}", self.remote, host);
let conn = match edge.quic_endpoint.connect(self.remote.parse().unwrap(), host) {
Ok(conn) => conn,
Err(e) => {
error!("failed to connect: {}", e);
println!("failed to connect: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
return;
};
let Ok(conn) = conn.await else {
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
let conn = match conn.await {
Err(e) => {
println!("failed to connect await: {}", e);
error!("failed to connect await: {}", e);
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
continue;
}
return;
}
return;
Ok(conn) => conn,
};
let local_ip = conn.local_ip();
let Ok((mut send, mut recv)) = conn.open_bi().await else {
println!("failed to open-bi");
self.connected.store(false, Ordering::Relaxed);
if keep_reconnect {
tokio::time::sleep(Duration::from_secs(3)).await;
@ -499,9 +549,9 @@ impl ReadWriteActor {
// let (reader, mut write) = stream.into_split();
let read_from_tcp = async move {
let mut buffed_reader = BufReader::new(recv);
// let mut buffed_reader = BufReader::new(recv);
loop {
match read_a_packet(&mut buffed_reader).await {
match read_a_packet(&mut recv).await {
Ok(packet) => {
debug!("got packet: {:?}", packet);
if let Err(_e) = self.from_tcp.send(packet).await {
@ -533,7 +583,7 @@ impl ReadWriteActor {
let check_pong = async {
loop {
tokio::time::sleep(Duration::from_secs(10)).await;
tokio::time::sleep(Duration::from_secs(3600)).await;
let connected = self.connected.load(Ordering::Relaxed);
let now = get_current_timestamp();
@ -573,6 +623,7 @@ impl ReadWriteActor {
_ = check_stop => {},
}
on_disconnected_callback().await;
conn.close(0u32.into(), "close".as_bytes());
debug!("connect retrying");
tokio::time::sleep(Duration::from_secs(1)).await;
debug!("disconnected");
@ -619,6 +670,8 @@ async fn on_connected_callback(local_ip: Option<IpAddr>, stream: &mut SendStream
pub_key: edge.rsa_pubkey.clone(),
hostname: edge.hostname.read().unwrap().clone(),
};
println!("register super: {:?}", register_super);
// debug!("send register super: {:?}", register_super);
let packet_id = match pkt_id {
Some(id) => id,

View File

@ -13,7 +13,7 @@ use crate::{network::Node, pb::{SdlStunRequest, Sdlv6Info, encode_to_udp_message
#[derive(Debug)]
pub struct SdlanTcp {
pub _packet_id: u32,
// pub _packet_id: u32,
pub packet_type: PacketType,
pub current_packet: Vec<u8>,
}
@ -94,6 +94,11 @@ pub enum PacketType {
StunProbe = 0x32,
StunProbeReply = 0x33,
Welcome = 0x4f,
PolicyRequest = 0xb0,
PolicyReply = 0xb1,
Data = 0xff,
}
@ -129,24 +134,22 @@ pub async fn send_stun_request(eee: &Node) {
}
pub async fn read_a_packet(
reader: &mut BufReader<RecvStream>,
reader: &mut RecvStream,
) -> Result<SdlanTcp, std::io::Error> {
debug!("read a packet");
let size = reader.read_u16().await?;
let payload_size = reader.read_u16().await?;
debug!("1");
let packet_id = reader.read_u32().await?;
debug!("2");
let packet_type = reader.read_u8().await?;
debug!("3");
if size < 5 {
if payload_size < 1 {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"size less than five",
));
}
let bufsize = (size - 5) as usize;
let bufsize = (payload_size - 1) as usize;
let mut binary = vec![0; bufsize];
let mut to_read = bufsize;
@ -155,6 +158,10 @@ pub async fn read_a_packet(
break;
}
let size_got = reader.read(&mut binary[(bufsize - to_read)..]).await?;
if size_got.is_none() {
break;
}
let size_got = size_got.unwrap();
if size_got == 0 {
return Err(std::io::Error::new(
@ -167,11 +174,11 @@ pub async fn read_a_packet(
let Ok(packet_type) = packet_type.try_into() else {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"packet type error",
format!("packet type error: 0x{:02x}", packet_type),
));
};
let result = SdlanTcp {
_packet_id: packet_id,
// _packet_id: packet_id,
packet_type,
current_packet: binary,
};

View File

@ -40,7 +40,7 @@ pub fn ip_string_to_u32(ip: &str) -> Result<u32> {
pub fn get_access_token() -> Option<CachedLoginInfo> {
let path = format!("{}/.access_token", get_base_dir());
if let Ok(content) = std::fs::read(&path) {
let data = serde_json::from_slice((&content)).unwrap();
let data = serde_json::from_slice(&content).unwrap();
return Some(data);
}
None