From 4e0247f6489452bfa5447dca4da037087d49fc85 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Wed, 7 Jan 2026 16:10:44 +0800 Subject: [PATCH] add actors --- Tun/Punchnet/Actors/SDLContextNew.swift | 695 ++++++++++++++++++ Tun/Punchnet/Actors/SDLSuperClientActor.swift | 283 +++++++ .../Actors/SDLTunnelProviderActor.swift | 87 +++ Tun/Punchnet/Actors/SDLUDPHoleActor.swift | 242 ++++++ 4 files changed, 1307 insertions(+) create mode 100644 Tun/Punchnet/Actors/SDLContextNew.swift create mode 100644 Tun/Punchnet/Actors/SDLSuperClientActor.swift create mode 100644 Tun/Punchnet/Actors/SDLTunnelProviderActor.swift create mode 100644 Tun/Punchnet/Actors/SDLUDPHoleActor.swift diff --git a/Tun/Punchnet/Actors/SDLContextNew.swift b/Tun/Punchnet/Actors/SDLContextNew.swift new file mode 100644 index 0000000..d9b809e --- /dev/null +++ b/Tun/Punchnet/Actors/SDLContextNew.swift @@ -0,0 +1,695 @@ +// +// SDLContext.swift +// Tun +// +// Created by 安礼成 on 2024/2/29. +// + +import Foundation +import NetworkExtension +import NIOCore +import Combine + +// 上下文环境变量,全局共享 +/* +1. 处理rsa的加解密逻辑 + */ + +@available(macOS 14, *) +public class SDLContextNew { + + // 路由信息 + struct Route { + let dstAddress: String + let subnetMask: String + + var debugInfo: String { + return "\(dstAddress):\(subnetMask)" + } + } + + let config: SDLConfiguration + + // tun网络地址信息 + var devAddr: SDLDevAddr + + // nat映射的相关信息, 暂时没有用处 + //var natAddress: SDLNatAddress? + // nat的网络类型 + var natType: SDLNatProber.NatType = .blocked + + // AES加密,授权通过后,对象才会被创建 + var aesCipher: AESCipher + + // aes + var aesKey: Data = Data() + + // rsa的相关配置, public_key是本地生成的 + let rsaCipher: RSACipher + + // 依赖的变量 + var udpHoleActor: SDLUDPHoleActor? + var superClientActor: SDLSuperClientActor? + + var providerActor: SDLTunnelProviderActor + + // dns的client对象 + var dnsClient: DNSClient? + + // 数据包读取任务 + private var readTask: Task<(), Never>? + + private var sessionManager: SessionManager + private var arpServer: ArpServer + + // 记录最后发送的stunRequest的cookie + private var lastCookie: UInt32? = 0 + + // 网络状态变化的健康 + private var monitor: SDLNetworkMonitor? + + // 内部socket通讯 + private var noticeClient: SDLNoticeClient? + + // 流量统计 + private var flowTracer = SDLFlowTracerActor() + private var flowTracerCancel: AnyCancellable? + + // 处理holer + private var holerPublishers: [Data:PassthroughSubject] = [:] + private var bag = Set() + private var locker = NSLock() + + private let logger: SDLLogger + private var rootTask: Task? + + struct RegisterRequest { + let srcMac: Data + let dstMac: Data + let networkId: UInt32 + } + + public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { + self.logger = logger + + self.config = config + self.rsaCipher = rsaCipher + self.aesCipher = aesCipher + + // 生成mac地址 + var devAddr = SDLDevAddr() + devAddr.mac = Self.getMacAddress() + self.devAddr = devAddr + + self.sessionManager = SessionManager() + self.arpServer = ArpServer(known_macs: [:]) + self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger) + } + + public func start() async throws { + self.rootTask = Task { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while !Task.isCancelled { + do { + try await self.startDnsClient() + } catch let err { + self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) + try await Task.sleep(for: .seconds(2)) + } + } + } + + group.addTask { + while !Task.isCancelled { + do { + try await self.startUDPHole() + } catch let err { + self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) + try await Task.sleep(for: .seconds(2)) + } + } + } + + group.addTask { + while !Task.isCancelled { + do { + try await self.startSuperClient() + } catch let err { + self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning) + await self.arpServer.clear() + try await Task.sleep(for: .seconds(2)) + } + } + } + + group.addTask { + await self.startMonitor() + } + + group.addTask { + while !Task.isCancelled { + do { + try await self.startNoticeClient() + } catch let err { + self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) + try await Task.sleep(for: .seconds(2)) + } + } + } + + try await group.waitForAll() + } + } + + try await self.rootTask?.value + } + + public func stop() async { + self.rootTask?.cancel() + self.superClientActor = nil + self.udpHoleActor = nil + self.noticeClient = nil + + self.readTask?.cancel() + } + + private func startNoticeClient() async throws { + self.noticeClient = try await SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) + try await self.noticeClient?.start() + self.logger.log("[SDLContext] notice_client task cancel", level: .warning) + } + + private func startUDPHole() async throws { + self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await self.udpHoleActor?.start() + } + + group.addTask { + while !Task.isCancelled { + try await Task.sleep(nanoseconds: 5 * 1_000_000_000) + + if let udpHoleActor = self.udpHoleActor { + let cookie = await udpHoleActor.getCookieId() + var stunRequest = SDLStunRequest() + stunRequest.cookie = cookie + stunRequest.clientID = self.config.clientId + stunRequest.networkID = self.devAddr.networkID + stunRequest.ip = self.devAddr.netAddr + stunRequest.mac = self.devAddr.mac + stunRequest.natType = UInt32(self.natType.rawValue) + + let remoteAddress = self.config.stunSocketAddress + await udpHoleActor.stunRequest(request: stunRequest, remoteAddress: remoteAddress) + self.lastCookie = cookie + } + } + } + + group.addTask { + if let eventFlow = self.udpHoleActor?.eventFlow { + for try await event in eventFlow { + try await self.handleUDPEvent(event: event) + } + } + } + + if let _ = try await group.next() { + group.cancelAll() + } + } + + } + + private func startSuperClient() async throws { + self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort, logger: self.logger) + try await withThrowingTaskGroup(of: Void.self) { group in + defer { + self.logger.log("[SDLContext] super client task cancel", level: .warning) + } + + group.addTask { + try await self.superClientActor?.start() + } + + group.addTask { + if let eventFlow = self.superClientActor?.eventFlow { + for try await event in eventFlow { + try await self.handleSuperEvent(event: event) + } + } + } + + if let _ = try await group.next() { + group.cancelAll() + } + } + } + + private func startMonitor() async { + self.monitor = SDLNetworkMonitor() + for await event in self.monitor!.eventStream { + switch event { + case .changed: + // TODO 需要重新探测网络的nat类型 + //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) + self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) + case .unreachable: + self.logger.log("didNetworkPathUnreachable", level: .warning) + } + } + } + + private func startDnsClient() async throws { + let remoteDnsServer = config.remoteDnsServer + let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353) + self.dnsClient = try await DNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) + + try await withThrowingTaskGroup(of: Void.self) { group in + defer { + self.logger.log("[SDLContext] dns client task cancel", level: .warning) + } + + group.addTask { + try await self.dnsClient?.start() + } + + group.addTask { + if let packetFlow = self.dnsClient?.packetFlow { + for await packet in packetFlow { + let nePacket = NEPacket(data: packet, protocolFamily: 2) + await self.providerActor.writePackets(packets: [nePacket]) + } + } + + } + + if let _ = try await group.next() { + group.cancelAll() + } + } + } + + private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws { + switch event { + case .ready: + self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug) + var registerSuper = SDLRegisterSuper() + registerSuper.version = UInt32(self.config.version) + registerSuper.clientID = self.config.clientId + registerSuper.devAddr = self.devAddr + registerSuper.pubKey = self.rsaCipher.pubKey + registerSuper.token = self.config.token + registerSuper.networkCode = self.config.networkCode + registerSuper.hostname = self.config.hostname + guard let message = try await self.superClientActor?.registerSuper(register: registerSuper).get() else { + return + } + + switch message.packet { + case .registerSuperAck(let registerSuperAck): + // 需要对数据通过rsa的私钥解码 + let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) + let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType) + + self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count), network_id:\(registerSuperAck.devAddr.networkID)", level: .info) + self.devAddr = registerSuperAck.devAddr + + if upgradeType == .force { + let forceUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) + await self.noticeClient?.send(data: forceUpgrade) + exit(-1) + } + + // 服务器分配的tun网卡信息 + do { + try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) + self.startReader() + } catch let err { + self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) + exit(-1) + } + + self.aesKey = aesKey + if upgradeType == .normal { + let normalUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) + await self.noticeClient?.send(data: normalUpgrade) + } + + case .registerSuperNak(let nakPacket): + let errorMessage = nakPacket.errorMessage + guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { + return + } + + switch errorCode { + case .invalidToken, .nodeDisabled: + let alertNotice = NoticeMessage.alert(alert: errorMessage) + await self.noticeClient?.send(data: alertNotice) + exit(-1) + case .noIpAddress, .networkFault, .internalFault: + let alertNotice = NoticeMessage.alert(alert: errorMessage) + await self.noticeClient?.send(data: alertNotice) + } + self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) + default: + () + } + + case .event(let evt): + switch evt { + case .natChanged(let natChangedEvent): + let dstMac = natChangedEvent.mac + self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) + await sessionManager.removeSession(dstMac: dstMac) + case .sendRegister(let sendRegisterEvent): + self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) + let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) + if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { + // 发送register包 + var register = SDLRegister() + register.networkID = self.devAddr.networkID + register.srcMac = self.devAddr.mac + register.dstMac = sendRegisterEvent.dstMac + await self.udpHoleActor?.sendRegister(register: register, remoteAddress: remoteAddress) + } + + case .networkShutdown(let shutdownEvent): + let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) + await self.noticeClient?.send(data: alertNotice) + exit(-1) + } + case .command(let packetId, let command): + switch command { + case .changeNetwork(let changeNetworkCommand): + // 需要对数据通过rsa的私钥解码 + let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey)) + self.logger.log("[SDLContext] change network command get aes_key len: \(aesKey.count)", level: .info) + self.devAddr = changeNetworkCommand.devAddr + + // 服务器分配的tun网卡信息 + do { + try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) + self.startReader() + } catch let err { + self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) + exit(-1) + } + + self.aesKey = aesKey + + var commandAck = SDLCommandAck() + commandAck.status = true + + await self.superClientActor?.commandAck(packetId: packetId, ack: commandAck) + } + } + + } + + private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws { + switch event { + case .ready: + // 获取当前网络的类型 + //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config) + self.logger.log("[SDLContext] nat type is: \(self.natType)", level: .debug) + + case .message(let remoteAddress, let message): + switch message { + case .register(let register): + self.logger.log("register packet: \(register), dev_addr: \(self.devAddr)", level: .debug) + // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 + if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { + // 回复ack包 + var registerAck = SDLRegisterAck() + registerAck.networkID = self.devAddr.networkID + registerAck.srcMac = self.devAddr.mac + registerAck.dstMac = register.srcMac + + await self.udpHoleActor?.sendRegisterAck(registerAck: registerAck, remoteAddress: remoteAddress) + + // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 + let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) + await self.sessionManager.addSession(session: session) + } else { + self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) + } + case .registerAck(let registerAck): + // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 + if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID { + let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) + await self.sessionManager.addSession(session: session) + } else { + self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) + } + case .stunReply(let stunReply): + let cookie = stunReply.cookie + if cookie == self.lastCookie { + // 记录下当前在nat上的映射信息,暂时没有用;后续会用来判断网络类型 + //self.natAddress = stunReply.natAddress + self.logger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())", level: .debug) + } + default: + () + } + + case .data(let data): + let mac = LayerPacket.MacAddress(data: data.dstMac) + guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { + return + } + + guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else { + return + } + + do { + let layerPacket = try LayerPacket(layerData: decyptedData) + + await self.flowTracer.inc(num: decyptedData.count, type: .inbound) + // 处理arp请求 + switch layerPacket.type { + case .arp: + // 判断如果收到的是arp请求 + if let arpPacket = ARPPacket(data: layerPacket.data) { + if arpPacket.targetIP == self.devAddr.netAddr { + switch arpPacket.opcode { + case .request: + self.logger.log("[SDLContext] get arp request packet", level: .debug) + let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr) + await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) + case .response: + self.logger.log("[SDLContext] get arp response packet", level: .debug) + await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) + } + } else { + self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(self.devAddr.netAddr))", level: .debug) + } + } else { + self.logger.log("[SDLContext] get invalid arp packet", level: .debug) + } + case .ipv4: + guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else { + return + } + let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) + await self.providerActor.writePackets(packets: [packet]) + default: + self.logger.log("[SDLContext] get invalid packet", level: .debug) + } + } catch let err { + self.logger.log("[SDLContext] didReadData err: \(err)", level: .warning) + } + } + + } + + // 流量统计 +// public func flowReportTask() { +// Task { +// // 每分钟汇报一次 +// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() +// .sink { _ in +// Task { +// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() +// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) +// } +// } +// } +// } + + // 开始读取数据, 用单独的线程处理packetFlow + private func startReader() { + // 停止之前的任务 + self.readTask?.cancel() + + // 开启新的任务 + self.readTask = Task(priority: .high) { + repeat { + let packets = await self.providerActor.readPackets() + for packet in packets { + await self.dealPacket(data: packet) + } + } while true + } + } + + // 处理读取的每个数据包 + private func dealPacket(data: Data) async { + guard let packet = IPPacket(data) else { + return + } + + if DNSClient.Helper.isDnsRequestPacket(ipPacket: packet) { + let destIp = packet.header.destination_ip + NSLog("destIp: \(destIp), int: \(packet.header.destination)") + await self.dnsClient?.forward(ipPacket: packet) + } + else { + Task.detached { + let dstIp = packet.header.destination + // 本地通讯, 目标地址是本地服务器的ip地址 + if dstIp == self.devAddr.netAddr { + let nePacket = NEPacket(data: packet.data, protocolFamily: 2) + await self.providerActor.writePackets(packets: [nePacket]) + return + } + + // 查找arp缓存中是否有目标mac地址 + if let dstMac = await self.arpServer.query(ip: dstIp) { + await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) + } + else { + // 构造arp请求 + let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]) + let arpReqeust = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp) + await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal()) + + self.logger.log("[SDLContext] dstIp: \(dstIp) arp query not found", level: .debug) + } + } + } + } + + private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async { + // 将数据封装层2层的数据包 + let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data) + guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else { + return + } + + // 构造数据包 + var dataPacket = SDLData() + dataPacket.networkID = self.devAddr.networkID + dataPacket.srcMac = self.devAddr.mac + dataPacket.dstMac = dstMac + dataPacket.ttl = 255 + dataPacket.data = encodedPacket + + // 通过session发送到对端 + if let session = await self.sessionManager.getSession(toAddress: dstMac) { + self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) + await self.udpHoleActor?.sendPacket(data: dataPacket, remoteAddress: session.natAddress) + await self.flowTracer.inc(num: data.count, type: .p2p) + } + else { + // 通过super_node进行转发 + let superAddress = self.config.stunSocketAddress + await self.udpHoleActor?.sendPacket(data: dataPacket, remoteAddress: superAddress) + // 流量统计 + await self.flowTracer.inc(num: data.count, type: .forward) + + // 尝试打洞 + let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID) + self.submitRegisterRequest(request: registerRequest) + } + } + + private func submitRegisterRequest(request: RegisterRequest) { + self.locker.lock() + defer { + self.locker.unlock() + } + + let dstMac = request.dstMac + if let publisher = self.holerPublishers[dstMac] { + publisher.send(request) + } else { + let publisher = PassthroughSubject() + publisher.throttle(for: .seconds(5), scheduler: DispatchQueue.global(), latest: true) + .sink { request in + Task { + await self.tryHole(request: request) + } + } + .store(in: &self.bag) + + self.holerPublishers[dstMac] = publisher + } + } + + private func tryHole(request: RegisterRequest) async { + var queryInfo = SDLQueryInfo() + queryInfo.dstMac = request.dstMac + + guard let message = try? await self.superClientActor?.queryInfo(query: queryInfo).get() else { + return + } + + switch message.packet { + case .empty: + self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug) + case .peerInfo(let peerInfo): + if let remoteAddress = peerInfo.v4Info.socketAddress() { + self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug) + // 发送register包 + var register = SDLRegister() + register.networkID = request.networkId + register.srcMac = request.srcMac + register.dstMac = request.dstMac + + await self.udpHoleActor?.sendRegister(register: register, remoteAddress: remoteAddress) + } else { + self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) + } + default: + self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning) + } + } + + deinit { + self.rootTask?.cancel() + self.udpHoleActor = nil + self.superClientActor = nil + self.dnsClient = nil + } + + // 获取mac地址 + public static func getMacAddress() -> Data { + let key = "gMacAddress2" + + let userDefaults = UserDefaults.standard + if let mac = userDefaults.value(forKey: key) as? Data { + return mac + } + else { + let mac = generateMacAddress() + userDefaults.setValue(mac, forKey: key) + + return mac + } + } + + // 随机生成mac地址 + private static func generateMacAddress() -> Data { + var macAddress = [UInt8](repeating: 0, count: 6) + for i in 0..<6 { + macAddress[i] = UInt8.random(in: 0...255) + } + + return Data(macAddress) + } + +} diff --git a/Tun/Punchnet/Actors/SDLSuperClientActor.swift b/Tun/Punchnet/Actors/SDLSuperClientActor.swift new file mode 100644 index 0000000..003fe54 --- /dev/null +++ b/Tun/Punchnet/Actors/SDLSuperClientActor.swift @@ -0,0 +1,283 @@ +// +// SDLWebsocketClient.swift +// Tun +// +// Created by 安礼成 on 2024/3/28. +// + +import Foundation +import NIOCore +import NIOPosix + +// --MARK: 和SuperNode的客户端 +actor SDLSuperClientActor { + // 发送的消息格式 + private typealias TcpMessage = (packetId: UInt32, type: SDLPacketType, data: Data) + + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + private let asyncChannel: NIOAsyncChannel + private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) + private var callbackPromises: [UInt32:EventLoopPromise] = [:] + + public let eventFlow: AsyncStream + private let inboundContinuation: AsyncStream.Continuation + + // id生成器 + var idGenerator = SDLIdGenerator(seed: 1) + + private let logger: SDLLogger + + // 定义事件类型 + enum SuperEvent { + case ready + case event(SDLEvent) + case command(UInt32, SDLCommand) + } + + init(host: String, port: Int, logger: SDLLogger) async throws { + self.logger = logger + + (self.eventFlow, self.inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded) + let bootstrap = ClientBootstrap(group: self.group) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .channelInitializer { channel in + return channel.pipeline.addHandlers([ + ByteToMessageHandler(FixedHeaderDecoder()), + MessageToByteHandler(FixedHeaderEncoder()) + ]) + } + + self.asyncChannel = try await bootstrap.connect(host: host, port: port) + .flatMapThrowing { channel in + return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + )) + } + .get() + } + + func start() async throws { + try await withTaskCancellationHandler { + try await self.asyncChannel.executeThenClose { inbound, outbound in + self.inboundContinuation.yield(.ready) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + self.logger.log("[SDLSuperClient] inbound closed", level: .warning) + } + + for try await var packet in inbound { + try Task.checkCancellation() + + if let message = SDLSuperClientDecoder.decode(buffer: &packet) { + self.logger.log("[SDLSuperTransport] read message: \(message)", level: .debug) + switch message.packet { + case .event(let event): + self.inboundContinuation.yield(.event(event)) + case .command(let command): + self.inboundContinuation.yield(.command(message.msgId, command)) + default: + await self.fireCallback(message: message) + } + } + } + } + + group.addTask { + defer { + self.logger.log("[SDLSuperClient] outbound closed", level: .warning) + } + + for await (packetId, type, data) in self.writeStream { + try Task.checkCancellation() + + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: data.count + 5) + buffer.writeInteger(packetId, as: UInt32.self) + buffer.writeBytes([type.rawValue]) + buffer.writeBytes(data) + try await outbound.write(buffer) + } + } + + // --MARK: 心跳机制 + group.addTask { + defer { + self.logger.log("[SDLSuperClient] ping task closed", level: .warning) + } + + while true { + try Task.checkCancellation() + await self.ping() + try await Task.sleep(nanoseconds: 5 * 1_000_000_000) + } + } + + // 迭代等待所有任务的退出, 第一个异常会被抛出 + if let _ = try await group.next() { + group.cancelAll() + } + } + } + } onCancel: { + self.inboundContinuation.finish() + self.writeContinuation.finish() + self.logger.log("[SDLSuperClient] withTaskCancellationHandler cancel") + } + } + + // -- MARK: apis + func unregister() throws { + self.send(type: .unregisterSuper, packetId: 0, data: Data()) + } + + private func ping() { + self.send(type: .ping, packetId: 0, data: Data()) + } + + func request(type: SDLPacketType, data: Data) -> EventLoopFuture { + let packetId = idGenerator.nextId() + let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) + self.callbackPromises[packetId] = promise + + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) + + return promise.futureResult + } + + func send(type: SDLPacketType, packetId: UInt32, data: Data) { + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) + } + + // 处理回调函数 + private func fireCallback(message: SDLSuperInboundMessage) { + if let promise = self.callbackPromises[message.msgId] { + self.asyncChannel.channel.eventLoop.execute { + promise.succeed(message) + } + self.callbackPromises.removeValue(forKey: message.msgId) + } + } + + deinit { + try! group.syncShutdownGracefully() + } + +} + +// --MARK: 编解码器 +private struct SDLSuperClientDecoder { + // 消息格式为: <> + static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { + guard let msgId = buffer.readInteger(as: UInt32.self), + let type = buffer.readInteger(as: UInt8.self), + let messageType = SDLPacketType(rawValue: type) else { + return nil + } + + switch messageType { + case .empty: + return .init(msgId: msgId, packet: .empty) + case .registerSuperAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) + + case .registerSuperNak: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) + + case .peerInfo: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + return nil + } + + return .init(msgId: msgId, packet: .peerInfo(peerInfo)) + case .pong: + return .init(msgId: msgId, packet: .pong) + + case .command: + guard let commandVal = buffer.readInteger(as: UInt8.self), + let command = SDLCommandType(rawValue: commandVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch command { + case .changeNetwork: + guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else { + return nil + } + + return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) + } + + case .event: + guard let eventVal = buffer.readInteger(as: UInt8.self), + let event = SDLEventType(rawValue: eventVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch event { + case .natChanged: + guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) + case .sendRegister: + guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) + case .networkShutdown: + guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) + } + + default: + return nil + } + } +} + +private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func encode(data: ByteBuffer, out: inout ByteBuffer) throws { + let len = data.readableBytes + out.writeInteger(UInt16(len)) + out.writeBytes(data.readableBytesView) + } +} + +private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { + guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { + return .needMoreData + } + + if buffer.readableBytes >= len + 2 { + buffer.moveReaderIndex(forwardBy: 2) + if let bytes = buffer.readBytes(length: Int(len)) { + context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) + } + return .continue + } else { + return .needMoreData + } + } +} diff --git a/Tun/Punchnet/Actors/SDLTunnelProviderActor.swift b/Tun/Punchnet/Actors/SDLTunnelProviderActor.swift new file mode 100644 index 0000000..fa03a88 --- /dev/null +++ b/Tun/Punchnet/Actors/SDLTunnelProviderActor.swift @@ -0,0 +1,87 @@ +// +// SDLContext.swift +// Tun +// +// Created by 安礼成 on 2024/2/29. +// + +import Foundation +import NetworkExtension +import NIOCore +import Combine + +// 上下文环境变量,全局共享 +/* +1. 处理rsa的加解密逻辑 + */ + +actor SDLTunnelProviderActor { + + // 路由信息 + struct Route { + let dstAddress: String + let subnetMask: String + + var debugInfo: String { + return "\(dstAddress):\(subnetMask)" + } + } + + // 数据包读取任务 + private var readTask: Task<(), Never>? + + let provider: NEPacketTunnelProvider + let logger: SDLLogger + + public init(provider: NEPacketTunnelProvider, logger: SDLLogger) { + self.logger = logger + self.provider = provider + } + + func writePackets(packets: [NEPacket]) { + //let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) + self.provider.packetFlow.writePacketObjects(packets) + } + + // 网络改变时需要重新配置网络信息 + func setNetworkSettings(devAddr: SDLDevAddr, dnsServer: String) async throws { + let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen)) + let routes = [ + Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress), + Route(dstAddress: dnsServer, subnetMask: "255.255.255.255") + ] + + // Add code here to start the process of connecting the tunnel. + let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8") + networkSettings.mtu = 1460 + + // 设置网卡的DNS解析 + + let networkDomain = devAddr.networkDomain + let dnsSettings = NEDNSSettings(servers: [dnsServer]) + dnsSettings.searchDomains = [networkDomain] + dnsSettings.matchDomains = [networkDomain] + dnsSettings.matchDomainsNoSearch = false + networkSettings.dnsSettings = dnsSettings + self.logger.log("[SDLContext] Tun started at network ip: \(netAddress.ipAddress), mask: \(netAddress.maskAddress)", level: .info) + + let ipv4Settings = NEIPv4Settings(addresses: [netAddress.ipAddress], subnetMasks: [netAddress.maskAddress]) + // 设置路由表 + //NEIPv4Route.default() + ipv4Settings.includedRoutes = routes.map { route in + NEIPv4Route(destinationAddress: route.dstAddress, subnetMask: route.subnetMask) + } + networkSettings.ipv4Settings = ipv4Settings + // 网卡配置设置必须成功 + try await self.provider.setTunnelNetworkSettings(networkSettings) + } + + // 开始读取数据, 用单独的线程处理packetFlow + func readPackets() async -> [Data] { + let (packets, numbers) = await self.provider.packetFlow.readPackets() + return zip(packets, numbers).compactMap { (data, number) in + return number == 2 ? data : nil + } + } + +} diff --git a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift new file mode 100644 index 0000000..d08bad6 --- /dev/null +++ b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift @@ -0,0 +1,242 @@ +// +// SDLanServer.swift +// Tun +// +// Created by 安礼成 on 2024/1/31. +// + +import Foundation +import NIOCore +import NIOPosix + +// 处理和sn-server服务器之间的通讯 +@available(macOS 14, *) +actor SDLUDPHoleActor { + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + private let asyncChannel: NIOAsyncChannel, AddressedEnvelope> + private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPMessage.self, bufferingPolicy: .unbounded) + + private var cookieGenerator = SDLIdGenerator(seed: 1) + private var promises: [UInt32:EventLoopPromise] = [:] + public var localAddress: SocketAddress? + + public let eventFlow: AsyncStream + private let eventContinuation: AsyncStream.Continuation + + private let logger: SDLLogger + + // 依赖的外表能力 + struct Capabilities { + let logger: @Sendable (String) async -> Void + + } + + struct UDPMessage { + let remoteAddress: SocketAddress + let type: SDLPacketType + let data: Data + } + + // 定义事件类型 + enum UDPEvent { + case ready + case message(SocketAddress, SDLHoleInboundMessage) + case data(SDLData) + } + + // 启动函数 + init(logger: SDLLogger) async throws { + self.logger = logger + + (self.eventFlow, self.eventContinuation) = AsyncStream.makeStream(of: UDPEvent.self, bufferingPolicy: .unbounded) + + let bootstrap = DatagramBootstrap(group: group) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + + self.asyncChannel = try await bootstrap.bind(host: "0.0.0.0", port: 0) + .flatMapThrowing { channel in + return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( + inboundType: AddressedEnvelope.self, + outboundType: AddressedEnvelope.self + )) + } + .get() + + self.localAddress = self.asyncChannel.channel.localAddress + self.logger.log("[UDPHole] started and listening on: \(self.localAddress!)", level: .debug) + } + + func start() async throws { + try await withTaskCancellationHandler { + try await self.asyncChannel.executeThenClose {inbound, outbound in + self.eventContinuation.yield(.ready) + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + self.logger.log("[SDLUDPHole] inbound closed", level: .warning) + } + + for try await envelope in inbound { + try Task.checkCancellation() + + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + do { + if let message = try Self.decode(buffer: &buffer) { + switch message { + case .data(let data): + self.logger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug) + self.eventContinuation.yield(.data(data)) + case .stunProbeReply(let probeReply): + // 执行并移除回调 + await self.trigger(probeReply: probeReply) + default: + self.eventContinuation.yield(.message(remoteAddress, message)) + } + } else { + self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) + } + } catch let err { + self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) + throw err + } + } + } + + group.addTask { + defer { + self.logger.log("[SDLUDPHole] outbound closed", level: .warning) + } + + for await message in self.writeStream { + try Task.checkCancellation() + + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) + buffer.writeBytes([message.type.rawValue]) + buffer.writeBytes(message.data) + + let envelope = AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer) + try await outbound.write(envelope) + } + } + + if let _ = try await group.next() { + group.cancelAll() + } + } + } + } onCancel: { + self.writeContinuation.finish() + self.eventContinuation.finish() + self.logger.log("[SDLUDPHole] withTaskCancellationHandler cancel") + } + } + + func getCookieId() -> UInt32 { + return self.cookieGenerator.nextId() + } + + // MARK: super_node apis + + func stunRequest(request: SDLStunRequest, remoteAddress: SocketAddress) { + self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! request.serializedData()) + } + + // 探测tun信息 + func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async throws -> SDLStunProbeReply { + return try await self._stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout).get() + } + + private func _stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int) -> EventLoopFuture { + let cookie = self.cookieGenerator.nextId() + var stunProbe = SDLStunProbe() + stunProbe.cookie = cookie + stunProbe.attr = UInt32(attr.rawValue) + self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData()) + self.logger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .debug) + + let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLStunProbeReply.self) + self.promises[cookie] = promise + + return promise.futureResult + } + + private func trigger(probeReply: SDLStunProbeReply) { + let id = probeReply.cookie + // 执行并移除回调 + if let promise = self.promises[id] { + self.asyncChannel.channel.eventLoop.execute { + promise.succeed(probeReply) + } + self.promises.removeValue(forKey: id) + } + } + + // MARK: client-client apis + + // 发送数据包到其他session + func sendPacket(data: SDLData, remoteAddress: SocketAddress) { + if let packet = try? data.serializedData() { + self.logger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug) + self.send(remoteAddress: remoteAddress, type: .data, data: packet) + } + } + + // 发送register包 + func sendRegister(register: SDLRegister, remoteAddress: SocketAddress) { + if let packet = try? register.serializedData() { + self.logger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: register.srcMac)), dst_mac: \(LayerPacket.MacAddress.description(data: register.dstMac))", level: .debug) + self.send(remoteAddress: remoteAddress, type: .register, data: packet) + } + } + + // 回复registerAck + func sendRegisterAck(registerAck: SDLRegisterAck, remoteAddress: SocketAddress) { + if let packet = try? registerAck.serializedData() { + self.logger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug) + self.send(remoteAddress: remoteAddress, type: .registerAck, data: packet) + } + } + + // 处理写入逻辑 + private func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) { + let message = UDPMessage(remoteAddress: remoteAddress, type: type, data: data) + self.writeContinuation.yield(message) + } + + //--MARK: 编解码器 + private static func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { + guard let type = buffer.readInteger(as: UInt8.self), + let packetType = SDLPacketType(rawValue: type), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch packetType { + case .data: + let dataPacket = try SDLData(serializedBytes: bytes) + return .data(dataPacket) + case .register: + let registerPacket = try SDLRegister(serializedBytes: bytes) + return .register(registerPacket) + case .registerAck: + let registerAck = try SDLRegisterAck(serializedBytes: bytes) + return .registerAck(registerAck) + case .stunReply: + let stunReply = try SDLStunReply(serializedBytes: bytes) + return .stunReply(stunReply) + case .stunProbeReply: + let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) + return .stunProbeReply(stunProbeReply) + default: + return nil + } + } + + deinit { + try? self.group.syncShutdownGracefully() + self.writeContinuation.finish() + self.eventContinuation.finish() + } + +}