diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index d1a6ab1..f7564e1 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -832,94 +832,41 @@ extension SDLContextActor { } } - private func handleHoleData(data: SDLData) async throws { - guard let dataCipher = self.dataCipher else { - return - } - - let mac = LayerPacket.MacAddress(data: data.dstMac) - let networkAddr = config.networkAddress - guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { - return - } - - let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data)) - let layerPacket = try LayerPacket(layerData: decyptedData) - - self.flowTracer.inc(num: decyptedData.count, type: .inbound) - // 处理arp请求 - switch layerPacket.type { - case .arp: - // 判断如果收到的是arp请求 - if let arpPacket = ARPPacket(data: layerPacket.data) { - if arpPacket.targetIP == networkAddr.ip { - switch arpPacket.opcode { - case .request: - SDLLogger.log("[SDLContext] get arp request packet") - let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) - await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) - case .response: - SDLLogger.log("[SDLContext] get arp response packet") - await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) - } - } else { - SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))") - } - } else { - SDLLogger.log("[SDLContext] get invalid arp packet") - } - case .ipv4: - // 有数据是通过出口网关转发的,所有只判断是合法的ip包 - guard let ipPacket = IPPacket(layerPacket.data) else { - return - } - - // 检查权限逻辑 - let identitySnapshot = self.snapshotPublisher.current() - let ruleMap = identitySnapshot.lookup(data.identityID) - - if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) - SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace) - } - else { - SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug) - // 向服务器请求权限逻辑 - await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient) - } - default: - SDLLogger.log("[SDLContext] get invalid packet", for: .debug) - } + private func makeHoleDataProcessor() -> SDLHoleDataProcessor { + return .init( + networkAddress: self.config.networkAddress, + dataCipher: self.dataCipher, + snapshotPublisher: self.snapshotPublisher, + flowSessionManager: self.flowSessionManager + ) } - private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { - // 进来的数据反转一下,然后再处理 - if let reverseFlowSession = ipPacket.flowSession()?.reverse(), - self.flowSessionManager.hasSession(reverseFlowSession) { - self.flowSessionManager.updateSession(reverseFlowSession) - return true + private func handleHoleData(data: SDLData) async throws { + let processor = self.makeHoleDataProcessor() + guard let plan = try processor.makeProcessingPlan(data: data) else { + return } - // 检查权限逻辑 - let proto = ipPacket.header.proto - // 优先判断访问规则 - switch ipPacket.transportPacket { - case .tcp(let tcpPacket): - if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { - return true - } - case .udp(let udpPacket): - if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { - return true - } - case .icmp(_): - return true - default: - return false - } + self.flowTracer.inc(num: plan.inboundBytes, type: .inbound) - return false + switch plan.action { + case .sendARPReply(let dstMac, let responseData): + SDLLogger.log("[SDLContext] get arp request packet") + await self.routeLayerPacket(dstMac: dstMac, type: .arp, data: responseData) + case .appendARP(let ip, let mac): + SDLLogger.log("[SDLContext] get arp response packet") + await self.arpServer.append(ip: ip, mac: mac) + case .writeToTun(let packetData, let identityID): + let packet = NEPacket(data: packetData, protocolFamily: 2) + self.provider.packetFlow.writePacketObjects([packet]) + SDLLogger.log("[SDLContext] hole identity: \(identityID), allow, data count: \(packetData.count)", for: .trace) + case .requestPolicy(let srcIdentityID): + SDLLogger.log("[SDLContext] not found identity: \(srcIdentityID) ruleMap", for: .debug) + // 向服务器请求权限逻辑 + await self.identifyStore.policyRequest(srcIdentityId: srcIdentityID, dstIdentityId: self.config.identityId, using: self.quicClient) + case .none: + () + } } } diff --git a/Tun/Punchnet/Actors/SDLHoleDataProcessor.swift b/Tun/Punchnet/Actors/SDLHoleDataProcessor.swift new file mode 100644 index 0000000..2b0b71a --- /dev/null +++ b/Tun/Punchnet/Actors/SDLHoleDataProcessor.swift @@ -0,0 +1,143 @@ +// +// SDLHoleDataProcessor.swift +// Tun +// +// Created by 安礼成 on 2026/4/14. +// + +import Foundation + +final class SDLHoleDataProcessor { + enum ProcessingAction { + case sendARPReply(dstMac: Data, data: Data) + case appendARP(ip: UInt32, mac: Data) + case writeToTun(packetData: Data, identityID: UInt32) + case requestPolicy(srcIdentityID: UInt32) + case none + } + + struct ProcessingPlan { + let inboundBytes: Int + let action: ProcessingAction + } + + private let networkAddress: SDLConfiguration.NetworkAddress + private let dataCipher: CCDataCipher? + private let snapshotPublisher: SnapshotPublisher + private let flowSessionManager: SDLFlowSessionManager + + init(networkAddress: SDLConfiguration.NetworkAddress, + dataCipher: CCDataCipher?, + snapshotPublisher: SnapshotPublisher, + flowSessionManager: SDLFlowSessionManager) { + self.networkAddress = networkAddress + self.dataCipher = dataCipher + self.snapshotPublisher = snapshotPublisher + self.flowSessionManager = flowSessionManager + } + + func makeProcessingPlan(data: SDLData) throws -> ProcessingPlan? { + guard let dataCipher = self.dataCipher else { + return nil + } + + let mac = LayerPacket.MacAddress(data: data.dstMac) + guard (data.dstMac == self.networkAddress.mac || mac.isBroadcast() || mac.isMulticast()) else { + return nil + } + + let decryptedData = try dataCipher.decrypt(cipherText: Data(data.data)) + let layerPacket = try LayerPacket(layerData: decryptedData) + let inboundBytes = decryptedData.count + + // 处理arp请求 + switch layerPacket.type { + case .arp: + return self.makeARPPlan(layerData: layerPacket.data, inboundBytes: inboundBytes) + case .ipv4: + return self.makeIPv4Plan(layerData: layerPacket.data, identityID: data.identityID, inboundBytes: inboundBytes) + default: + SDLLogger.log("[SDLContext] get invalid packet", for: .debug) + return .init(inboundBytes: inboundBytes, action: .none) + } + } + + private func makeARPPlan(layerData: Data, inboundBytes: Int) -> ProcessingPlan { + // 判断如果收到的是arp请求 + if let arpPacket = ARPPacket(data: layerData) { + if arpPacket.targetIP == self.networkAddress.ip { + switch arpPacket.opcode { + case .request: + let response = ARPPacket.arpResponse(for: arpPacket, mac: self.networkAddress.mac, ip: self.networkAddress.ip) + return .init( + inboundBytes: inboundBytes, + action: .sendARPReply(dstMac: arpPacket.senderMAC, data: response.marshal()) + ) + case .response: + return .init( + inboundBytes: inboundBytes, + action: .appendARP(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) + ) + } + } else { + SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(self.networkAddress.ip))") + } + } else { + SDLLogger.log("[SDLContext] get invalid arp packet") + } + + return .init(inboundBytes: inboundBytes, action: .none) + } + + private func makeIPv4Plan(layerData: Data, identityID: UInt32, inboundBytes: Int) -> ProcessingPlan { + // 有数据是通过出口网关转发的,所有只判断是合法的ip包 + guard let ipPacket = IPPacket(layerData) else { + return .init(inboundBytes: inboundBytes, action: .none) + } + + // 检查权限逻辑 + let identitySnapshot = self.snapshotPublisher.current() + let ruleMap = identitySnapshot.lookup(identityID) + + if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { + return .init( + inboundBytes: inboundBytes, + action: .writeToTun(packetData: ipPacket.data, identityID: identityID) + ) + } + + return .init( + inboundBytes: inboundBytes, + action: .requestPolicy(srcIdentityID: identityID) + ) + } + + private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { + // 进来的数据反转一下,然后再处理 + if let reverseFlowSession = ipPacket.flowSession()?.reverse(), + self.flowSessionManager.hasSession(reverseFlowSession) { + self.flowSessionManager.updateSession(reverseFlowSession) + return true + } + + // 检查权限逻辑 + let proto = ipPacket.header.proto + // 优先判断访问规则 + switch ipPacket.transportPacket { + case .tcp(let tcpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { + return true + } + case .udp(let udpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { + return true + } + case .icmp(_): + return true + default: + return false + } + + return false + } +}