diff --git a/Tun/Punchnet/Actors/SDLContextNew.swift b/Tun/Punchnet/Actors/SDLContextNew.swift deleted file mode 100644 index eaa9f3a..0000000 --- a/Tun/Punchnet/Actors/SDLContextNew.swift +++ /dev/null @@ -1,695 +0,0 @@ -// -// SDLContext.swift -// Tun -// -// Created by 安礼成 on 2024/2/29. -// - -import Foundation -import NetworkExtension -import NIOCore -import Combine - -// 上下文环境变量,全局共享 -/* -1. 处理rsa的加解密逻辑 - */ - -@available(macOS 14, *) -public class SDLContextNew { - - // 路由信息 - struct Route { - let dstAddress: String - let subnetMask: String - - var debugInfo: String { - return "\(dstAddress):\(subnetMask)" - } - } - - let config: SDLConfiguration - - // tun网络地址信息 - var devAddr: SDLDevAddr - - // nat映射的相关信息, 暂时没有用处 - //var natAddress: SDLNatAddress? - // nat的网络类型 - var natType: SDLNatProber.NatType = .blocked - - // AES加密,授权通过后,对象才会被创建 - var aesCipher: AESCipher - - // aes - var aesKey: Data = Data() - - // rsa的相关配置, public_key是本地生成的 - let rsaCipher: RSACipher - - // 依赖的变量 - var udpHoleActor: SDLUDPHoleActor? - var superClientActor: SDLSuperClientActor? - var providerActor: SDLTunnelProviderActor - - // dns的client对象 - var dnsClient: DNSClient? - - // 数据包读取任务 - private var readTask: Task<(), Never>? - - private var sessionManager: SessionManager - private var arpServer: ArpServer - - // 记录最后发送的stunRequest的cookie - private var lastCookie: UInt32? = 0 - - // 网络状态变化的健康 - private var monitor: SDLNetworkMonitor? - - // 内部socket通讯 - private var noticeClient: SDLNoticeClient? - - // 流量统计 - private var flowTracer = SDLFlowTracerActor() - private var flowTracerCancel: AnyCancellable? - - // 处理holer - private var holerPublishers: [Data:PassthroughSubject] = [:] - private var bag = Set() - private var locker = NSLock() - - private let logger: SDLLogger - private var rootTask: Task? - - struct RegisterRequest { - let srcMac: Data - let dstMac: Data - let networkId: UInt32 - } - - public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { - self.logger = logger - - self.config = config - self.rsaCipher = rsaCipher - self.aesCipher = aesCipher - - // 生成mac地址 - var devAddr = SDLDevAddr() - devAddr.mac = Self.getMacAddress() - self.devAddr = devAddr - - self.sessionManager = SessionManager() - self.arpServer = ArpServer(known_macs: [:]) - self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger) - } - - public func start() async throws { - self.rootTask = Task { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - while !Task.isCancelled { - do { - try await self.startDnsClient() - } catch let err { - self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) - try await Task.sleep(for: .seconds(2)) - } - } - } - - group.addTask { - while !Task.isCancelled { - do { - try await self.startUDPHole() - } catch let err { - self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) - try await Task.sleep(for: .seconds(2)) - } - } - } - - group.addTask { - while !Task.isCancelled { - do { - try await self.startSuperClient() - } catch let err { - self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning) - await self.arpServer.clear() - try await Task.sleep(for: .seconds(2)) - } - } - } - - group.addTask { - await self.startMonitor() - } - - group.addTask { - while !Task.isCancelled { - do { - try await self.startNoticeClient() - } catch let err { - self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) - try await Task.sleep(for: .seconds(2)) - } - } - } - - try await group.waitForAll() - } - } - - try await self.rootTask?.value - } - - public func stop() async { - self.rootTask?.cancel() - self.superClientActor = nil - self.udpHoleActor = nil - self.noticeClient = nil - - self.readTask?.cancel() - } - - private func startNoticeClient() async throws { - self.noticeClient = try await SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) - try await self.noticeClient?.start() - self.logger.log("[SDLContext] notice_client task cancel", level: .warning) - } - - private func startUDPHole() async throws { - self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await self.udpHoleActor?.start() - } - - group.addTask { - while !Task.isCancelled { - try Task.checkCancellation() - try await Task.sleep(nanoseconds: 5 * 1_000_000_000) - try Task.checkCancellation() - - if let udpHoleActor = self.udpHoleActor { - let cookie = await udpHoleActor.getCookieId() - var stunRequest = SDLStunRequest() - stunRequest.cookie = cookie - stunRequest.clientID = self.config.clientId - stunRequest.networkID = self.devAddr.networkID - stunRequest.ip = self.devAddr.netAddr - stunRequest.mac = self.devAddr.mac - stunRequest.natType = UInt32(self.natType.rawValue) - - let remoteAddress = self.config.stunSocketAddress - await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress) - self.lastCookie = cookie - } - } - } - - group.addTask { - if let eventFlow = self.udpHoleActor?.eventFlow { - for try await event in eventFlow { - try Task.checkCancellation() - try await self.handleUDPEvent(event: event) - } - } - } - - if let _ = try await group.next() { - group.cancelAll() - } - } - - } - - private func startSuperClient() async throws { - self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort, logger: self.logger) - try await withThrowingTaskGroup(of: Void.self) { group in - defer { - self.logger.log("[SDLContext] super client task cancel", level: .warning) - } - - group.addTask { - try await self.superClientActor?.start() - } - - group.addTask { - if let eventFlow = self.superClientActor?.eventFlow { - for try await event in eventFlow { - try await self.handleSuperEvent(event: event) - } - } - } - - if let _ = try await group.next() { - group.cancelAll() - } - } - } - - private func startMonitor() async { - self.monitor = SDLNetworkMonitor() - for await event in self.monitor!.eventStream { - switch event { - case .changed: - // TODO 需要重新探测网络的nat类型 - //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) - self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) - case .unreachable: - self.logger.log("didNetworkPathUnreachable", level: .warning) - } - } - } - - private func startDnsClient() async throws { - let remoteDnsServer = config.remoteDnsServer - let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353) - self.dnsClient = try await DNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) - - try await withThrowingTaskGroup(of: Void.self) { group in - defer { - self.logger.log("[SDLContext] dns client task cancel", level: .warning) - } - - group.addTask { - try await self.dnsClient?.start() - } - - group.addTask { - if let packetFlow = self.dnsClient?.packetFlow { - for await packet in packetFlow { - let nePacket = NEPacket(data: packet, protocolFamily: 2) - await self.providerActor.writePackets(packets: [nePacket]) - } - } - - } - - if let _ = try await group.next() { - group.cancelAll() - } - } - } - - private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws { - switch event { - case .ready: - self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug) - var registerSuper = SDLRegisterSuper() - registerSuper.version = UInt32(self.config.version) - registerSuper.clientID = self.config.clientId - registerSuper.devAddr = self.devAddr - registerSuper.pubKey = self.rsaCipher.pubKey - registerSuper.token = self.config.token - registerSuper.networkCode = self.config.networkCode - registerSuper.hostname = self.config.hostname - guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()).get() else { - return - } - - switch message.packet { - case .registerSuperAck(let registerSuperAck): - // 需要对数据通过rsa的私钥解码 - let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) - let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType) - - self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count), network_id:\(registerSuperAck.devAddr.networkID)", level: .info) - self.devAddr = registerSuperAck.devAddr - - if upgradeType == .force { - let forceUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) - await self.noticeClient?.send(data: forceUpgrade) - exit(-1) - } - - // 服务器分配的tun网卡信息 - do { - try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) - self.startReader() - } catch let err { - self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) - exit(-1) - } - - self.aesKey = aesKey - if upgradeType == .normal { - let normalUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) - await self.noticeClient?.send(data: normalUpgrade) - } - - case .registerSuperNak(let nakPacket): - let errorMessage = nakPacket.errorMessage - guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { - return - } - - switch errorCode { - case .invalidToken, .nodeDisabled: - let alertNotice = NoticeMessage.alert(alert: errorMessage) - await self.noticeClient?.send(data: alertNotice) - exit(-1) - case .noIpAddress, .networkFault, .internalFault: - let alertNotice = NoticeMessage.alert(alert: errorMessage) - await self.noticeClient?.send(data: alertNotice) - } - self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning) - default: - () - } - - case .event(let evt): - switch evt { - case .natChanged(let natChangedEvent): - let dstMac = natChangedEvent.mac - self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) - await sessionManager.removeSession(dstMac: dstMac) - case .sendRegister(let sendRegisterEvent): - self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) - let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) - if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { - // 发送register包 - var register = SDLRegister() - register.networkID = self.devAddr.networkID - register.srcMac = self.devAddr.mac - register.dstMac = sendRegisterEvent.dstMac - await self.udpHoleActor?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) - } - - case .networkShutdown(let shutdownEvent): - let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) - await self.noticeClient?.send(data: alertNotice) - exit(-1) - } - case .command(let packetId, let command): - switch command { - case .changeNetwork(let changeNetworkCommand): - // 需要对数据通过rsa的私钥解码 - let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey)) - self.logger.log("[SDLContext] change network command get aes_key len: \(aesKey.count)", level: .info) - self.devAddr = changeNetworkCommand.devAddr - - // 服务器分配的tun网卡信息 - do { - try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) - self.startReader() - } catch let err { - self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) - exit(-1) - } - - self.aesKey = aesKey - - var commandAck = SDLCommandAck() - commandAck.status = true - await self.superClientActor?.send(type: .commandAck, packetId: packetId, data: try commandAck.serializedData()) - } - } - - } - - private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws { - switch event { - case .ready: - // 获取当前网络的类型 - //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config) - self.logger.log("[SDLContext] nat type is: \(self.natType)", level: .debug) - - case .message(let remoteAddress, let message): - switch message { - case .register(let register): - self.logger.log("register packet: \(register), dev_addr: \(self.devAddr)", level: .debug) - // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 - if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { - // 回复ack包 - var registerAck = SDLRegisterAck() - registerAck.networkID = self.devAddr.networkID - registerAck.srcMac = self.devAddr.mac - registerAck.dstMac = register.srcMac - - await self.udpHoleActor?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) - // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 - let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) - } else { - self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) - } - case .registerAck(let registerAck): - // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 - if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID { - let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) - } else { - self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) - } - case .stunReply(let stunReply): - let cookie = stunReply.cookie - if cookie == self.lastCookie { - // 记录下当前在nat上的映射信息,暂时没有用;后续会用来判断网络类型 - //self.natAddress = stunReply.natAddress - self.logger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())", level: .debug) - } - default: - () - } - - case .data(let data): - let mac = LayerPacket.MacAddress(data: data.dstMac) - guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { - return - } - - guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else { - return - } - - do { - let layerPacket = try LayerPacket(layerData: decyptedData) - - await self.flowTracer.inc(num: decyptedData.count, type: .inbound) - // 处理arp请求 - switch layerPacket.type { - case .arp: - // 判断如果收到的是arp请求 - if let arpPacket = ARPPacket(data: layerPacket.data) { - if arpPacket.targetIP == self.devAddr.netAddr { - switch arpPacket.opcode { - case .request: - self.logger.log("[SDLContext] get arp request packet", level: .debug) - let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr) - await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) - case .response: - self.logger.log("[SDLContext] get arp response packet", level: .debug) - await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) - } - } else { - self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(self.devAddr.netAddr))", level: .debug) - } - } else { - self.logger.log("[SDLContext] get invalid arp packet", level: .debug) - } - case .ipv4: - guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else { - return - } - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - await self.providerActor.writePackets(packets: [packet]) - default: - self.logger.log("[SDLContext] get invalid packet", level: .debug) - } - } catch let err { - self.logger.log("[SDLContext] didReadData err: \(err)", level: .warning) - } - } - - } - - // 流量统计 -// public func flowReportTask() { -// Task { -// // 每分钟汇报一次 -// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() -// .sink { _ in -// Task { -// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() -// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) -// } -// } -// } -// } - - // 开始读取数据, 用单独的线程处理packetFlow - private func startReader() { - // 停止之前的任务 - self.readTask?.cancel() - - // 开启新的任务 - self.readTask = Task(priority: .high) { - repeat { - let packets = await self.providerActor.readPackets() - for packet in packets { - await self.dealPacket(data: packet) - } - } while true - } - } - - // 处理读取的每个数据包 - private func dealPacket(data: Data) async { - guard let packet = IPPacket(data) else { - return - } - - if DNSClient.Helper.isDnsRequestPacket(ipPacket: packet) { - let destIp = packet.header.destination_ip - NSLog("destIp: \(destIp), int: \(packet.header.destination)") - await self.dnsClient?.forward(ipPacket: packet) - } - else { - Task.detached { - let dstIp = packet.header.destination - // 本地通讯, 目标地址是本地服务器的ip地址 - if dstIp == self.devAddr.netAddr { - let nePacket = NEPacket(data: packet.data, protocolFamily: 2) - await self.providerActor.writePackets(packets: [nePacket]) - return - } - - // 查找arp缓存中是否有目标mac地址 - if let dstMac = await self.arpServer.query(ip: dstIp) { - await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) - } - else { - // 构造arp请求 - let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]) - let arpReqeust = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp) - await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal()) - - self.logger.log("[SDLContext] dstIp: \(dstIp) arp query not found", level: .debug) - } - } - } - } - - private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async { - // 将数据封装层2层的数据包 - let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data) - guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else { - return - } - - // 构造数据包 - var dataPacket = SDLData() - dataPacket.networkID = self.devAddr.networkID - dataPacket.srcMac = self.devAddr.mac - dataPacket.dstMac = dstMac - dataPacket.ttl = 255 - dataPacket.data = encodedPacket - - // 通过session发送到对端 - if let session = await self.sessionManager.getSession(toAddress: dstMac) { - self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) - await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: session.natAddress) - await self.flowTracer.inc(num: data.count, type: .p2p) - } - else { - // 通过super_node进行转发 - let superAddress = self.config.stunSocketAddress - await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: superAddress) - // 流量统计 - await self.flowTracer.inc(num: data.count, type: .forward) - - // 尝试打洞 - let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID) - self.submitRegisterRequest(request: registerRequest) - } - } - - private func submitRegisterRequest(request: RegisterRequest) { - self.locker.lock() - defer { - self.locker.unlock() - } - - let dstMac = request.dstMac - if let publisher = self.holerPublishers[dstMac] { - publisher.send(request) - } else { - let publisher = PassthroughSubject() - publisher.throttle(for: .seconds(5), scheduler: DispatchQueue.global(), latest: true) - .sink { request in - Task { - await self.tryHole(request: request) - } - } - .store(in: &self.bag) - - self.holerPublishers[dstMac] = publisher - } - } - - private func tryHole(request: RegisterRequest) async { - var queryInfo = SDLQueryInfo() - queryInfo.dstMac = request.dstMac - - guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else { - return - } - - switch message.packet { - case .empty: - self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug) - case .peerInfo(let peerInfo): - if let remoteAddress = peerInfo.v4Info.socketAddress() { - self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug) - // 发送register包 - var register = SDLRegister() - register.networkID = request.networkId - register.srcMac = request.srcMac - register.dstMac = request.dstMac - - await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress) - } else { - self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) - } - default: - self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning) - } - } - - deinit { - self.rootTask?.cancel() - self.udpHoleActor = nil - self.superClientActor = nil - self.dnsClient = nil - } - - // 获取mac地址 - public static func getMacAddress() -> Data { - let key = "gMacAddress2" - - let userDefaults = UserDefaults.standard - if let mac = userDefaults.value(forKey: key) as? Data { - return mac - } - else { - let mac = generateMacAddress() - userDefaults.setValue(mac, forKey: key) - - return mac - } - } - - // 随机生成mac地址 - private static func generateMacAddress() -> Data { - var macAddress = [UInt8](repeating: 0, count: 6) - for i in 0..<6 { - macAddress[i] = UInt8.random(in: 0...255) - } - - return Data(macAddress) - } - -} diff --git a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift index 667c313..497d8ad 100644 --- a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift +++ b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift @@ -135,24 +135,24 @@ actor SDLUDPHoleActor { return self.cookieGenerator.nextId() } -// // 探测tun信息 -// func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async throws -> SDLStunProbeReply { -// return try await self._stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout).get() -// } -// -// private func _stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int) -> EventLoopFuture { -// let cookie = self.cookieGenerator.nextId() -// var stunProbe = SDLStunProbe() -// stunProbe.cookie = cookie -// stunProbe.attr = UInt32(attr.rawValue) -// self.send( type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remoteAddress) -// self.logger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .debug) -// -// let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLStunProbeReply.self) -// self.promises[cookie] = promise -// -// return promise.futureResult -// } + // 探测tun信息 + func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async throws -> SDLStunProbeReply { + return try await self._stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout).get() + } + + private func _stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int) -> EventLoopFuture { + let cookie = self.cookieGenerator.nextId() + var stunProbe = SDLStunProbe() + stunProbe.cookie = cookie + stunProbe.attr = UInt32(attr.rawValue) + self.send( type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remoteAddress) + self.logger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .debug) + + let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLStunProbeReply.self) + self.promises[cookie] = promise + + return promise.futureResult + } private func trigger(probeReply: SDLStunProbeReply) { let id = probeReply.cookie diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 70a9f26..10b37e6 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -16,7 +16,7 @@ import Combine */ @available(macOS 14, *) -public class SDLContext: @unchecked Sendable { +public class SDLContext { // 路由信息 struct Route { @@ -36,7 +36,7 @@ public class SDLContext: @unchecked Sendable { // nat映射的相关信息, 暂时没有用处 //var natAddress: SDLNatAddress? // nat的网络类型 - var natType: SDLNatProber.NatType = .blocked + var natType: NatType = .blocked // AES加密,授权通过后,对象才会被创建 var aesCipher: AESCipher @@ -48,8 +48,9 @@ public class SDLContext: @unchecked Sendable { let rsaCipher: RSACipher // 依赖的变量 - var udpHole: SDLUDPHole? - var superClient: SDLSuperClient? + var udpHoleActor: SDLUDPHoleActor? + var superClientActor: SDLSuperClientActor? + var providerActor: SDLTunnelProviderActor // dns的client对象 var dnsClient: DNSClient? @@ -57,8 +58,6 @@ public class SDLContext: @unchecked Sendable { // 数据包读取任务 private var readTask: Task<(), Never>? - let provider: NEPacketTunnelProvider - private var sessionManager: SessionManager private var arpServer: ArpServer @@ -81,7 +80,6 @@ public class SDLContext: @unchecked Sendable { private var locker = NSLock() private let logger: SDLLogger - private var rootTask: Task? struct RegisterRequest { @@ -102,15 +100,14 @@ public class SDLContext: @unchecked Sendable { devAddr.mac = Self.getMacAddress() self.devAddr = devAddr - self.provider = provider self.sessionManager = SessionManager() self.arpServer = ArpServer(known_macs: [:]) + self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger) } public func start() async throws { self.rootTask = Task { try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { while !Task.isCancelled { do { @@ -169,8 +166,8 @@ public class SDLContext: @unchecked Sendable { public func stop() async { self.rootTask?.cancel() - self.superClient = nil - self.udpHole = nil + self.superClientActor = nil + self.udpHoleActor = nil self.noticeClient = nil self.readTask?.cancel() @@ -183,22 +180,39 @@ public class SDLContext: @unchecked Sendable { } private func startUDPHole() async throws { - self.udpHole = try await SDLUDPHole(logger: self.logger) + self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { - try await self.udpHole?.start() + try await self.udpHoleActor?.start() } group.addTask { while !Task.isCancelled { + try Task.checkCancellation() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) - self.lastCookie = await self.udpHole?.stunRequest(context: self) + try Task.checkCancellation() + + if let udpHoleActor = self.udpHoleActor { + let cookie = await udpHoleActor.getCookieId() + var stunRequest = SDLStunRequest() + stunRequest.cookie = cookie + stunRequest.clientID = self.config.clientId + stunRequest.networkID = self.devAddr.networkID + stunRequest.ip = self.devAddr.netAddr + stunRequest.mac = self.devAddr.mac + stunRequest.natType = UInt32(self.natType.rawValue) + + let remoteAddress = self.config.stunSocketAddress + await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress) + self.lastCookie = cookie + } } } group.addTask { - if let eventFlow = self.udpHole?.eventFlow { + if let eventFlow = self.udpHoleActor?.eventFlow { for try await event in eventFlow { + try Task.checkCancellation() try await self.handleUDPEvent(event: event) } } @@ -212,18 +226,18 @@ public class SDLContext: @unchecked Sendable { } private func startSuperClient() async throws { - self.superClient = try await SDLSuperClient(host: self.config.superHost, port: self.config.superPort, logger: self.logger) + self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort, logger: self.logger) try await withThrowingTaskGroup(of: Void.self) { group in defer { self.logger.log("[SDLContext] super client task cancel", level: .warning) } group.addTask { - try await self.superClient?.start() + try await self.superClientActor?.start() } group.addTask { - if let eventFlow = self.superClient?.eventFlow { + if let eventFlow = self.superClientActor?.eventFlow { for try await event in eventFlow { try await self.handleSuperEvent(event: event) } @@ -241,8 +255,8 @@ public class SDLContext: @unchecked Sendable { for await event in self.monitor!.eventStream { switch event { case .changed: - // 需要重新探测网络的nat类型 - self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) + // TODO 需要重新探测网络的nat类型 + //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) @@ -268,7 +282,7 @@ public class SDLContext: @unchecked Sendable { if let packetFlow = self.dnsClient?.packetFlow { for await packet in packetFlow { let nePacket = NEPacket(data: packet, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([nePacket]) + await self.providerActor.writePackets(packets: [nePacket]) } } @@ -280,11 +294,19 @@ public class SDLContext: @unchecked Sendable { } } - private func handleSuperEvent(event: SDLSuperClient.SuperEvent) async throws { + private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws { switch event { case .ready: self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug) - guard let message = try await self.superClient?.registerSuper(context: self).get() else { + var registerSuper = SDLRegisterSuper() + registerSuper.version = UInt32(self.config.version) + registerSuper.clientID = self.config.clientId + registerSuper.devAddr = self.devAddr + registerSuper.pubKey = self.rsaCipher.pubKey + registerSuper.token = self.config.token + registerSuper.networkCode = self.config.networkCode + registerSuper.hostname = self.config.hostname + guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()).get() else { return } @@ -304,9 +326,15 @@ public class SDLContext: @unchecked Sendable { } // 服务器分配的tun网卡信息 - await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) - self.aesKey = aesKey + do { + try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) + self.startReader() + } catch let err { + self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) + exit(-1) + } + self.aesKey = aesKey if upgradeType == .normal { let normalUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) await self.noticeClient?.send(data: normalUpgrade) @@ -343,7 +371,11 @@ public class SDLContext: @unchecked Sendable { let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 - await self.udpHole?.sendRegister(remoteAddress: remoteAddress, networkId: self.devAddr.networkID, srcMac: self.devAddr.mac, dst_mac: sendRegisterEvent.dstMac) + var register = SDLRegister() + register.networkID = self.devAddr.networkID + register.srcMac = self.devAddr.mac + register.dstMac = sendRegisterEvent.dstMac + await self.udpHoleActor?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): @@ -360,19 +392,25 @@ public class SDLContext: @unchecked Sendable { self.devAddr = changeNetworkCommand.devAddr // 服务器分配的tun网卡信息 - await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) + do { + try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer) + self.startReader() + } catch let err { + self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) + exit(-1) + } + self.aesKey = aesKey var commandAck = SDLCommandAck() commandAck.status = true - - await self.superClient?.commandAck(packetId: packetId, ack: commandAck) + await self.superClientActor?.send(type: .commandAck, packetId: packetId, data: try commandAck.serializedData()) } } } - private func handleUDPEvent(event: SDLUDPHole.UDPEvent) async throws { + private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws { switch event { case .ready: // 获取当前网络的类型 @@ -386,7 +424,12 @@ public class SDLContext: @unchecked Sendable { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { // 回复ack包 - await self.udpHole?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac) + var registerAck = SDLRegisterAck() + registerAck.networkID = self.devAddr.networkID + registerAck.srcMac = self.devAddr.mac + registerAck.dstMac = register.srcMac + + await self.udpHoleActor?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) @@ -452,7 +495,7 @@ public class SDLContext: @unchecked Sendable { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) + await self.providerActor.writePackets(packets: [packet]) default: self.logger.log("[SDLContext] get invalid packet", level: .debug) } @@ -477,50 +520,6 @@ public class SDLContext: @unchecked Sendable { // } // } - // 网络改变时需要重新配置网络信息 - private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServer: String) async { - let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen)) - let routes = [ - Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress), - Route(dstAddress: dnsServer, subnetMask: "255.255.255.255") - ] - - // Add code here to start the process of connecting the tunnel. - let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8") - networkSettings.mtu = 1460 - - // 设置网卡的DNS解析 - - let networkDomain = devAddr.networkDomain - let dnsSettings = NEDNSSettings(servers: [dnsServer]) - dnsSettings.searchDomains = [networkDomain] - dnsSettings.matchDomains = [networkDomain] - dnsSettings.matchDomainsNoSearch = false - networkSettings.dnsSettings = dnsSettings - self.logger.log("[SDLContext] Tun started at network ip: \(netAddress.ipAddress), mask: \(netAddress.maskAddress)", level: .info) - - let ipv4Settings = NEIPv4Settings(addresses: [netAddress.ipAddress], subnetMasks: [netAddress.maskAddress]) - // 设置路由表 - //NEIPv4Route.default() - ipv4Settings.includedRoutes = routes.map { route in - NEIPv4Route(destinationAddress: route.dstAddress, subnetMask: route.subnetMask) - } - networkSettings.ipv4Settings = ipv4Settings - // 网卡配置设置必须成功 - do { - try await self.provider.setTunnelNetworkSettings(networkSettings) - self.startReader() - - let ipMessage = NoticeMessage.ipAdress(ip: netAddress.ipAddress) - await self.noticeClient?.send(data: ipMessage) - - self.logger.log("[SDLContext] setTunnelNetworkSettings success, start read packet", level: .info) - } catch let err { - self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) - exit(-1) - } - } - // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 @@ -529,9 +528,9 @@ public class SDLContext: @unchecked Sendable { // 开启新的任务 self.readTask = Task(priority: .high) { repeat { - let (packets, numbers) = await self.provider.packetFlow.readPackets() - for (data, number) in zip(packets, numbers) where number == 2 { - await self.dealPacket(data: data) + let packets = await self.providerActor.readPackets() + for packet in packets { + await self.dealPacket(data: packet) } } while true } @@ -554,7 +553,7 @@ public class SDLContext: @unchecked Sendable { // 本地通讯, 目标地址是本地服务器的ip地址 if dstIp == self.devAddr.netAddr { let nePacket = NEPacket(data: packet.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([nePacket]) + await self.providerActor.writePackets(packets: [nePacket]) return } @@ -581,16 +580,24 @@ public class SDLContext: @unchecked Sendable { return } + // 构造数据包 + var dataPacket = SDLData() + dataPacket.networkID = self.devAddr.networkID + dataPacket.srcMac = self.devAddr.mac + dataPacket.dstMac = dstMac + dataPacket.ttl = 255 + dataPacket.data = encodedPacket + // 通过session发送到对端 if let session = await self.sessionManager.getSession(toAddress: dstMac) { self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) - await self.udpHole?.sendPacket(context: self, session: session, data: encodedPacket) - + await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: session.natAddress) await self.flowTracer.inc(num: data.count, type: .p2p) } else { // 通过super_node进行转发 - await self.udpHole?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket) + let superAddress = self.config.stunSocketAddress + await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: superAddress) // 流量统计 await self.flowTracer.inc(num: data.count, type: .forward) @@ -624,7 +631,10 @@ public class SDLContext: @unchecked Sendable { } private func tryHole(request: RegisterRequest) async { - guard let message = try? await self.superClient?.queryInfo(dst_mac: request.dstMac).get() else { + var queryInfo = SDLQueryInfo() + queryInfo.dstMac = request.dstMac + + guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else { return } @@ -635,7 +645,12 @@ public class SDLContext: @unchecked Sendable { if let remoteAddress = peerInfo.v4Info.socketAddress() { self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug) // 发送register包 - await self.udpHole?.sendRegister(remoteAddress: remoteAddress, networkId: request.networkId, srcMac: request.srcMac, dst_mac: request.dstMac) + var register = SDLRegister() + register.networkID = request.networkId + register.srcMac = request.srcMac + register.dstMac = request.dstMac + + await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress) } else { self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) } @@ -646,8 +661,8 @@ public class SDLContext: @unchecked Sendable { deinit { self.rootTask?.cancel() - self.udpHole = nil - self.superClient = nil + self.udpHoleActor = nil + self.superClientActor = nil self.dnsClient = nil } @@ -678,3 +693,68 @@ public class SDLContext: @unchecked Sendable { } } + +// 网络类型探测 +extension SDLContext { + // 定义nat类型 + enum NatType: UInt8, Encodable { + case blocked = 0 + case noNat = 1 + case fullCone = 2 + case portRestricted = 3 + case coneRestricted = 4 + case symmetric = 5 + } + + // 获取当前所处的网络的nat类型 + func getNatType() async -> NatType { + guard let udpHole = self.udpHoleActor else { + return .blocked + } + + let addressArray = config.stunProbeSocketAddressArray + // step1: ip1:port1 <---- ip1:port1 + guard let natAddress1 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .none) else { + return .blocked + } + + // 网络没有在nat下 + if await natAddress1 == udpHole.localAddress { + return .noNat + } + + // step2: ip2:port2 <---- ip2:port2 + guard let natAddress2 = await getNatAddress(udpHole, remoteAddress: addressArray[1][1], attr: .none) else { + return .blocked + } + + // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 + // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 + logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2)", level: .debug) + if let ipAddress1 = natAddress1.ipAddress, let ipAddress2 = natAddress2.ipAddress, ipAddress1 != ipAddress2 { + return .symmetric + } + + // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) + // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 + if let natAddress3 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .peer) { + logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address3: \(natAddress3)", level: .debug) + return .fullCone + } + + // step3: ip1:port1 <---- ip1:port2 (port改变情况) + // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 + if let natAddress4 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .port) { + logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address4: \(natAddress4)", level: .debug) + return .coneRestricted + } else { + return .portRestricted + } + } + + private func getNatAddress(_ udpHole: SDLUDPHoleActor, remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { + let stunProbeReply = try? await udpHole.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) + return stunProbeReply?.socketAddress() + } + +} diff --git a/Tun/Punchnet/SDLNatProber.swift b/Tun/Punchnet/SDLNatProber.swift deleted file mode 100644 index b61fe93..0000000 --- a/Tun/Punchnet/SDLNatProber.swift +++ /dev/null @@ -1,77 +0,0 @@ -// -// File.swift -// sdlan -// -// Created by 安礼成 on 2025/7/14. -// - -import Foundation -import NIOCore - -// 网络类型探测器 -@available(macOS 14, *) -struct SDLNatProber { - - // 定义nat类型 - enum NatType: UInt8, Encodable { - case blocked = 0 - case noNat = 1 - case fullCone = 2 - case portRestricted = 3 - case coneRestricted = 4 - case symmetric = 5 - } - - // 获取当前所处的网络的nat类型 - static func getNatType(udpHole: SDLUDPHole?, config: SDLConfiguration, logger: SDLLogger) async -> NatType { - guard let udpHole else { - return .blocked - } - - let addressArray = config.stunProbeSocketAddressArray - // step1: ip1:port1 <---- ip1:port1 - guard let natAddress1 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .none) else { - return .blocked - } - - // 网络没有在nat下 - if await natAddress1 == udpHole.localAddress { - return .noNat - } - - // step2: ip2:port2 <---- ip2:port2 - guard let natAddress2 = await getNatAddress(udpHole, remoteAddress: addressArray[1][1], attr: .none) else { - return .blocked - } - - // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 - // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2)", level: .debug) - if let ipAddress1 = natAddress1.ipAddress, let ipAddress2 = natAddress2.ipAddress, ipAddress1 != ipAddress2 { - return .symmetric - } - - // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) - // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let natAddress3 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .peer) { - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address3: \(natAddress3)", level: .debug) - return .fullCone - } - - // step3: ip1:port1 <---- ip1:port2 (port改变情况) - // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let natAddress4 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .port) { - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address4: \(natAddress4)", level: .debug) - return .coneRestricted - } else { - return .portRestricted - } - } - - private static func getNatAddress(_ udpHole: SDLUDPHole, remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { - let stunProbeReply = try? await udpHole.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) - - return stunProbeReply?.socketAddress() - } - -} diff --git a/Tun/Punchnet/SDLSuperClient.swift b/Tun/Punchnet/SDLSuperClient.swift deleted file mode 100644 index ef10ff9..0000000 --- a/Tun/Punchnet/SDLSuperClient.swift +++ /dev/null @@ -1,330 +0,0 @@ -// -// SDLWebsocketClient.swift -// Tun -// -// Created by 安礼成 on 2024/3/28. -// - -import Foundation -import NIOCore -import NIOPosix - -// --MARK: 和SuperNode的客户端 -@available(macOS 14, *) -actor SDLSuperClient { - private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - private let asyncChannel: NIOAsyncChannel - private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) - private var callbackPromises: [UInt32:EventLoopPromise] = [:] - - public let eventFlow: AsyncStream - private let inboundContinuation: AsyncStream.Continuation - - // id生成器 - var idGenerator = SDLIdGenerator(seed: 1) - - private let logger: SDLLogger - // 发送的消息格式 - struct TcpMessage { - let packetId: UInt32 - let type: SDLPacketType - let data: Data - } - - // 定义事件类型 - enum SuperEvent { - case ready - case event(SDLEvent) - case command(UInt32, SDLCommand) - } - - init(host: String, port: Int, logger: SDLLogger) async throws { - self.logger = logger - - (self.eventFlow, self.inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded) - let bootstrap = ClientBootstrap(group: self.group) - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .channelInitializer { channel in - return channel.pipeline.addHandlers([ - ByteToMessageHandler(FixedHeaderDecoder()), - MessageToByteHandler(FixedHeaderEncoder()) - ]) - } - - self.asyncChannel = try await bootstrap.connect(host: host, port: port) - .flatMapThrowing { channel in - return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( - inboundType: ByteBuffer.self, - outboundType: ByteBuffer.self - )) - } - .get() - } - - func start() async throws { - try await withTaskCancellationHandler { - try await self.asyncChannel.executeThenClose { inbound, outbound in - self.inboundContinuation.yield(.ready) - - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - defer { - self.logger.log("[SDLSuperClient] inbound closed", level: .warning) - } - - for try await var packet in inbound { - try Task.checkCancellation() - - if let message = SDLSuperClientDecoder.decode(buffer: &packet) { - self.logger.log("[SDLSuperTransport] read message: \(message)", level: .debug) - switch message.packet { - case .event(let event): - self.inboundContinuation.yield(.event(event)) - case .command(let command): - self.inboundContinuation.yield(.command(message.msgId, command)) - default: - await self.fireCallback(message: message) - } - } - } - } - - group.addTask { - defer { - self.logger.log("[SDLSuperClient] outbound closed", level: .warning) - } - - for await message in self.writeStream { - try Task.checkCancellation() - - var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5) - buffer.writeInteger(message.packetId, as: UInt32.self) - buffer.writeBytes([message.type.rawValue]) - buffer.writeBytes(message.data) - try await outbound.write(buffer) - } - } - - // --MARK: 心跳机制 - group.addTask { - defer { - self.logger.log("[SDLSuperClient] ping task closed", level: .warning) - } - - while true { - try Task.checkCancellation() - await self.ping() - try await Task.sleep(nanoseconds: 5 * 1_000_000_000) - } - } - - // 迭代等待所有任务的退出, 第一个异常会被抛出 - if let _ = try await group.next() { - group.cancelAll() - } - } - } - } onCancel: { - self.inboundContinuation.finish() - self.writeContinuation.finish() - self.logger.log("[SDLSuperClient] withTaskCancellationHandler cancel") - } - } - - // -- MARK: apis - - func commandAck(packetId: UInt32, ack: SDLCommandAck) { - guard let data = try? ack.serializedData() else { - return - } - - self.send(type: .commandAck, packetId: packetId, data: data) - } - - func registerSuper(context ctx: SDLContext) throws -> EventLoopFuture { - var registerSuper = SDLRegisterSuper() - registerSuper.version = UInt32(ctx.config.version) - registerSuper.clientID = ctx.config.clientId - registerSuper.devAddr = ctx.devAddr - registerSuper.pubKey = ctx.rsaCipher.pubKey - registerSuper.token = ctx.config.token - registerSuper.networkCode = ctx.config.networkCode - registerSuper.hostname = ctx.config.hostname - - NSLog("[SuperClient] register super request: \(registerSuper)") - - let data = try! registerSuper.serializedData() - - return self.write(type: .registerSuper, data: data) - } - - // 查询目标服务器的相关信息 - func queryInfo(dst_mac: Data) async throws -> EventLoopFuture { - var queryInfo = SDLQueryInfo() - queryInfo.dstMac = dst_mac - - return self.write(type: .queryInfo, data: try! queryInfo.serializedData()) - } - - func unregister(context ctx: SDLContext) throws { - self.send(type: .unregisterSuper, packetId: 0, data: Data()) - } - - func ping() { - self.send(type: .ping, packetId: 0, data: Data()) - } - - func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) { - var flow = SDLFlows() - flow.forwardNum = forwardNum - flow.p2PNum = p2pNum - flow.inboundNum = inboundNum - - self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData()) - } - - private func write(type: SDLPacketType, data: Data) -> EventLoopFuture { - let packetId = idGenerator.nextId() - let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) - self.callbackPromises[packetId] = promise - - self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) - - return promise.futureResult - } - - private func send(type: SDLPacketType, packetId: UInt32, data: Data) { - self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) - } - - // 处理回调函数 - private func fireCallback(message: SDLSuperInboundMessage) { - if let promise = self.callbackPromises[message.msgId] { - self.asyncChannel.channel.eventLoop.execute { - promise.succeed(message) - } - self.callbackPromises.removeValue(forKey: message.msgId) - } - } - - deinit { - try! group.syncShutdownGracefully() - } - -} - -// --MARK: 编解码器 -private struct SDLSuperClientDecoder { - // 消息格式为: <> - static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { - guard let msgId = buffer.readInteger(as: UInt32.self), - let type = buffer.readInteger(as: UInt8.self), - let messageType = SDLPacketType(rawValue: type) else { - return nil - } - - switch messageType { - case .empty: - return .init(msgId: msgId, packet: .empty) - case .registerSuperAck: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { - return nil - } - return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) - - case .registerSuperNak: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { - return nil - } - return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) - - case .peerInfo: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { - return nil - } - - return .init(msgId: msgId, packet: .peerInfo(peerInfo)) - case .pong: - return .init(msgId: msgId, packet: .pong) - - case .command: - guard let commandVal = buffer.readInteger(as: UInt8.self), - let command = SDLCommandType(rawValue: commandVal), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch command { - case .changeNetwork: - guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else { - return nil - } - - return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) - } - - case .event: - guard let eventVal = buffer.readInteger(as: UInt8.self), - let event = SDLEventType(rawValue: eventVal), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch event { - case .natChanged: - guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { - return nil - } - return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) - case .sendRegister: - guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { - return nil - } - return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) - case .networkShutdown: - guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { - return nil - } - return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) - } - - default: - return nil - } - } -} - -private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - - func encode(data: ByteBuffer, out: inout ByteBuffer) throws { - let len = data.readableBytes - out.writeInteger(UInt16(len)) - out.writeBytes(data.readableBytesView) - } -} - -private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - - func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { - guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { - return .needMoreData - } - - if buffer.readableBytes >= len + 2 { - buffer.moveReaderIndex(forwardBy: 2) - if let bytes = buffer.readBytes(length: Int(len)) { - context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) - } - return .continue - } else { - return .needMoreData - } - } -} diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift deleted file mode 100644 index e8b7aba..0000000 --- a/Tun/Punchnet/SDLUDPHole.swift +++ /dev/null @@ -1,286 +0,0 @@ -// -// SDLanServer.swift -// Tun -// -// Created by 安礼成 on 2024/1/31. -// - -import Foundation -import NIOCore -import NIOPosix - -// 处理和sn-server服务器之间的通讯 -@available(macOS 14, *) -actor SDLUDPHole { - private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - private let asyncChannel: NIOAsyncChannel, AddressedEnvelope> - private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPMessage.self, bufferingPolicy: .unbounded) - - private var cookieGenerator = SDLIdGenerator(seed: 1) - private var promises: [UInt32:EventLoopPromise] = [:] - public var localAddress: SocketAddress? - - public let eventFlow: AsyncStream - private let eventContinuation: AsyncStream.Continuation - - private let logger: SDLLogger - - struct UDPMessage { - let remoteAddress: SocketAddress - let type: SDLPacketType - let data: Data - } - - // 定义事件类型 - enum UDPEvent { - case ready - case message(SocketAddress, SDLHoleInboundMessage) - case data(SDLData) - } - - // 启动函数 - init(logger: SDLLogger) async throws { - self.logger = logger - - (self.eventFlow, self.eventContinuation) = AsyncStream.makeStream(of: UDPEvent.self, bufferingPolicy: .unbounded) - - let bootstrap = DatagramBootstrap(group: group) - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - self.asyncChannel = try await bootstrap.bind(host: "0.0.0.0", port: 0) - .flatMapThrowing { channel in - return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( - inboundType: AddressedEnvelope.self, - outboundType: AddressedEnvelope.self - )) - } - .get() - - self.localAddress = self.asyncChannel.channel.localAddress - self.logger.log("[UDPHole] started and listening on: \(self.localAddress!)", level: .debug) - } - - func start() async throws { - try await withTaskCancellationHandler { - try await self.asyncChannel.executeThenClose {inbound, outbound in - self.eventContinuation.yield(.ready) - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - defer { - self.logger.log("[SDLUDPHole] inbound closed", level: .warning) - } - - for try await envelope in inbound { - try Task.checkCancellation() - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - do { - if let message = try Self.decode(buffer: &buffer) { - switch message { - case .data(let data): - self.logger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug) - self.eventContinuation.yield(.data(data)) - case .stunProbeReply(let probeReply): - // 执行并移除回调 - await self.trigger(probeReply: probeReply) - default: - self.eventContinuation.yield(.message(remoteAddress, message)) - } - } else { - self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) - } - } catch let err { - self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) - throw err - } - } - } - - group.addTask { - defer { - self.logger.log("[SDLUDPHole] outbound closed", level: .warning) - } - - for await message in self.writeStream { - try Task.checkCancellation() - - var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) - buffer.writeBytes([message.type.rawValue]) - buffer.writeBytes(message.data) - - let envelope = AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer) - try await outbound.write(envelope) - } - } - - if let _ = try await group.next() { - group.cancelAll() - } - } - } - } onCancel: { - self.writeContinuation.finish() - self.eventContinuation.finish() - self.logger.log("[SDLUDPHole] withTaskCancellationHandler cancel") - } - } - - // MARK: super_node apis - - func stunRequest(context ctx: SDLContext) -> UInt32? { - guard ctx.devAddr.networkID > 0 else { - return nil - } - - let cookie = self.cookieGenerator.nextId() - let remoteAddress = ctx.config.stunSocketAddress - - var stunRequest = SDLStunRequest() - stunRequest.cookie = cookie - stunRequest.clientID = ctx.config.clientId - stunRequest.networkID = ctx.devAddr.networkID - stunRequest.ip = ctx.devAddr.netAddr - stunRequest.mac = ctx.devAddr.mac - stunRequest.natType = UInt32(ctx.natType.rawValue) - - self.logger.log("[SDLUDPHole] stunRequest: \(remoteAddress), host: \(ctx.config.stunServers[0].host):\(ctx.config.stunServers[0].ports[0]), network_id: \(ctx.devAddr.networkID)", level: .debug) - - self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! stunRequest.serializedData()) - - return cookie - } - - // 探测tun信息 - func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async throws -> SDLStunProbeReply { - return try await self._stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout).get() - } - - private func _stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int) -> EventLoopFuture { - let cookie = self.cookieGenerator.nextId() - var stunProbe = SDLStunProbe() - stunProbe.cookie = cookie - stunProbe.attr = UInt32(attr.rawValue) - self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData()) - self.logger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .debug) - - let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLStunProbeReply.self) - self.promises[cookie] = promise - - return promise.futureResult - } - - private func trigger(probeReply: SDLStunProbeReply) { - let id = probeReply.cookie - // 执行并移除回调 - if let promise = self.promises[id] { - self.asyncChannel.channel.eventLoop.execute { - promise.succeed(probeReply) - } - self.promises.removeValue(forKey: id) - } - } - - // MARK: client-client apis - - // 发送数据包到其他session - func sendPacket(context ctx: SDLContext, session: Session, data: Data) { - let remoteAddress = session.natAddress - - var dataPacket = SDLData() - dataPacket.networkID = ctx.devAddr.networkID - dataPacket.srcMac = ctx.devAddr.mac - dataPacket.dstMac = session.dstMac - dataPacket.ttl = 255 - dataPacket.data = data - if let packet = try? dataPacket.serializedData() { - self.logger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug) - self.send(remoteAddress: remoteAddress, type: .data, data: packet) - } - } - - // 通过sn服务器转发数据包, data已经是加密过后的数据 - func forwardPacket(context ctx: SDLContext, dst_mac: Data, data: Data) { - let remoteAddress = ctx.config.stunSocketAddress - - var dataPacket = SDLData() - dataPacket.networkID = ctx.devAddr.networkID - dataPacket.srcMac = ctx.devAddr.mac - dataPacket.dstMac = dst_mac - dataPacket.ttl = 255 - dataPacket.data = data - - if let packet = try? dataPacket.serializedData() { - self.logger.log("[SDLContext] forward packet, remoteAddress: \(remoteAddress), data size: \(packet.count)", level: .debug) - self.send(remoteAddress: remoteAddress, type: .data, data: packet) - } - } - - // 发送register包 - func sendRegister(remoteAddress: SocketAddress, networkId: UInt32, srcMac: Data, dst_mac: Data) { - var register = SDLRegister() - register.networkID = networkId - register.srcMac = srcMac - register.dstMac = dst_mac - - if let packet = try? register.serializedData() { - self.logger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: srcMac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug) - self.send(remoteAddress: remoteAddress, type: .register, data: packet) - } - } - - // 回复registerAck - func sendRegisterAck(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) { - var registerAck = SDLRegisterAck() - registerAck.networkID = ctx.devAddr.networkID - registerAck.srcMac = ctx.devAddr.mac - registerAck.dstMac = dst_mac - - if let packet = try? registerAck.serializedData() { - self.logger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug) - self.send(remoteAddress: remoteAddress, type: .registerAck, data: packet) - } - } - - // 处理写入逻辑 - private func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) { - let message = UDPMessage(remoteAddress: remoteAddress, type: type, data: data) - self.writeContinuation.yield(message) - } - - //--MARK: 编解码器 - private static func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { - guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch packetType { - case .data: - let dataPacket = try SDLData(serializedBytes: bytes) - return .data(dataPacket) - case .register: - let registerPacket = try SDLRegister(serializedBytes: bytes) - return .register(registerPacket) - case .registerAck: - let registerAck = try SDLRegisterAck(serializedBytes: bytes) - return .registerAck(registerAck) - case .stunReply: - let stunReply = try SDLStunReply(serializedBytes: bytes) - return .stunReply(stunReply) - case .stunProbeReply: - let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) - return .stunProbeReply(stunProbeReply) - default: - return nil - } - } - - deinit { - try? self.group.syncShutdownGracefully() - self.writeContinuation.finish() - self.eventContinuation.finish() - } - -}