// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore import Combine // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ @available(macOS 14, *) public class SDLContext { // 路由信息 struct Route { let dstAddress: String let subnetMask: String var debugInfo: String { return "\(dstAddress):\(subnetMask)" } } let config: SDLConfiguration // nat的网络类型 var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 var aesCipher: AESCipher // aes var aesKey: Data = Data() // rsa的相关配置, public_key是本地生成的 let rsaCipher: RSACipher // 依赖的变量 var udpHoleActor: SDLUDPHoleActor? var providerActor: SDLTunnelProviderActor var puncherActor: SDLPuncherActor // dns的client对象 var dnsClientActor: SDLDNSClientActor? // 网络探测对象 var proberActor: SDLNATProberActor? // 数据包读取任务 private var readTask: Task<(), Never>? private var sessionManager: SessionManager private var arpServer: ArpServer // 记录最后发送的stunRequest的cookie private var lastCookie: UInt32? = 0 // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? // 内部socket通讯 private var noticeClientActor: SDLNoticeClientActor? // 流量统计 private var flowTracer = SDLFlowTracerActor() private var flowTracerCancel: AnyCancellable? private let logger: SDLLogger private var rootTask: Task? public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger self.config = config self.rsaCipher = rsaCipher self.aesCipher = aesCipher self.sessionManager = SessionManager() self.arpServer = ArpServer(known_macs: [:]) self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger) self.puncherActor = SDLPuncherActor(querySocketAddress: config.stunSocketAddress, logger: logger) } public func start() async throws { self.rootTask = Task { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { while !Task.isCancelled { do { try await self.startDnsClient() } catch let err { self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } group.addTask { while !Task.isCancelled { do { try await self.startUDPHole() } catch let err { self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } group.addTask { await self.startMonitor() } group.addTask { while !Task.isCancelled { do { try await self.startNoticeClient() } catch let err { self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } try await group.waitForAll() } } try await self.rootTask?.value } public func stop() async { self.rootTask?.cancel() self.udpHoleActor = nil self.noticeClientActor = nil self.readTask?.cancel() } private func startNoticeClient() async throws { self.noticeClientActor = try SDLNoticeClientActor(noticePort: self.config.noticePort, logger: self.logger) try await self.noticeClientActor?.start() self.logger.log("[SDLContext] notice_client task cancel", level: .warning) } private func startUDPHole() async throws { self.udpHoleActor = try SDLUDPHoleActor(logger: self.logger) try await self.udpHoleActor?.start() try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { while !Task.isCancelled { try Task.checkCancellation() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) try Task.checkCancellation() await self.sendStunRequest() } } group.addTask { if let eventStream = self.udpHoleActor?.eventStream { for try await event in eventStream { try Task.checkCancellation() try? await self.dispatchEvent(event: event) } } } if let _ = try await group.next() { group.cancelAll() } } } private func startMonitor() async { self.monitor = SDLNetworkMonitor() for await event in self.monitor!.eventStream { switch event { case .changed: // 需要重新探测网络的nat类型 //self.natType = await self.getNatType() self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) } } } private func startDnsClient() async throws { let remoteDnsServer = config.remoteDnsServer let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353) self.dnsClientActor = try await SDLDNSClientActor(dnsServerAddress: dnsSocketAddress, logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in defer { self.logger.log("[SDLContext] dns client task cancel", level: .warning) } group.addTask { try await self.dnsClientActor?.start() } group.addTask { if let packetFlow = self.dnsClientActor?.packetFlow { for await packet in packetFlow { let nePacket = NEPacket(data: packet, protocolFamily: 2) await self.providerActor.writePackets(packets: [nePacket]) } } } if let _ = try await group.next() { group.cancelAll() } } } private func handleUDPHoleReady() async throws { await self.puncherActor.setUDPHoleActor(udpHoleActor: self.udpHoleActor) // 开始探测nat的类型 if let udpHoleActor = self.udpHoleActor { self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) self.natType = await self.proberActor!.probeNatType() } var registerSuper = SDLRegisterSuper() registerSuper.pktID = 0 registerSuper.clientID = self.config.clientId registerSuper.networkID = self.config.networkAddress.networkId registerSuper.mac = self.config.networkAddress.mac registerSuper.ip = self.config.networkAddress.ip registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen) registerSuper.hostname = self.config.hostname registerSuper.pubKey = self.rsaCipher.pubKey registerSuper.accessToken = self.config.accessToken await self.udpHoleActor?.send(type: .registerSuper, data: try registerSuper.serializedData(), remoteAddress: self.config.stunSocketAddress) } private func sendStunRequest() async { var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId stunRequest.ip = self.config.networkAddress.ip stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress await self.udpHoleActor?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) } } private func dispatchEvent(event: SDLUDPHoleActor.UDPHoleEvent) async throws { switch event { case .ready: try await self.handleUDPHoleReady() case .message(let remoteAddress, let message): switch message { case .registerSuperAck(let registerSuperAck): await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) case .registerSuperNak(let registerSuperNak): await self.handleRegisterSuperNak(nakPacket: registerSuperNak) case .peerInfo(let peerInfo): await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) case .event(let event): try await self.handleEvent(event: event) case .stunProbeReply(let probeReply): await self.proberActor?.handleProbeReply(reply: probeReply) case .data(let data): try await self.handleData(data: data) case .register(let register): try await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) } } } private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info) // 服务器分配的tun网卡信息 do { let ipAddress = try await self.providerActor.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClientActor.Helper.dnsServer) await self.noticeClientActor?.send(data: NoticeMessage.ipAdress(ip: ipAddress)) self.startReader() } catch let err { self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } self.aesKey = aesKey } private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) async { let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.alert(alert: errorMessage) await self.noticeClientActor?.send(data: alertNotice) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.alert(alert: errorMessage) await self.noticeClientActor?.send(data: alertNotice) } self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) } private func handleEvent(event: SDLEvent) async throws { switch event { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) await sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 var register = SDLRegister() register.networkID = self.config.networkAddress.networkId register.srcMac = self.config.networkAddress.mac register.dstMac = sendRegisterEvent.dstMac await self.udpHoleActor?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) await self.noticeClientActor?.send(data: alertNotice) exit(-1) } } private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { let networkAddr = config.networkAddress self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { // 回复ack包 var registerAck = SDLRegisterAck() registerAck.networkID = networkAddr.networkId registerAck.srcMac = networkAddr.mac registerAck.dstMac = register.srcMac await self.udpHoleActor?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } } private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } } private func handleData(data: SDLData) async throws { let mac = LayerPacket.MacAddress(data: data.dstMac) let networkAddr = config.networkAddress guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return } guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else { return } let layerPacket = try LayerPacket(layerData: decyptedData) await self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == networkAddr.ip { switch arpPacket.opcode { case .request: self.logger.log("[SDLContext] get arp request packet", level: .debug) let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: self.logger.log("[SDLContext] get arp response packet", level: .debug) await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug) } } else { self.logger.log("[SDLContext] get invalid arp packet", level: .debug) } case .ipv4: guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) await self.providerActor.writePackets(packets: [packet]) default: self.logger.log("[SDLContext] get invalid packet", level: .debug) } } // 流量统计 // public func flowReportTask() { // Task { // // 每分钟汇报一次 // self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() // .sink { _ in // Task { // let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() // await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) // } // } // } // } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task(priority: .high) { repeat { let packets = await self.providerActor.readPackets() let ipPackets = packets.compactMap { IPPacket($0) } await self.batchProcessPackets(batchSize: 20, packets: ipPackets) } while true } } // 批量分发ip数据包 private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async { for startIndex in stride(from: 0, to: packets.count, by: batchSize) { let endIndex = Swift.min(startIndex + batchSize, packets.count) let chunkPackets = packets[startIndex.. String { return SDLUtil.int32ToIp(self) } }