diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 2914efd..6301236 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -328,30 +328,9 @@ actor SDLContextActor { } // 处理消息流 + let messageStream = udpHole.messageStream let messageTask = Task.detached { - for await (remoteAddress, message) in udpHole.messageStream { - if Task.isCancelled { - break - } - - switch message { - case .stunProbeReply(let probeReply): - await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) - case .register(let register): - try? await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - case .data(let data): - do { - try await self.handleHoleData(data: data) - } catch let err { - SDLLogger.log("[SDLContext] handleHoleData get err: \(err)") - } - case .stunReply(_): - //SDLLogger.shared.log("[SDLContext] get a stunReply: \(stunReply)") - () - } - } + await self.consumeUDPHoleMessages(stream: messageStream, localAddress: localAddress) } self.udpHole = udpHole @@ -597,133 +576,6 @@ actor SDLContextActor { } } - private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { - let networkAddr = config.networkAddress - SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") - - // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 - if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { - // 回复ack包 - var registerAck = SDLRegisterAck() - registerAck.networkID = networkAddr.networkId - registerAck.srcMac = networkAddr.mac - registerAck.dstMac = register.srcMac - - self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) - // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 - let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) - self.sessionManager.addSession(session: session) - } else { - SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)") - } - } - - private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { - // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 - let networkAddr = config.networkAddress - if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { - let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) - self.sessionManager.addSession(session: session) - } else { - SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)") - } - } - - private func handleHoleData(data: SDLData) async throws { - guard let dataCipher = self.dataCipher else { - return - } - - let mac = LayerPacket.MacAddress(data: data.dstMac) - let networkAddr = config.networkAddress - guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { - return - } - - let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data)) - let layerPacket = try LayerPacket(layerData: decyptedData) - - self.flowTracer.inc(num: decyptedData.count, type: .inbound) - // 处理arp请求 - switch layerPacket.type { - case .arp: - // 判断如果收到的是arp请求 - if let arpPacket = ARPPacket(data: layerPacket.data) { - if arpPacket.targetIP == networkAddr.ip { - switch arpPacket.opcode { - case .request: - SDLLogger.log("[SDLContext] get arp request packet") - let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) - await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) - case .response: - SDLLogger.log("[SDLContext] get arp response packet") - await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) - } - } else { - SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))") - } - } else { - SDLLogger.log("[SDLContext] get invalid arp packet") - } - case .ipv4: - // 有数据是通过出口网关转发的,所有只判断是合法的ip包 - guard let ipPacket = IPPacket(layerPacket.data) else { - return - } - - // 检查权限逻辑 - let identitySnapshot = self.snapshotPublisher.current() - let ruleMap = identitySnapshot.lookup(data.identityID) - - if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { - // 用来做debug - if ipPacket.header.source == 168428037 { - SDLLogger.log("[SDLContext] hole data: \(Array(ipPacket.data)), len: \(ipPacket.data.count)", for: .trace) - } - - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) - SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace) - } - else { - SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug) - // 向服务器请求权限逻辑 - await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient) - } - default: - SDLLogger.log("[SDLContext] get invalid packet", for: .debug) - } - } - - private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { - // 进来的数据反转一下,然后再处理 - if let reverseFlowSession = ipPacket.flowSession()?.reverse(), - self.flowSessionManager.hasSession(reverseFlowSession) { - self.flowSessionManager.updateSession(reverseFlowSession) - return true - } - - // 检查权限逻辑 - let proto = ipPacket.header.proto - // 优先判断访问规则 - switch ipPacket.transportPacket { - case .tcp(let tcpPacket): - if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { - return true - } - case .udp(let udpPacket): - if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { - return true - } - case .icmp(_): - return true - default: - return false - } - - return false - } - // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 @@ -911,6 +763,166 @@ actor SDLContextActor { } } +// 处理从Hole收到的数据 +extension SDLContextActor { + + private func consumeUDPHoleMessages(stream: AsyncStream<(SocketAddress, SDLHoleMessage)>, localAddress: SocketAddress) async { + for await (remoteAddress, message) in stream { + if Task.isCancelled { + break + } + + switch message.inboundMessage { + case .control(let controlMessage): + await self.handleHoleControlMessage(controlMessage, localAddress: localAddress, remoteAddress: remoteAddress) + case .data(let data): + try? await self.handleHoleData(data: data) + } + } + } + + private func handleHoleControlMessage(_ message: SDLHoleControlMessage, localAddress: SocketAddress, remoteAddress: SocketAddress) async { + switch message { + case .stunProbeReply(let probeReply): + await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) + case .register(let register): + try? self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + case .stunReply(_): + //SDLLogger.shared.log("[SDLContext] get a stunReply: \(stunReply)") + () + } + } + + private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { + let networkAddr = config.networkAddress + SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") + + // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 + if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { + // 回复ack包 + var registerAck = SDLRegisterAck() + registerAck.networkID = networkAddr.networkId + registerAck.srcMac = networkAddr.mac + registerAck.dstMac = register.srcMac + + self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) + // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 + let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) + self.sessionManager.addSession(session: session) + } else { + SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)") + } + } + + private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { + // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 + let networkAddr = config.networkAddress + if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { + let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) + self.sessionManager.addSession(session: session) + } else { + SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)") + } + } + + private func handleHoleData(data: SDLData) async throws { + guard let dataCipher = self.dataCipher else { + return + } + + let mac = LayerPacket.MacAddress(data: data.dstMac) + let networkAddr = config.networkAddress + guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { + return + } + + let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data)) + let layerPacket = try LayerPacket(layerData: decyptedData) + + self.flowTracer.inc(num: decyptedData.count, type: .inbound) + // 处理arp请求 + switch layerPacket.type { + case .arp: + // 判断如果收到的是arp请求 + if let arpPacket = ARPPacket(data: layerPacket.data) { + if arpPacket.targetIP == networkAddr.ip { + switch arpPacket.opcode { + case .request: + SDLLogger.log("[SDLContext] get arp request packet") + let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) + await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) + case .response: + SDLLogger.log("[SDLContext] get arp response packet") + await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) + } + } else { + SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))") + } + } else { + SDLLogger.log("[SDLContext] get invalid arp packet") + } + case .ipv4: + // 有数据是通过出口网关转发的,所有只判断是合法的ip包 + guard let ipPacket = IPPacket(layerPacket.data) else { + return + } + + // 检查权限逻辑 + let identitySnapshot = self.snapshotPublisher.current() + let ruleMap = identitySnapshot.lookup(data.identityID) + + if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { + // 用来做debug + if ipPacket.header.source == 168428037 { + SDLLogger.log("[SDLContext] hole data: \(Array(ipPacket.data)), len: \(ipPacket.data.count)", for: .trace) + } + + let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) + self.provider.packetFlow.writePacketObjects([packet]) + SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace) + } + else { + SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug) + // 向服务器请求权限逻辑 + await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient) + } + default: + SDLLogger.log("[SDLContext] get invalid packet", for: .debug) + } + } + + private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { + // 进来的数据反转一下,然后再处理 + if let reverseFlowSession = ipPacket.flowSession()?.reverse(), + self.flowSessionManager.hasSession(reverseFlowSession) { + self.flowSessionManager.updateSession(reverseFlowSession) + return true + } + + // 检查权限逻辑 + let proto = ipPacket.header.proto + // 优先判断访问规则 + switch ipPacket.transportPacket { + case .tcp(let tcpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { + return true + } + case .udp(let udpPacket): + if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { + return true + } + case .icmp(_): + return true + default: + return false + } + + return false + } +} + private extension UInt32 { // 转换成ip地址 func asIpAddress() -> String { diff --git a/Tun/Punchnet/SDLMessage.swift b/Tun/Punchnet/SDLMessage.swift index 06b9ab2..0b436e3 100644 --- a/Tun/Punchnet/SDLMessage.swift +++ b/Tun/Punchnet/SDLMessage.swift @@ -124,6 +124,35 @@ enum SDLHoleMessage { case stunReply(SDLStunReply) } +enum SDLHoleControlMessage { + case register(SDLRegister) + case registerAck(SDLRegisterAck) + case stunProbeReply(SDLStunProbeReply) + case stunReply(SDLStunReply) +} + +enum SDLHoleInboundMessage { + case control(SDLHoleControlMessage) + case data(SDLData) +} + +extension SDLHoleMessage { + var inboundMessage: SDLHoleInboundMessage { + switch self { + case .data(let data): + return .data(data) + case .register(let register): + return .control(.register(register)) + case .registerAck(let registerAck): + return .control(.registerAck(registerAck)) + case .stunProbeReply(let stunProbeReply): + return .control(.stunProbeReply(stunProbeReply)) + case .stunReply(let stunReply): + return .control(.stunReply(stunReply)) + } + } +} + enum SDLQUICInboundMessage { // 欢迎消息 case welcome(SDLWelcome)