// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ actor SDLContextActor { enum State { case unregistered case registered } private var state: State = .unregistered var config: SDLConfiguration // nat的网络类型 var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 private var dataCipher: CCDataCipher? // session token private var sessionToken: Data? // rsa的相关配置, public_key是本地生成的 // 加密算法相关 nonisolated let rsaCipher: RSACipher // 依赖的变量 private var udpHole: SDLUDPHole? private var udpHoleWorkers: [Task]? // dns的client对象 private var dnsClient: DNSCloudClient? private var dnsWorker: Task? // Localdns的client对象 private var dnsLocalClient: DNSLocalClient? private var dnsLocalWorker: Task? private var quicClient: SDLQUICClient? private var quicWorker: Task? nonisolated private let puncherActor: SDLPuncherActor // 网络探测对象 nonisolated private let proberActor: SDLNATProberActor // 数据包读取任务 private var readTask: Task<(), Never>? nonisolated private let sessionManager = SessionManager() nonisolated private let arpServer: ArpServer // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? private var monitorWorker: Task? // 内部socket通讯 private var noticeClient: SDLNoticeClient? // 流量统计 nonisolated private let flowTracer = SDLFlowTracer() // 处理内部的需要长时间运行的任务 private var supervisor = SDLSupervisor() nonisolated private let provider: NEPacketTunnelProvider // 处理权限控制 private let identifyStore: IdentityStore private var updatePolicyTask: Task? private let snapshotPublisher: SnapshotPublisher // Flow流会话管理, 过期时间为: 180秒 private let flowSessionManager = SDLFlowSessionManager(sessionTimeout: 180) // 注册任务 private var registerTask: Task? public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher) { self.provider = provider self.config = config self.rsaCipher = rsaCipher self.puncherActor = SDLPuncherActor() self.proberActor = SDLNATProberActor(addressArray: config.stunProbeSocketAddressArray) self.arpServer = ArpServer() // 权限控制 let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty()) self.identifyStore = IdentityStore(publisher: snapshotPublisher) self.snapshotPublisher = snapshotPublisher } public func start() async { self.startMonitor() // 启动arp的定时清理任务 await self.arpServer.start() await self.startDnsClient() await self.startDnsLocalClient() await self.supervisor.addWorker(name: "quicClient") { SDLLogger.log("[SDLContext] try start quicClient", for: .debug) let quicClient = try await self.startQUICClient() SDLLogger.log("[SDLContext] quicClient running!!!!") await quicClient.waitClose() SDLLogger.log("[SDLContext] quicClient closed!!!!") } await self.supervisor.addWorker(name: "noticeClient") { let noticeClient = try self.startNoticeClient() SDLLogger.log("[SDLContext] noticeClient running!!!!") try await noticeClient.waitClose() SDLLogger.log("[SDLContext] noticeClient closed!!!!") } await self.supervisor.addWorker(name: "udpHole") { let udpHole = try await self.startUDPHole() SDLLogger.log("[SDLContext] udp running!!!!") try await udpHole.waitClose() SDLLogger.log("[SDLContext] udp closed!!!!") } } // 取消出口节点的时候,ip地址为: 0.0.0.0 public func updateExitNode(exitNodeIp: String) async throws { if let ip = SDLUtil.ipv4StrToInt32(exitNodeIp), ip > 0 { self.config.exitNode = .init(exitNodeIp: ip) } else { self.config.exitNode = nil } try await self.setNetworkSettings(config: config, dnsServer: DNSHelper.dnsServer) } private func startQUICClient() async throws -> SDLQUICClient { self.quicWorker?.cancel() self.quicClient?.stop() // 启动monitor let quicClient = SDLQUICClient(host: self.config.serverHost, port: 443) quicClient.start() // 等待quic准备好 try await quicClient.waitReady() // 这里必须等待quic的协商完成 try await Task.sleep(for: .seconds(0.2)) SDLLogger.log("[SDLContext] start quic client: \(self.config.serverHost)") self.quicWorker = Task.detached { for await message in quicClient.messageStream { switch message { case .welcome(let welcome): SDLLogger.log("[SDLContext] quic welcome: \(welcome)") // 注册 await self.startRegisterLoop() case .pong: //SDLLogger.shared.log("[SDLContext] quic pong") () case .registerSuperAck(let registerSuperAck): await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) case .registerSuperNak(let registerSuperNak): await self.handleRegisterSuperNak(nakPacket: registerSuperNak) case .peerInfo(let peerInfo): //SDLLogger.shared.log("[SDLContext] peer message: \(peerInfo)") await self.puncherActor.handlePeerInfo(using: self.udpHole, peerInfo: peerInfo) case .event(let event): await self.handleEvent(event: event) case .policyReponse(let policyResponse): // 处理权限的请求问题 await self.identifyStore.applyPolicyResponse(policyResponse) case .arpResponse(let arpResponse): //SDLLogger.shared.log("[SDLContext] get arp response: \(arpResponse)") await self.arpServer.handleArpResponse(arpResponse: arpResponse) } } } self.quicClient = quicClient return quicClient } private func startNoticeClient() throws -> SDLNoticeClient { // 启动noticeClient let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort) noticeClient.start() SDLLogger.log("[SDLContext] noticeClient started") self.noticeClient = noticeClient return noticeClient } private func startMonitor() { self.monitorWorker?.cancel() self.monitorWorker = nil // 启动monitor let monitor = SDLNetworkMonitor() monitor.start() SDLLogger.log("[SDLContext] monitor started") self.monitor = monitor self.monitorWorker = Task.detached { for await event in monitor.eventStream { switch event { case .changed: // 需要重新探测网络的nat类型 await self.probeNatType() SDLLogger.log("didNetworkPathChanged, nat type is:") case .unreachable: SDLLogger.log("didNetworkPathUnreachable") } } } } private func startDnsClient() async { self.dnsWorker?.cancel() self.dnsWorker = nil // 启动dns服务 let dnsClient = DNSCloudClient(host: self.config.serverIp, port: 15353) dnsClient.start() SDLLogger.log("[SDLContext] dnsClient started") self.dnsClient = dnsClient self.dnsWorker = Task.detached { // 处理事件流 for await packet in dnsClient.packetFlow { if Task.isCancelled { break } let nePacket = NEPacket(data: packet, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([nePacket]) } } } private func startDnsLocalClient() async { self.dnsLocalWorker?.cancel() self.dnsLocalWorker = nil // 启动dns服务 let dnsLocalClient = DNSLocalClient() dnsLocalClient.start() SDLLogger.log("[SDLContext] dnsClient started") self.dnsLocalClient = dnsLocalClient self.dnsLocalWorker = Task.detached { // 处理事件流 for await packet in dnsLocalClient.packetFlow { if Task.isCancelled { break } // 要想办法构造一个完整的Ip包 let nePacket = NEPacket(data: packet, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([nePacket]) } } } private func startUDPHole() async throws -> SDLUDPHole { self.udpHoleWorkers?.forEach {$0.cancel()} self.udpHoleWorkers = nil // 启动udp服务器 let udpHole = try SDLUDPHole() try udpHole.start() SDLLogger.log("[SDLContext] udpHole started") // 获取当前udp启动的地址 let localAddress = udpHole.getLocalAddress() // 阻塞等待udpHole是准备好的状态 await udpHole.channelIsActived() // 处理心跳逻辑 let pingTask = Task.detached { let timerStream = SDLAsyncTimerStream() timerStream.start(interval: .seconds(5)) for await _ in timerStream.stream { if Task.isCancelled { break } await self.sendStunRequest() } SDLLogger.log("[SDLContext] udp pingTask cancel") } // 处理消息流 let messageTask = Task.detached { for await (remoteAddress, message) in udpHole.messageStream { if Task.isCancelled { break } switch message { case .stunProbeReply(let probeReply): await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) case .register(let register): try? await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) case .data(let data): do { try await self.handleHoleData(data: data) } catch let err { SDLLogger.log("[SDLContext] handleHoleData get err: \(err)") } case .stunReply(_): //SDLLogger.shared.log("[SDLContext] get a stunReply: \(stunReply)") () } } SDLLogger.log("[SDLContext] udp signalTask cancel") } self.udpHole = udpHole self.udpHoleWorkers = [pingTask, messageTask] // 开始探测nat的类型 await self.probeNatType() return udpHole } // 处理context的停止问题 public func stop() async { await self.supervisor.stop() self.udpHoleWorkers?.forEach { $0.cancel() } self.udpHoleWorkers = nil self.quicWorker?.cancel() self.quicWorker = nil self.dnsClient?.stop() self.dnsWorker?.cancel() self.dnsWorker = nil self.dnsLocalClient?.stop() self.dnsLocalWorker?.cancel() self.dnsLocalWorker = nil self.monitorWorker?.cancel() self.monitorWorker = nil self.readTask?.cancel() self.readTask = nil self.registerTask?.cancel() self.registerTask = nil self.updatePolicyTask?.cancel() self.updatePolicyTask = nil } private func setNatType(natType: SDLNATProberActor.NatType) { self.natType = natType } // 开启注册任务 private func startRegisterLoop() { guard self.registerTask == nil else { return } self.registerTask = Task { while !Task.isCancelled { self.doRegisterSuper() try? await Task.sleep(for: .seconds(5)) if self.state == .registered { await self.whenRegistedSuper() break } SDLLogger.log("[SDLContext] register super failed, retry") } self.registerTask = nil } } // 注册成功super的回调函数 private func whenRegistedSuper() async { self.updatePolicyTask?.cancel() self.updatePolicyTask = Task { while !Task.isCancelled { try? await Task.sleep(for: .seconds(300)) SDLLogger.log("[SDLContext] updatePolicyTask execute") await self.identifyStore.batUpdatePolicy(using: self.quicClient, dstIdentityID: self.config.identityId) } } } private func sendStunRequest() { guard let sessionToken = self.sessionToken else { return } var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId stunRequest.ip = self.config.networkAddress.ip stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) stunRequest.sessionToken = sessionToken if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) } } private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 guard let key = try? self.rsaCipher.decode(data: Data(registerSuperAck.key)) else { SDLLogger.log("[SDLContext] registerSuperAck invalid key") self.provider.cancelTunnelWithError(SDLError.invalidKey) return } let algorithm = registerSuperAck.algorithm.lowercased() let regionId = registerSuperAck.regionID self.sessionToken = registerSuperAck.sessionToken switch algorithm { case "aes": self.dataCipher = CCAESChiper(key: key) case "chacha20": self.dataCipher = CCChaCha20Cipher(regionId: regionId, keyData: key) default: SDLLogger.log("[SDLContext] registerSuperAck invalid algorithm \(algorithm)") self.provider.cancelTunnelWithError(SDLError.unsupportedAlgorithm(algorithm: algorithm)) return } SDLLogger.log("[SDLContext] registerSuperAck, use algorithm \(algorithm), key len: \(key.count)") // 服务器分配的tun网卡信息 do { try await self.setNetworkSettings(config: self.config, dnsServer: DNSHelper.dnsServer) SDLLogger.log("[SDLContext] setNetworkSettings successed") self.state = .registered self.startReader() } catch let err { SDLLogger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)") self.provider.cancelTunnelWithError(err) } } private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) { let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) // 报告错误并退出 let error = NSError(domain: "com.jihe.punchnet.tun", code: -1) self.provider.cancelTunnelWithError(error) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.alert(alert: errorMessage) self.noticeClient?.send(data: alertNotice) } SDLLogger.log("[SDLContext] Get a SuperNak message exit") } private func handleEvent(event: SDLEvent) async { switch event.event { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac SDLLogger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)") sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): SDLLogger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)") let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 var register = SDLRegister() register.networkID = self.config.networkAddress.networkId register.srcMac = self.config.networkAddress.mac register.dstMac = sendRegisterEvent.dstMac self.udpHole?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress) } case .shutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) self.noticeClient?.send(data: alertNotice) // 报告错误并退出 let error = NSError(domain: "com.jihe.punchnet.tun", code: -2) self.provider.cancelTunnelWithError(error) case .none: () } } private func doRegisterSuper() { // 注册 var registerSuper = SDLRegisterSuper() registerSuper.clientID = self.config.clientId registerSuper.networkID = self.config.networkAddress.networkId registerSuper.mac = self.config.networkAddress.mac registerSuper.ip = self.config.networkAddress.ip registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen) registerSuper.hostname = self.config.hostname registerSuper.pubKey = self.rsaCipher.pubKey registerSuper.accessToken = self.config.accessToken if let registerSuperData = try? registerSuper.serializedData() { SDLLogger.log("[SDLContext] will send register super") self.quicClient?.send(type: .registerSuper, data: registerSuperData) } } private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { let networkAddr = config.networkAddress SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId { // 回复ack包 var registerAck = SDLRegisterAck() registerAck.networkID = networkAddr.networkId registerAck.srcMac = networkAddr.mac registerAck.dstMac = register.srcMac self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)") } } private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) self.sessionManager.addSession(session: session) } else { SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)") } } private func handleHoleData(data: SDLData) async throws { guard let dataCipher = self.dataCipher else { return } let mac = LayerPacket.MacAddress(data: data.dstMac) let networkAddr = config.networkAddress guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return } let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data)) let layerPacket = try LayerPacket(layerData: decyptedData) self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == networkAddr.ip { switch arpPacket.opcode { case .request: SDLLogger.log("[SDLContext] get arp request packet") let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip) await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: SDLLogger.log("[SDLContext] get arp response packet") await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))") } } else { SDLLogger.log("[SDLContext] get invalid arp packet") } case .ipv4: // 有数据是通过出口网关转发的,所有只判断是合法的ip包 guard let ipPacket = IPPacket(layerPacket.data) else { return } // 检查权限逻辑 let identitySnapshot = self.snapshotPublisher.current() let ruleMap = identitySnapshot.lookup(data.identityID) if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { // 用来做debug if ipPacket.header.source == 168428037 { SDLLogger.log("[SDLContext] hole data: \(Array(ipPacket.data)), len: \(ipPacket.data.count)", for: .trace) } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([packet]) SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace) } else { SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug) // 向服务器请求权限逻辑 await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient) } default: SDLLogger.log("[SDLContext] get invalid packet", for: .debug) } } private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { // 进来的数据反转一下,然后再处理 if let reverseFlowSession = ipPacket.flowSession()?.reverse(), self.flowSessionManager.hasSession(reverseFlowSession) { self.flowSessionManager.updateSession(reverseFlowSession) return true } // 检查权限逻辑 let proto = ipPacket.header.proto // 优先判断访问规则 switch ipPacket.transportPacket { case .tcp(let tcpPacket): if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) { return true } case .udp(let udpPacket): if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) { return true } case .icmp(_): return true default: return false } return false } // 流量统计 // public func flowReportTask() { // Task { // // 每分钟汇报一次 // self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() // .sink { _ in // Task { // let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() // await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) // } // } // } // } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task.detached(priority: .high) { while true { if Task.isCancelled { return } let (packets, numbers) = await self.provider.packetFlow.readPackets() for (data, number) in zip(packets, numbers) where number == 2 { SDLLogger.log("[SDLContext] read Tun packet step 1, data count: \(data.count)", for: .trace) if let ipPacket = IPPacket(data) { SDLLogger.log("[SDLContext] read Tun packet step 2, data count: \(ipPacket.data.count)", for: .trace) await self.dealTunPacket(packet: ipPacket) } } } } } // 处理读取的每个数据包, Tun收到的包的一级路由 private func dealTunPacket(packet: IPPacket) async { let networkAddr = self.config.networkAddress let dstIp = packet.header.destination // 本地通讯, 目标地址是本地服务器的ip地址 if dstIp == networkAddr.ip { let nePacket = NEPacket(data: packet.data, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([nePacket]) } // 处理dns的解析 else if DNSHelper.isDnsRequestPacket(ipPacket: packet) { if case .udp(let udpPacket) = packet.transportPacket { // 数据是通过offset解析的, dns查询必然是udp包 let payloadOffset = udpPacket.payloadOffset let dnsParser = DNSParser(data: packet.data, offset: payloadOffset) if let dnsMessage = dnsParser.parse(), let name = dnsMessage.questions.first?.name { // 如果是内部域名,则转发整个ip包的内容到云端服务器 if name.contains(self.config.networkAddress.networkDomain) { SDLLogger.log("[SDLContext] get cloud dns request: \(name)") self.dnsClient?.forward(ipPacketData: packet.data) } // 如果开启了出口节点,则转发给出口节点 else if let exitNode = config.exitNode { let exitNodeIp = exitNode.exitNodeIp SDLLogger.log("[SDLContext] dstIp: \(packet.header.destination.asIpAddress()), use exit_node: \(exitNodeIp.asIpAddress())") // 查找arp缓存中是否有目标mac地址 if let dstMac = await self.arpServer.query(ip: exitNodeIp) { await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } else { try? await self.arpServer.arpRequest(targetIp: exitNodeIp, use: self.quicClient) } } // 通过本地的dns解析,发送的是udp的payload部分 else { SDLLogger.log("[SDLContext] get local dns request: \(name)") let dnsPayload = Data(packet.data[payloadOffset.. Void) -> Task { return Task.detached { while !Task.isCancelled { do { try await body() } catch is CancellationError { break } catch { try? await Task.sleep(nanoseconds: 2_000_000_000) } } } } private func getIpv4ExcludeRoutes() -> [NEIPv4Route] { // 要排除的路由表 let dnsServers = SDLUtil.getMacOSSystemDnsServers() var ipv4DnsServers = dnsServers.filter {!$0.contains(":")} // 增加常见的dns服务 let commonDnsServers = [ "8.8.8.8", "8.8.4.4", "223.5.5.5", "223.6.6.6", "114.114.114.114" ] for ip in commonDnsServers { if !ipv4DnsServers.contains(ip) { ipv4DnsServers.append(ip) } } return ipv4DnsServers.map { NEIPv4Route(destinationAddress: $0, subnetMask: "255.255.255.255") } } deinit { self.udpHole = nil self.dnsClient = nil } } private extension UInt32 { // 转换成ip地址 func asIpAddress() -> String { return SDLUtil.int32ToIp(self) } }