From 57dd0d9538f840e5a202197f21756a3263456418 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Mon, 2 Feb 2026 12:07:29 +0800 Subject: [PATCH] fix context --- Tun/PacketTunnelProvider.swift | 26 +-- .../Actors/SDLContextSupervisor.swift | 44 +++++ ...SDLContext.swift => SDLContextActor.swift} | 181 ++++++++++-------- Tun/Punchnet/SDLDNSClient.swift | 8 +- Tun/Punchnet/SDLNoticeClient.swift | 18 +- Tun/Punchnet/SDLUDPHole.swift | 7 +- 6 files changed, 168 insertions(+), 116 deletions(-) create mode 100644 Tun/Punchnet/Actors/SDLContextSupervisor.swift rename Tun/Punchnet/{SDLContext.swift => SDLContextActor.swift} (83%) diff --git a/Tun/PacketTunnelProvider.swift b/Tun/PacketTunnelProvider.swift index bc341c3..7f959e4 100644 --- a/Tun/PacketTunnelProvider.swift +++ b/Tun/PacketTunnelProvider.swift @@ -13,47 +13,41 @@ enum TunnelError: Error { } class PacketTunnelProvider: NEPacketTunnelProvider { - var context: SDLContext? + var contextSupervisor: SDLContextSupervisor? private var rootTask: Task? override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { - let logger = SDLLogger(level: .debug) - // host: "192.168.0.101", port: 1265 guard let options, let config = SDLConfiguration.parse(options: options) else { completionHandler(TunnelError.invalidConfiguration) return } - + // 如果当前在运行状态,不允许重复请求 - guard self.context == nil else { + guard self.contextSupervisor == nil else { completionHandler(TunnelError.invalidContext) return } - + // 加密算法 let rsaCipher = try! CCRSACipher(keySize: 1024) let aesChiper = CCAESChiper() + let logger = SDLLogger(level: .debug) self.rootTask = Task { - do { - self.context = SDLContext(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger) - try await self.context?.start() - } catch let err { - logger.log("[PacketTunnelProvider] exit with error: \(err)") - exit(-1) - } + self.contextSupervisor = SDLContextSupervisor() + await self.contextSupervisor?.start(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger) + completionHandler(nil) } - completionHandler(nil) } override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { // Add code here to start the process of stopping the tunnel. self.rootTask?.cancel() Task { - await self.context?.stop() + await self.contextSupervisor?.stop() } - self.context = nil + self.contextSupervisor = nil self.rootTask = nil completionHandler() diff --git a/Tun/Punchnet/Actors/SDLContextSupervisor.swift b/Tun/Punchnet/Actors/SDLContextSupervisor.swift new file mode 100644 index 0000000..cf0c648 --- /dev/null +++ b/Tun/Punchnet/Actors/SDLContextSupervisor.swift @@ -0,0 +1,44 @@ +// +// SDLContextSupervisor.swift +// Tun +// +// Created by 安礼成 on 2026/2/2. +// + +import Foundation +import NetworkExtension + +actor SDLContextSupervisor { + private var context: SDLContextActor? + private var tasks: [Task] = [] + + public func start(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) async { + let context = SDLContextActor(provider: provider, config: config, rsaCipher: rsaCipher, aesCipher: aesCipher, logger: logger) + self.context = context + + tasks.append(spawnLoop { try await context.startNoticeClient()}) + tasks.append(spawnLoop { try await context.startUDPHole()}) + tasks.append(spawnLoop { try await context.startDnsClient()}) + tasks.append(spawnLoop { try await context.startMonitor()}) + } + + func stop() { + tasks.forEach {$0.cancel()} + tasks.removeAll() + } + + private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task { + return Task.detached { + while !Task.isCancelled { + do { + try await body() + } catch is CancellationError { + break + } catch { + try? await Task.sleep(nanoseconds: 2_000_000_000) + } + } + } + } + +} diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContextActor.swift similarity index 83% rename from Tun/Punchnet/SDLContext.swift rename to Tun/Punchnet/SDLContextActor.swift index d87298d..7ea6d2a 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContextActor.swift @@ -13,7 +13,7 @@ import NIOCore /* 1. 处理rsa的加解密逻辑 */ -actor SDLContext { +actor SDLContextActor { nonisolated let config: SDLConfiguration // nat的网络类型 @@ -57,7 +57,6 @@ actor SDLContext { public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger - self.config = config self.rsaCipher = rsaCipher self.aesCipher = aesCipher @@ -67,29 +66,72 @@ actor SDLContext { self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger) } - public func start() async throws { - // 启动udp服务器 - self.udpHole = try SDLUDPHole(logger: self.logger) - try self.udpHole?.start() - self.logger.log("[SDLContext] udpHole started") - - // 启动dns服务 - let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) - self.dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) - try self.dnsClient?.start() - self.logger.log("[SDLContext] dnsClient started") - + public func startNoticeClient() async throws { // 启动noticeClient self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) - try self.noticeClient?.start() self.logger.log("[SDLContext] noticeClient started") - + try await self.noticeClient?.waitClose() + } + + public func startMonitor() async throws { // 启动monitor - self.monitor = SDLNetworkMonitor() - self.monitor?.start() + let monitor = SDLNetworkMonitor() + monitor.start() self.logger.log("[SDLContext] monitor started") + self.monitor = monitor + + for await event in monitor.eventStream { + switch event { + case .changed: + // 需要重新探测网络的nat类型 + //self.natType = await self.getNatType() + self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) + case .unreachable: + self.logger.log("didNetworkPathUnreachable", level: .warning) + } + } + } + + public func startDnsClient() async throws { + // 启动dns服务 + let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) + let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) + let channel = try dnsClient.start() + self.logger.log("[SDLContext] dnsClient started") + self.dnsClient = dnsClient + + try await withThrowingTaskGroup(of: Void.self) {group in + group.addTask { + // 处理事件流 + for await packet in dnsClient.packetFlow { + try Task.checkCancellation() + let nePacket = NEPacket(data: packet, protocolFamily: 2) + self.providerAdapter.writePackets(packets: [nePacket]) + } + } + + group.addTask { + try await channel.closeFuture.get() + } + + try await group.next() + self.logger.log("[SDLContext] taskGroup cancel") + group.cancelAll() + } + } + + public func startUDPHole() async throws { + // 启动udp服务器 + let udpHole = try SDLUDPHole(logger: self.logger) + let channel = try udpHole.start() + self.logger.log("[SDLContext] udpHole started") + self.udpHole = udpHole try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + try await channel.closeFuture.get() + } + // 处理UDP的事件流 group.addTask { while true { @@ -102,86 +144,51 @@ actor SDLContext { // 处理event事件流 group.addTask { - if let eventStream = await self.udpHole?.eventStream { - for try await event in eventStream { - try Task.checkCancellation() - switch event { - case .ready: - await self.handleUDPHoleReady() - case .closed: - () - } + for try await event in udpHole.eventStream { + try Task.checkCancellation() + switch event { + case .ready: + await self.handleUDPHoleReady() + case .closed: + () } } } // 处理数据流 group.addTask { - if let dataStream = await self.udpHole?.dataStream { - for try await data in dataStream { - try Task.checkCancellation() - Task { - try await self.handleData(data: data) - } - } + for try await data in udpHole.dataStream { + try Task.checkCancellation() + try await self.handleData(data: data) } } // 处理signal信号流 group.addTask { - if let signalStream = await self.udpHole?.signalStream { - for try await(remoteAddress, signal) in signalStream { - try Task.checkCancellation() - Task { - switch signal { - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) - case .event(let event): - try await self.handleEvent(event: event) - case .stunProbeReply(let probeReply): - await self.proberActor?.handleProbeReply(reply: probeReply) - case .register(let register): - try await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - } - } + for try await(remoteAddress, signal) in udpHole.signalStream { + try Task.checkCancellation() + switch signal { + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) + case .event(let event): + try await self.handleEvent(event: event) + case .stunProbeReply(let probeReply): + await self.proberActor?.handleProbeReply(reply: probeReply) + case .register(let register): + try await self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) } } } - // 处理DNS的事件流 - group.addTask { - if let packetFlow = await self.dnsClient?.packetFlow { - for await packet in packetFlow { - let nePacket = NEPacket(data: packet, protocolFamily: 2) - self.providerAdapter.writePackets(packets: [nePacket]) - } - } - } - - // 处理Monitor的事件流 - group.addTask { - for await event in await self.monitor!.eventStream { - switch event { - case .changed: - // 需要重新探测网络的nat类型 - //self.natType = await self.getNatType() - self.logger.log("didNetworkPathChanged, nat type is: \(await self.natType)", level: .info) - case .unreachable: - self.logger.log("didNetworkPathUnreachable", level: .warning) - } - } - } - - if let _ = try await group.next() { - self.logger.log("[SDLContext] taskGroup cancel") - group.cancelAll() - } + try await group.next() + group.cancelAll() + self.logger.log("[SDLContext] taskGroup cancel") } } @@ -409,11 +416,15 @@ actor SDLContext { // 开启新的任务 self.readTask = Task(priority: .high) { - repeat { + while true { + if Task.isCancelled { + return + } + let packets = await self.providerAdapter.readPackets() let ipPackets = packets.compactMap { IPPacket($0) } await self.batchProcessPackets(batchSize: 20, packets: ipPackets) - } while true + } } } diff --git a/Tun/Punchnet/SDLDNSClient.swift b/Tun/Punchnet/SDLDNSClient.swift index 1956fa9..c5dff28 100644 --- a/Tun/Punchnet/SDLDNSClient.swift +++ b/Tun/Punchnet/SDLDNSClient.swift @@ -29,15 +29,18 @@ final class SDLDNSClient: ChannelInboundHandler { (self.packetFlow, self.packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) } - func start() throws { + func start() throws -> Channel { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in channel.pipeline.addHandler(self) } - self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() + let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[DNSClient] started", level: .debug) + self.channel = channel + + return channel } // --MARK: ChannelInboundHandler delegate @@ -79,7 +82,6 @@ final class SDLDNSClient: ChannelInboundHandler { } extension SDLDNSClient { - struct Helper { static let dnsServer: String = "100.100.100.100" // dns请求包的目标地址 diff --git a/Tun/Punchnet/SDLNoticeClient.swift b/Tun/Punchnet/SDLNoticeClient.swift index eb1585a..06b991b 100644 --- a/Tun/Punchnet/SDLNoticeClient.swift +++ b/Tun/Punchnet/SDLNoticeClient.swift @@ -21,7 +21,7 @@ import NIOPosix // 处理和sn-server服务器之间的通讯 final class SDLNoticeClient { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - private var channel: Channel? + private var channel: Channel private let logger: SDLLogger private let noticePort: Int @@ -29,9 +29,7 @@ final class SDLNoticeClient { init(noticePort: Int, logger: SDLLogger) throws { self.logger = logger self.noticePort = noticePort - } - - func start() throws { + let bootstrap = DatagramBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in @@ -44,19 +42,19 @@ final class SDLNoticeClient { // 处理写入逻辑 func send(data: Data) { - guard let channel = self.channel else { - return - } - if let remoteAddress = try? SocketAddress(ipAddress: "127.0.0.1", port: noticePort) { let buf = channel.allocator.buffer(bytes: data) let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buf) - channel.eventLoop.execute { - channel.writeAndFlush(envelope, promise: nil) + self.channel.eventLoop.execute { + self.channel.writeAndFlush(envelope, promise: nil) } } } + func waitClose() async throws { + try await self.channel.closeFuture.get() + } + deinit { try? self.group.syncShutdownGracefully() } diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index 34aeed3..cb649ac 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -49,15 +49,18 @@ final class SDLUDPHole: ChannelInboundHandler { (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: HoleEvent.self, bufferingPolicy: .unbounded) } - func start() throws { + func start() throws -> Channel { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in channel.pipeline.addHandler(self) } - self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() + let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[UDPHole] started", level: .debug) + self.channel = channel + + return channel } // --MARK: ChannelInboundHandler delegate