diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 66525db..bc6439a 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -78,6 +78,9 @@ actor SDLContextActor { private var updatePolicyTask: Task? private let snapshotPublisher: SnapshotPublisher + // Flow流会话管理, 过期时间为: 180秒 + private let flowSessionManager = SDLFlowSessionManager(sessionTimeout: 180) + // 注册任务 private var registerTask: Task? @@ -524,7 +527,6 @@ actor SDLContextActor { } let mac = LayerPacket.MacAddress(data: data.dstMac) - let networkAddr = config.networkAddress guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return @@ -565,30 +567,14 @@ actor SDLContextActor { // 检查权限逻辑 let identitySnapshot = self.snapshotPublisher.current() - if let ruleMap = identitySnapshot.lookup(data.identityID) { - let proto = ipPacket.header.proto - switch ipPacket.transportPacket() { - case .tcp(let tcpPacket): - let dstPort = tcpPacket.header.dstPort - if ruleMap.isAllow(proto: proto, port: dstPort) { - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) - SDLLogger.shared.log("[SDLContext] identity: \(data.identityID), ruleMap: \(ruleMap), dstPort: \(dstPort) allow", level: .debug) - } - case .udp(let udpPacket): - let dstPort = udpPacket.dstPort - if ruleMap.isAllow(proto: proto, port: dstPort) { - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) - SDLLogger.shared.log("[SDLContext] identity: \(data.identityID), ruleMap: \(ruleMap), dstPort: \(dstPort) allow", level: .debug) - } - case .icmp(_): - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) - default: - () - } - } else { + let ruleMap = identitySnapshot.lookup(data.identityID) + + if self.authIPPacket(ipPacket: ipPacket, ruleMap: ruleMap) { + let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) + self.provider.packetFlow.writePacketObjects([packet]) + SDLLogger.shared.log("[SDLContext] identity: \(data.identityID), allow", level: .debug) + } + else { SDLLogger.shared.log("[SDLContext] not found identity: \(data.identityID) ruleMap", level: .debug) // 向服务器请求权限逻辑 await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient) @@ -597,6 +583,35 @@ actor SDLContextActor { SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug) } } + + private func authIPPacket(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { + // 进来的数据反转一下,然后再处理 + if let reverseFlowSession = ipPacket.flowSession()?.reverse(), + self.flowSessionManager.hasSession(reverseFlowSession) { + self.flowSessionManager.updateSession(reverseFlowSession) + return true + } + + // 检查权限逻辑 + let proto = ipPacket.header.proto + // 优先判断访问规则 + switch ipPacket.transportPacket { + case .tcp(let tcpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { + return true + } + case .udp(let udpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { + return true + } + case .icmp(_): + return true + default: + return false + } + + return false + } // 流量统计 // public func flowReportTask() { @@ -651,6 +666,12 @@ actor SDLContextActor { return } + // 外部出去的数据,需要建立FlowSession + // 外部数据进来的时候需要查找 + if let flowSession = packet.flowSession() { + self.flowSessionManager.updateSession(flowSession) + } + // 查找arp缓存中是否有目标mac地址 if let dstMac = await self.arpServer.query(ip: dstIp) { await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) diff --git a/Tun/Punchnet/NetworkStack/IPPacket.swift b/Tun/Punchnet/NetworkStack/IPPacket.swift index 78d8e7c..20cc63d 100644 --- a/Tun/Punchnet/NetworkStack/IPPacket.swift +++ b/Tun/Punchnet/NetworkStack/IPPacket.swift @@ -44,6 +44,40 @@ struct IPPacket { let header: IPHeader let data: Data + enum TransportPacket { + case tcp(TCPPacket) + case udp(UDPPacket) + case icmp(ICMPPacket) + case unsupported(UInt8) + case malformed + } + + var transportPacket: TransportPacket { + guard let proto = TransportProtocol(rawValue: header.proto) else { + return .unsupported(header.proto) + } + + switch proto { + case .tcp: + guard let tcp = TCPPacket(payload) else { + return .malformed + } + return .tcp(tcp) + + case .udp: + guard let udp = UDPPacket(payload) else { + return .malformed + } + return .udp(udp) + + case .icmp: + guard let icmp = ICMPPacket(payload) else { + return .malformed + } + return .icmp(icmp) + } + } + var payload: Data.SubSequence { let offset = Int(header.headerLength) @@ -161,7 +195,7 @@ struct TCPPacket { } self.header = header - self.payload = data.subdata(in: headerLen.. TransportPacket { - guard let proto = TransportProtocol(rawValue: header.proto) else { - return .unsupported(header.proto) - } - - switch proto { - case .tcp: - guard let tcp = TCPPacket(payload) else { - return .malformed - } - return .tcp(tcp) - - case .udp: - guard let udp = UDPPacket(payload) else { - return .malformed - } - return .udp(udp) - - case .icmp: - guard let icmp = ICMPPacket(payload) else { - return .malformed - } - return .icmp(icmp) - } - } - -} diff --git a/Tun/Punchnet/SDLFlowSessionManager.swift b/Tun/Punchnet/SDLFlowSessionManager.swift new file mode 100644 index 0000000..9fbca75 --- /dev/null +++ b/Tun/Punchnet/SDLFlowSessionManager.swift @@ -0,0 +1,126 @@ +// +// FiveTuple.swift +// punchnet +// tcp/udp Flow流管理 +// Created by 安礼成 on 2026/3/10. +// +import Foundation + +// MARK: - 五元组 key +struct FlowSession: Hashable { + let srcIP: UInt32 + let dstIP: UInt32 + let srcPort: UInt16 + let dstPort: UInt16 + let proto: UInt8 + + func hash(into hasher: inout Hasher) { + // 高效组合 hash + hasher.combine(srcIP) + hasher.combine(dstIP) + hasher.combine(UInt32(srcPort) << 16 | UInt32(dstPort)) + hasher.combine(proto) + } + + static func ==(lhs: Self, rhs: Self) -> Bool { + return lhs.srcIP == rhs.srcIP && + lhs.dstIP == rhs.dstIP && + lhs.srcPort == rhs.srcPort && + lhs.dstPort == rhs.dstPort && + lhs.proto == rhs.proto + } + + func reverse() -> FlowSession { + return FlowSession( + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + proto: proto + ) + } + +} + +// MARK: - 会话管理器 +final class SDLFlowSessionManager { + private var sessions: [FlowSession: TimeInterval] = [:] + private let lock = NSLock() + private let sessionTimeout: TimeInterval + + /// - Parameter sessionTimeout: 会话闲置多久(秒)被清理 + init(sessionTimeout: TimeInterval = 300) { + self.sessionTimeout = sessionTimeout + } + + // 插入或更新会话 + func updateSession(_ key: FlowSession) { + lock.lock() + defer { + lock.unlock() + } + sessions[key] = Date().timeIntervalSince1970 + sessionTimeout + } + + // 查找会话 + func hasSession(_ key: FlowSession) -> Bool { + lock.lock() + defer { + lock.unlock() + } + + if let expireTs = sessions[key] { + if expireTs >= Date().timeIntervalSince1970 { + return true + } + self.sessions.removeValue(forKey: key) + } + + return false + } + + // 删除会话 + func removeSession(_ key: FlowSession) { + lock.lock() + defer { + lock.unlock() + } + + sessions.removeValue(forKey: key) + } + + // 清理过期会话 + func cleanupExpiredSessions() { + lock.lock() + defer { + lock.unlock() + } + + let now = Date().timeIntervalSince1970 + self.sessions = self.sessions.filter { $0.value >= now } + } + + // 返回当前会话数(调试/统计用) + var count: Int { + lock.lock() + defer { + lock.unlock() + } + return sessions.count + } + +} + +extension IPPacket { + + func flowSession() -> FlowSession? { + switch self.transportPacket { + case .tcp(let tcpPacket): + return FlowSession(srcIP: header.source, dstIP: header.destination, srcPort: tcpPacket.header.srcPort, dstPort: tcpPacket.header.dstPort, proto: header.proto) + case .udp(let udpPacket): + return FlowSession(srcIP: header.source, dstIP: header.destination, srcPort: udpPacket.srcPort, dstPort: udpPacket.dstPort, proto: header.proto) + default: + return nil + } + } +}