// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore import Combine // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ @available(macOS 14, *) public class SDLContextNew { // 路由信息 struct Route { let dstAddress: String let subnetMask: String var debugInfo: String { return "\(dstAddress):\(subnetMask)" } } let config: SDLConfiguration // tun网络地址信息 var devAddr: SDLDevAddr // nat映射的相关信息, 暂时没有用处 //var natAddress: SDLNatAddress? // nat的网络类型 var natType: SDLNatProber.NatType = .blocked // AES加密,授权通过后,对象才会被创建 var aesCipher: AESCipher // aes var aesKey: Data = Data() // rsa的相关配置, public_key是本地生成的 let rsaCipher: RSACipher // 依赖的变量 var udpHoleActor: SDLUDPHoleActor? var superClientActor: SDLSuperClientActor? var providerActor: SDLTunnelProviderActor // dns的client对象 var dnsClient: DNSClient? // 数据包读取任务 private var readTask: Task<(), Never>? private var sessionManager: SessionManager private var arpServer: ArpServer // 记录最后发送的stunRequest的cookie private var lastCookie: UInt32? = 0 // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? // 内部socket通讯 private var noticeClient: SDLNoticeClient? // 流量统计 private var flowTracer = SDLFlowTracerActor() private var flowTracerCancel: AnyCancellable? // 处理holer private var holerPublishers: [Data:PassthroughSubject] = [:] private var bag = Set() private var locker = NSLock() private let logger: SDLLogger private var rootTask: Task? struct RegisterRequest { let srcMac: Data let dstMac: Data let networkId: UInt32 } public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger self.config = config self.rsaCipher = rsaCipher self.aesCipher = aesCipher // 生成mac地址 var devAddr = SDLDevAddr() devAddr.mac = Self.getMacAddress() self.devAddr = devAddr self.sessionManager = SessionManager() self.arpServer = ArpServer(known_macs: [:]) self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger) } public func start() async throws { self.rootTask = Task { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { while !Task.isCancelled { do { try await self.startDnsClient() } catch let err { self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } group.addTask { while !Task.isCancelled { do { try await self.startUDPHole() } catch let err { self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } group.addTask { while !Task.isCancelled { do { try await self.startSuperClient() } catch let err { self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning) await self.arpServer.clear() try await Task.sleep(for: .seconds(2)) } } } group.addTask { await self.startMonitor() } group.addTask { while !Task.isCancelled { do { try await self.startNoticeClient() } catch let err { self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) try await Task.sleep(for: .seconds(2)) } } } try await group.waitForAll() } } try await self.rootTask?.value } public func stop() async { self.rootTask?.cancel() self.superClientActor = nil self.udpHoleActor = nil self.noticeClient = nil self.readTask?.cancel() } private func startNoticeClient() async throws { self.noticeClient = try await SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) try await self.noticeClient?.start() self.logger.log("[SDLContext] notice_client task cancel", level: .warning) } private func startUDPHole() async throws { self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.udpHoleActor?.start() } group.addTask { while !Task.isCancelled { try Task.checkCancellation() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) try Task.checkCancellation() if let udpHoleActor = self.udpHoleActor { let cookie = await udpHoleActor.getCookieId() var stunRequest = SDLStunRequest() stunRequest.cookie = cookie stunRequest.clientID = self.config.clientId stunRequest.networkID = self.devAddr.networkID stunRequest.ip = self.devAddr.netAddr stunRequest.mac = self.devAddr.mac stunRequest.natType = UInt32(self.natType.rawValue) let remoteAddress = self.config.stunSocketAddress await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress) self.lastCookie = cookie } } } group.addTask { if let eventFlow = self.udpHoleActor?.eventFlow { for try await event in eventFlow { try Task.checkCancellation() try await self.handleUDPEvent(event: event) } } } if let _ = try await group.next() { group.cancelAll() } } } private func startSuperClient() async throws { self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort, logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in defer { self.logger.log("[SDLContext] super client task cancel", level: .warning) } group.addTask { try await self.superClientActor?.start() } group.addTask { if let eventFlow = self.superClientActor?.eventFlow { for try await event in eventFlow { try await self.handleSuperEvent(event: event) } } } if let _ = try await group.next() { group.cancelAll() } } } private func startMonitor() async { self.monitor = SDLNetworkMonitor() for await event in self.monitor!.eventStream { switch event { case .changed: // TODO 需要重新探测网络的nat类型 //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) } } } private func startDnsClient() async throws { let remoteDnsServer = config.remoteDnsServer let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353) self.dnsClient = try await DNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in defer { self.logger.log("[SDLContext] dns client task cancel", level: .warning) } group.addTask { try await self.dnsClient?.start() } group.addTask { if let packetFlow = self.dnsClient?.packetFlow { for await packet in packetFlow { let nePacket = NEPacket(data: packet, protocolFamily: 2) await self.providerActor.writePackets(packets: [nePacket]) } } } if let _ = try await group.next() { group.cancelAll() } } } private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws { switch event { case .ready: self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug) var registerSuper = SDLRegisterSuper() registerSuper.version = UInt32(self.config.version) registerSuper.clientID = self.config.clientId registerSuper.devAddr = self.devAddr registerSuper.pubKey = self.rsaCipher.pubKey registerSuper.token = self.config.token registerSuper.networkCode = self.config.networkCode registerSuper.hostname = self.config.hostname guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()).get() else { return } switch message.packet { case .registerSuperAck(let registerSuperAck): // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType) self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count), network_id:\(registerSuperAck.devAddr.networkID)", level: .info) self.devAddr = registerSuperAck.devAddr if upgradeType == .force { let forceUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) await self.noticeClient?.send(data: forceUpgrade) exit(-1) } // 服务器分配的tun网卡信息 do { try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) self.startReader() } catch let err { self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } self.aesKey = aesKey if upgradeType == .normal { let normalUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) await self.noticeClient?.send(data: normalUpgrade) } case .registerSuperNak(let nakPacket): let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.alert(alert: errorMessage) await self.noticeClient?.send(data: alertNotice) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.alert(alert: errorMessage) await self.noticeClient?.send(data: alertNotice) } self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) default: () } case .event(let evt): switch evt { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) await sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 var register = SDLRegister() register.networkID = self.devAddr.networkID register.srcMac = self.devAddr.mac register.dstMac = sendRegisterEvent.dstMac await self.udpHoleActor?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) await self.noticeClient?.send(data: alertNotice) exit(-1) } case .command(let packetId, let command): switch command { case .changeNetwork(let changeNetworkCommand): // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey)) self.logger.log("[SDLContext] change network command get aes_key len: \(aesKey.count)", level: .info) self.devAddr = changeNetworkCommand.devAddr // 服务器分配的tun网卡信息 do { try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) self.startReader() } catch let err { self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } self.aesKey = aesKey var commandAck = SDLCommandAck() commandAck.status = true await self.superClientActor?.send(type: .commandAck, packetId: packetId, data: try commandAck.serializedData()) } } } private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws { switch event { case .ready: // 获取当前网络的类型 //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config) self.logger.log("[SDLContext] nat type is: \(self.natType)", level: .debug) case .message(let remoteAddress, let message): switch message { case .register(let register): self.logger.log("register packet: \(register), dev_addr: \(self.devAddr)", level: .debug) // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { // 回复ack包 var registerAck = SDLRegisterAck() registerAck.networkID = self.devAddr.networkID registerAck.srcMac = self.devAddr.mac registerAck.dstMac = register.srcMac await self.udpHoleActor?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } case .registerAck(let registerAck): // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } case .stunReply(let stunReply): let cookie = stunReply.cookie if cookie == self.lastCookie { // 记录下当前在nat上的映射信息,暂时没有用;后续会用来判断网络类型 //self.natAddress = stunReply.natAddress self.logger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())", level: .debug) } default: () } case .data(let data): let mac = LayerPacket.MacAddress(data: data.dstMac) guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { return } guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else { return } do { let layerPacket = try LayerPacket(layerData: decyptedData) await self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == self.devAddr.netAddr { switch arpPacket.opcode { case .request: self.logger.log("[SDLContext] get arp request packet", level: .debug) let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr) await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: self.logger.log("[SDLContext] get arp response packet", level: .debug) await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(self.devAddr.netAddr))", level: .debug) } } else { self.logger.log("[SDLContext] get invalid arp packet", level: .debug) } case .ipv4: guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) await self.providerActor.writePackets(packets: [packet]) default: self.logger.log("[SDLContext] get invalid packet", level: .debug) } } catch let err { self.logger.log("[SDLContext] didReadData err: \(err)", level: .warning) } } } // 流量统计 // public func flowReportTask() { // Task { // // 每分钟汇报一次 // self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() // .sink { _ in // Task { // let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() // await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) // } // } // } // } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task(priority: .high) { repeat { let packets = await self.providerActor.readPackets() for packet in packets { await self.dealPacket(data: packet) } } while true } } // 处理读取的每个数据包 private func dealPacket(data: Data) async { guard let packet = IPPacket(data) else { return } if DNSClient.Helper.isDnsRequestPacket(ipPacket: packet) { let destIp = packet.header.destination_ip NSLog("destIp: \(destIp), int: \(packet.header.destination)") await self.dnsClient?.forward(ipPacket: packet) } else { Task.detached { let dstIp = packet.header.destination // 本地通讯, 目标地址是本地服务器的ip地址 if dstIp == self.devAddr.netAddr { let nePacket = NEPacket(data: packet.data, protocolFamily: 2) await self.providerActor.writePackets(packets: [nePacket]) return } // 查找arp缓存中是否有目标mac地址 if let dstMac = await self.arpServer.query(ip: dstIp) { await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } else { // 构造arp请求 let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]) let arpReqeust = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp) await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal()) self.logger.log("[SDLContext] dstIp: \(dstIp) arp query not found", level: .debug) } } } } private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async { // 将数据封装层2层的数据包 let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data) guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else { return } // 构造数据包 var dataPacket = SDLData() dataPacket.networkID = self.devAddr.networkID dataPacket.srcMac = self.devAddr.mac dataPacket.dstMac = dstMac dataPacket.ttl = 255 dataPacket.data = encodedPacket // 通过session发送到对端 if let session = await self.sessionManager.getSession(toAddress: dstMac) { self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: session.natAddress) await self.flowTracer.inc(num: data.count, type: .p2p) } else { // 通过super_node进行转发 let superAddress = self.config.stunSocketAddress await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: superAddress) // 流量统计 await self.flowTracer.inc(num: data.count, type: .forward) // 尝试打洞 let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID) self.submitRegisterRequest(request: registerRequest) } } private func submitRegisterRequest(request: RegisterRequest) { self.locker.lock() defer { self.locker.unlock() } let dstMac = request.dstMac if let publisher = self.holerPublishers[dstMac] { publisher.send(request) } else { let publisher = PassthroughSubject() publisher.throttle(for: .seconds(5), scheduler: DispatchQueue.global(), latest: true) .sink { request in Task { await self.tryHole(request: request) } } .store(in: &self.bag) self.holerPublishers[dstMac] = publisher } } private func tryHole(request: RegisterRequest) async { var queryInfo = SDLQueryInfo() queryInfo.dstMac = request.dstMac guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else { return } switch message.packet { case .empty: self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug) case .peerInfo(let peerInfo): if let remoteAddress = peerInfo.v4Info.socketAddress() { self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug) // 发送register包 var register = SDLRegister() register.networkID = request.networkId register.srcMac = request.srcMac register.dstMac = request.dstMac await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress) } else { self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) } default: self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning) } } deinit { self.rootTask?.cancel() self.udpHoleActor = nil self.superClientActor = nil self.dnsClient = nil } // 获取mac地址 public static func getMacAddress() -> Data { let key = "gMacAddress2" let userDefaults = UserDefaults.standard if let mac = userDefaults.value(forKey: key) as? Data { return mac } else { let mac = generateMacAddress() userDefaults.setValue(mac, forKey: key) return mac } } // 随机生成mac地址 private static func generateMacAddress() -> Data { var macAddress = [UInt8](repeating: 0, count: 6) for i in 0..<6 { macAddress[i] = UInt8.random(in: 0...255) } return Data(macAddress) } }