// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ actor SDLContextActor { nonisolated let config: SDLConfiguration // nat的网络类型 var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 nonisolated let aesCipher: AESCipher // aes private var aesKey: Data? // rsa的相关配置, public_key是本地生成的 nonisolated let rsaCipher: RSACipher // 依赖的变量 private var udpHole: SDLUDPHole? private var udpHoleWorkers: [Task]? nonisolated let providerAdapter: SDLTunnelProviderAdapter var puncherActor: SDLPuncherActor? // dns的client对象 private var dnsClient: SDLDNSClient? private var dnsWorker: Task? // 网络探测对象 var proberActor: SDLNATProberActor? // 数据包读取任务 private var readTask: Task<(), Never>? private var sessionManager: SessionManager private var arpServer: ArpServer // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? private var monitorWorker: Task? // 内部socket通讯 private var noticeClient: SDLNoticeClient? // 流量统计 nonisolated private let flowTracer = SDLFlowTracer() nonisolated private let logger: SDLLogger // 处理内部的需要长时间运行的任务 private var loopChildWorkers: [Task] = [] public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger self.config = config self.rsaCipher = rsaCipher self.aesCipher = aesCipher self.sessionManager = SessionManager() self.arpServer = ArpServer(known_macs: [:]) self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger) } public func start() { self.startMonitor() self.loopChildWorkers.append(spawnLoop { let noticeClient = try self.startNoticeClient() self.logger.log("[SDLContext] noticeClient running!!!!") try await noticeClient.waitClose() self.logger.log("[SDLContext] noticeClient closed!!!!") }) self.loopChildWorkers.append(spawnLoop { let dnsClient = try await self.startDnsClient() self.logger.log("[SDLContext] dns running!!!!") try await dnsClient.waitClose() self.logger.log("[SDLContext] dns closed!!!!") }) self.loopChildWorkers.append(spawnLoop { let udpHole = try await self.startUDPHole() self.logger.log("[SDLContext] udp running!!!!") try await udpHole.waitClose() self.logger.log("[SDLContext] udp closed!!!!") }) } private func startNoticeClient() throws -> SDLNoticeClient { // 启动noticeClient let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) noticeClient.start() self.logger.log("[SDLContext] noticeClient started") self.noticeClient = noticeClient return noticeClient } private func startMonitor() { self.monitorWorker?.cancel() self.monitorWorker = nil // 启动monitor let monitor = SDLNetworkMonitor() monitor.start() self.logger.log("[SDLContext] monitor started") self.monitor = monitor self.monitorWorker = Task.detached { for await event in monitor.eventStream { switch event { case .changed: // 需要重新探测网络的nat类型 //self.natType = await self.getNatType() self.logger.log("didNetworkPathChanged, nat type is:", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) } } } } private func startDnsClient() async throws -> SDLDNSClient { self.dnsWorker?.cancel() self.dnsWorker = nil // 启动dns服务 let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) try dnsClient.start() self.logger.log("[SDLContext] dnsClient started") self.dnsClient = dnsClient self.dnsWorker = Task.detached { // 处理事件流 for await packet in dnsClient.packetFlow { if Task.isCancelled { break } let nePacket = NEPacket(data: packet, protocolFamily: 2) self.providerAdapter.writePackets(packets: [nePacket]) } } return dnsClient } private func startUDPHole() async throws -> SDLUDPHole { self.udpHoleWorkers?.forEach {$0.cancel()} self.udpHoleWorkers = nil // 启动udp服务器 let udpHole = try SDLUDPHole(logger: self.logger) try udpHole.start() self.logger.log("[SDLContext] udpHole started") self.udpHole = udpHole await udpHole.channelIsActived() await self.handleUDPHoleReady() // 处理心跳逻辑 let pingTask = Task.detached { let (stream, cont) = AsyncStream.makeStream(of: Void.self) let timerStream = SDLAsyncTimerStream() timerStream.start(cont) for await _ in stream { if Task.isCancelled { break } self.logger.log("[SDLContext] will do stunRequest22") await self.sendStunRequest() self.logger.log("[SDLContext] will do stunRequest44") } self.logger.log("[SDLContext] will do stunRequest55") } // 处理数据流 let dataTask = Task.detached { for await data in udpHole.dataStream { if Task.isCancelled { break } try? await self.handleData(data: data) } } // 处理控制信号 let signalTask = Task.detached { for await(remoteAddress, signal) in udpHole.signalStream { if Task.isCancelled { break } switch signal { case .registerSuperAck(let registerSuperAck): await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) case .registerSuperNak(let registerSuperNak): await self.handleRegisterSuperNak(nakPacket: registerSuperNak) case .peerInfo(let peerInfo): await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) case .event(let event): try? await self.handleEvent(event: event) case .stunProbeReply(let probeReply): await self.proberActor?.handleProbeReply(reply: probeReply) case .register(let register): try? await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) } } } self.udpHoleWorkers = [pingTask, dataTask, signalTask] return udpHole } // 处理context的停止问题 public func stop() async { self.loopChildWorkers.forEach { $0.cancel() } self.loopChildWorkers.removeAll() self.udpHoleWorkers?.forEach { $0.cancel() } self.udpHoleWorkers = nil self.dnsWorker?.cancel() self.dnsWorker = nil self.monitorWorker?.cancel() self.monitorWorker = nil self.readTask?.cancel() self.readTask = nil } private func setNatType(natType: SDLNATProberActor.NatType) { self.natType = natType } // 处理和super的协商问题 private func handleUDPHoleReady() async { guard let udpHole = self.udpHole else { return } self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger) self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) // 开始探测nat的类型 Task.detached { if let natType = await self.proberActor?.probeNatType() { await self.setNatType(natType: natType) self.logger.log("[SDLContext] nat_type is: \(natType)") } } // 注册 var registerSuper = SDLRegisterSuper() registerSuper.pktID = 0 registerSuper.clientID = self.config.clientId registerSuper.networkID = self.config.networkAddress.networkId registerSuper.mac = self.config.networkAddress.mac registerSuper.ip = self.config.networkAddress.ip registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen) registerSuper.hostname = self.config.hostname registerSuper.pubKey = self.rsaCipher.pubKey registerSuper.accessToken = self.config.accessToken if let registerSuperData = try? registerSuper.serializedData() { self.logger.log("[SDLContext] will send register super") self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress) } } private func sendStunRequest() { var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId stunRequest.ip = self.config.networkAddress.ip stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) self.logger.log("[SDLContext] will send stun request") if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) } } private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info) // 服务器分配的tun网卡信息 do { let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer) self.logger.log("[SDLContext] setNetworkSettings successed") self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress)) self.logger.log("[SDLContext] send ip successed") self.startReader() self.logger.log("[SDLContext] reader started") } catch let err { self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } self.aesKey = aesKey } private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) { let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) } self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) } private func handleEvent(event: SDLEvent) throws { switch event { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 var register = SDLRegister() register.networkID = self.config.networkAddress.networkId register.srcMac = self.config.networkAddress.mac register.dstMac = sendRegisterEvent.dstMac self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) self.noticeClient?.send(data: alertNotice) exit(-1) } } private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { let networkAddr = config.networkAddress self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { // 回复ack包 var registerAck = SDLRegisterAck() registerAck.networkID = networkAddr.networkId registerAck.srcMac = networkAddr.mac registerAck.dstMac = register.srcMac self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } } private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } } private func handleData(data: SDLData) throws { guard let aesKey = self.aesKey else { return } let mac = LayerPacket.MacAddress(data: data.dstMac) let networkAddr = config.networkAddress guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return } guard let decyptedData = try? self.aesCipher.decypt(aesKey: aesKey, data: Data(data.data)) else { return } let layerPacket = try LayerPacket(layerData: decyptedData) self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == networkAddr.ip { switch arpPacket.opcode { case .request: self.logger.log("[SDLContext] get arp request packet", level: .debug) let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: self.logger.log("[SDLContext] get arp response packet", level: .debug) self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug) } } else { self.logger.log("[SDLContext] get invalid arp packet", level: .debug) } case .ipv4: guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) self.providerAdapter.writePackets(packets: [packet]) default: self.logger.log("[SDLContext] get invalid packet", level: .debug) } } // 流量统计 // public func flowReportTask() { // Task { // // 每分钟汇报一次 // self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() // .sink { _ in // Task { // let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() // await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) // } // } // } // } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task.detached(priority: .high) { while true { if Task.isCancelled { return } let packets = await self.providerAdapter.readPackets() let ipPackets = packets.compactMap { IPPacket($0) } for ipPacket in ipPackets { await self.dealPacket(packet: ipPacket) } } } } // 处理读取的每个数据包 private func dealPacket(packet: IPPacket) { let networkAddr = self.config.networkAddress if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) { let destIp = packet.header.destination_ip self.logger.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug) self.dnsClient?.forward(ipPacket: packet) return } let dstIp = packet.header.destination // 本地通讯, 目标地址是本地服务器的ip地址 if dstIp == networkAddr.ip { let nePacket = NEPacket(data: packet.data, protocolFamily: 2) self.providerAdapter.writePackets(packets: [nePacket]) return } // 查找arp缓存中是否有目标mac地址 if let dstMac = self.arpServer.query(ip: dstIp) { self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } else { self.logger.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug) // 构造arp广播 let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp) self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal()) } } private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) { let networkAddr = self.config.networkAddress // 将数据封装层2层的数据包 let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data) guard let aesKey = self.aesKey, let encodedPacket = try? self.aesCipher.encrypt(aesKey: aesKey, data: layerPacket.marshal()) else { return } // 构造数据包 var dataPacket = SDLData() dataPacket.networkID = networkAddr.networkId dataPacket.srcMac = networkAddr.mac dataPacket.dstMac = dstMac dataPacket.ttl = 255 dataPacket.data = encodedPacket let data = try! dataPacket.serializedData() // 广播地址不要去尝试打洞 if ARPPacket.isBroadcastMac(dstMac) { // 通过super_node进行转发 self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress) } else { // 通过session发送到对端 if let session = self.sessionManager.getSession(toAddress: dstMac) { self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress) self.flowTracer.inc(num: data.count, type: .p2p) } else { // 通过super_node进行转发 self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress) // 流量统计 self.flowTracer.inc(num: data.count, type: .forward) // 尝试打洞 self.puncherActor?.submitRegisterRequest(request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId)) } } } private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task { return Task.detached { while !Task.isCancelled { do { try await body() } catch is CancellationError { break } catch { try? await Task.sleep(nanoseconds: 2_000_000_000) } } } } deinit { self.udpHole = nil self.dnsClient = nil } } private extension UInt32 { // 转换成ip地址 func asIpAddress() -> String { return SDLUtil.int32ToIp(self) } }