// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ actor SDLContextActor { nonisolated let config: SDLConfiguration // nat的网络类型 var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 nonisolated let aesCipher: AESCipher // aes var aesKey: Data = Data() // rsa的相关配置, public_key是本地生成的 nonisolated let rsaCipher: RSACipher // 依赖的变量 var udpHole: SDLUDPHole? nonisolated let providerAdapter: SDLTunnelProviderAdapter var puncherActor: SDLPuncherActor? // dns的client对象 var dnsClient: SDLDNSClient? // 网络探测对象 var proberActor: SDLNATProberActor? // 数据包读取任务 private var readTask: Task<(), Never>? private var sessionManager: SessionManager private var arpServer: ArpServer // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? // 内部socket通讯 private var noticeClient: SDLNoticeClient? // 流量统计 nonisolated private let flowTracer = SDLFlowTracer() nonisolated private let logger: SDLLogger public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger self.config = config self.rsaCipher = rsaCipher self.aesCipher = aesCipher self.sessionManager = SessionManager() self.arpServer = ArpServer(known_macs: [:]) self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger) } public func startNoticeClient() async throws { // 启动noticeClient self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) self.logger.log("[SDLContext] noticeClient started") try await self.noticeClient?.waitClose() } public func startMonitor() async throws { // 启动monitor let monitor = SDLNetworkMonitor() monitor.start() self.logger.log("[SDLContext] monitor started") self.monitor = monitor for await event in monitor.eventStream { switch event { case .changed: // 需要重新探测网络的nat类型 //self.natType = await self.getNatType() self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) } } } public func startDnsClient() async throws { // 启动dns服务 let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) let channel = try dnsClient.start() self.logger.log("[SDLContext] dnsClient started") self.dnsClient = dnsClient try await withThrowingTaskGroup(of: Void.self) {group in group.addTask { // 处理事件流 for await packet in dnsClient.packetFlow { try Task.checkCancellation() let nePacket = NEPacket(data: packet, protocolFamily: 2) self.providerAdapter.writePackets(packets: [nePacket]) } } group.addTask { try await channel.closeFuture.get() } try await group.next() self.logger.log("[SDLContext] taskGroup cancel") group.cancelAll() } } public func startUDPHole() async throws { // 启动udp服务器 let udpHole = try SDLUDPHole(logger: self.logger) let channel = try udpHole.start() self.logger.log("[SDLContext] udpHole started") self.udpHole = udpHole try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await channel.closeFuture.get() } // 处理UDP的事件流 group.addTask { while true { try Task.checkCancellation() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) try Task.checkCancellation() await self.sendStunRequest() } } // 处理event事件流 group.addTask { for try await event in udpHole.eventStream { try Task.checkCancellation() switch event { case .ready: await self.handleUDPHoleReady() case .closed: () } } } // 处理数据流 group.addTask { for try await data in udpHole.dataStream { try Task.checkCancellation() try await self.handleData(data: data) } } // 处理signal信号流 group.addTask { for try await(remoteAddress, signal) in udpHole.signalStream { try Task.checkCancellation() switch signal { case .registerSuperAck(let registerSuperAck): await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) case .registerSuperNak(let registerSuperNak): await self.handleRegisterSuperNak(nakPacket: registerSuperNak) case .peerInfo(let peerInfo): await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) case .event(let event): try await self.handleEvent(event: event) case .stunProbeReply(let probeReply): await self.proberActor?.handleProbeReply(reply: probeReply) case .register(let register): try await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) } } } try await group.next() group.cancelAll() self.logger.log("[SDLContext] taskGroup cancel") } } public func stop() async { self.udpHole = nil self.noticeClient = nil self.readTask?.cancel() } private func setNatType(natType: SDLNATProberActor.NatType) { self.natType = natType } private func handleUDPHoleReady() async { if let udpHole = self.udpHole { self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger) self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) } await withDiscardingTaskGroup { group in group.addTask { // 开始探测nat的类型 if let natType = await self.proberActor?.probeNatType() { await self.setNatType(natType: natType) self.logger.log("[SDLContext] nat_type is: \(natType)") } } group.addTask { var registerSuper = SDLRegisterSuper() registerSuper.pktID = 0 registerSuper.clientID = self.config.clientId registerSuper.networkID = self.config.networkAddress.networkId registerSuper.mac = self.config.networkAddress.mac registerSuper.ip = self.config.networkAddress.ip registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen) registerSuper.hostname = self.config.hostname registerSuper.pubKey = self.rsaCipher.pubKey registerSuper.accessToken = self.config.accessToken if let registerSuperData = try? registerSuper.serializedData() { self.logger.log("[SDLContext] will send register super") await self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress) } } } } private func sendStunRequest() async { var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId stunRequest.ip = self.config.networkAddress.ip stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) } } private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info) // 服务器分配的tun网卡信息 do { let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer) self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress)) self.startReader() } catch let err { self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } self.aesKey = aesKey } private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) { let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) } self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) } private func handleEvent(event: SDLEvent) throws { switch event { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 var register = SDLRegister() register.networkID = self.config.networkAddress.networkId register.srcMac = self.config.networkAddress.mac register.dstMac = sendRegisterEvent.dstMac self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) self.noticeClient?.send(data: alertNotice) exit(-1) } } private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { let networkAddr = config.networkAddress self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { // 回复ack包 var registerAck = SDLRegisterAck() registerAck.networkID = networkAddr.networkId registerAck.srcMac = networkAddr.mac registerAck.dstMac = register.srcMac self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } } private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } } private func handleData(data: SDLData) throws { let mac = LayerPacket.MacAddress(data: data.dstMac) let networkAddr = config.networkAddress guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return } guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else { return } let layerPacket = try LayerPacket(layerData: decyptedData) self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == networkAddr.ip { switch arpPacket.opcode { case .request: self.logger.log("[SDLContext] get arp request packet", level: .debug) let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: self.logger.log("[SDLContext] get arp response packet", level: .debug) self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug) } } else { self.logger.log("[SDLContext] get invalid arp packet", level: .debug) } case .ipv4: guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) self.providerAdapter.writePackets(packets: [packet]) default: self.logger.log("[SDLContext] get invalid packet", level: .debug) } } // 流量统计 // public func flowReportTask() { // Task { // // 每分钟汇报一次 // self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() // .sink { _ in // Task { // let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() // await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) // } // } // } // } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task(priority: .high) { while true { if Task.isCancelled { return } let packets = await self.providerAdapter.readPackets() let ipPackets = packets.compactMap { IPPacket($0) } await self.batchProcessPackets(batchSize: 20, packets: ipPackets) } } } // 批量分发ip数据包 private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async { for startIndex in stride(from: 0, to: packets.count, by: batchSize) { let endIndex = Swift.min(startIndex + batchSize, packets.count) let chunkPackets = packets[startIndex.. String { return SDLUtil.int32ToIp(self) } }