From 55ea1cd09d741ed02650280a914a88046ac463aa Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Tue, 3 Feb 2026 16:07:30 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E5=AE=9A=E6=97=B6=E5=99=A8?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Tun/PacketTunnelProvider.swift | 21 +- .../Actors/SDLContextSupervisor.swift | 44 --- Tun/Punchnet/SDLAsyncTimerStream.swift | 29 ++ Tun/Punchnet/SDLContextActor.swift | 255 +++++++++++------- Tun/Punchnet/SDLDNSClient.swift | 8 +- Tun/Punchnet/SDLUDPHole.swift | 8 +- 6 files changed, 205 insertions(+), 160 deletions(-) delete mode 100644 Tun/Punchnet/Actors/SDLContextSupervisor.swift create mode 100644 Tun/Punchnet/SDLAsyncTimerStream.swift diff --git a/Tun/PacketTunnelProvider.swift b/Tun/PacketTunnelProvider.swift index 7f959e4..ceeaca1 100644 --- a/Tun/PacketTunnelProvider.swift +++ b/Tun/PacketTunnelProvider.swift @@ -13,7 +13,7 @@ enum TunnelError: Error { } class PacketTunnelProvider: NEPacketTunnelProvider { - var contextSupervisor: SDLContextSupervisor? + var contextActor: SDLContextActor? private var rootTask: Task? override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { @@ -24,7 +24,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { } // 如果当前在运行状态,不允许重复请求 - guard self.contextSupervisor == nil else { + guard self.contextActor == nil else { completionHandler(TunnelError.invalidContext) return } @@ -35,22 +35,23 @@ class PacketTunnelProvider: NEPacketTunnelProvider { let logger = SDLLogger(level: .debug) self.rootTask = Task { - self.contextSupervisor = SDLContextSupervisor() - await self.contextSupervisor?.start(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger) + self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger) + await self.contextActor?.start() completionHandler(nil) } } override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { // Add code here to start the process of stopping the tunnel. - self.rootTask?.cancel() Task { - await self.contextSupervisor?.stop() + await self.contextActor?.stop() + self.contextActor = nil + + self.rootTask?.cancel() + self.rootTask = nil + + completionHandler() } - self.contextSupervisor = nil - self.rootTask = nil - - completionHandler() } override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { diff --git a/Tun/Punchnet/Actors/SDLContextSupervisor.swift b/Tun/Punchnet/Actors/SDLContextSupervisor.swift deleted file mode 100644 index cf0c648..0000000 --- a/Tun/Punchnet/Actors/SDLContextSupervisor.swift +++ /dev/null @@ -1,44 +0,0 @@ -// -// SDLContextSupervisor.swift -// Tun -// -// Created by 安礼成 on 2026/2/2. -// - -import Foundation -import NetworkExtension - -actor SDLContextSupervisor { - private var context: SDLContextActor? - private var tasks: [Task] = [] - - public func start(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) async { - let context = SDLContextActor(provider: provider, config: config, rsaCipher: rsaCipher, aesCipher: aesCipher, logger: logger) - self.context = context - - tasks.append(spawnLoop { try await context.startNoticeClient()}) - tasks.append(spawnLoop { try await context.startUDPHole()}) - tasks.append(spawnLoop { try await context.startDnsClient()}) - tasks.append(spawnLoop { try await context.startMonitor()}) - } - - func stop() { - tasks.forEach {$0.cancel()} - tasks.removeAll() - } - - private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task { - return Task.detached { - while !Task.isCancelled { - do { - try await body() - } catch is CancellationError { - break - } catch { - try? await Task.sleep(nanoseconds: 2_000_000_000) - } - } - } - } - -} diff --git a/Tun/Punchnet/SDLAsyncTimerStream.swift b/Tun/Punchnet/SDLAsyncTimerStream.swift new file mode 100644 index 0000000..1bec606 --- /dev/null +++ b/Tun/Punchnet/SDLAsyncTimerStream.swift @@ -0,0 +1,29 @@ +// +// SDLAsyncTimerStream.swift +// Tun +// +// Created by 安礼成 on 2026/2/3. +// + +import Foundation + +class SDLAsyncTimerStream { + let timer: DispatchSourceTimer + + init() { + self.timer = DispatchSource.makeTimerSource(queue: .global()) + } + + func start(_ cont: AsyncStream.Continuation) { + timer.schedule(deadline: .now(), repeating: .seconds(5)) + timer.setEventHandler { + cont.yield() + } + timer.resume() + } + + deinit { + self.timer.cancel() + } + +} diff --git a/Tun/Punchnet/SDLContextActor.swift b/Tun/Punchnet/SDLContextActor.swift index 8054fc8..939a9cd 100644 --- a/Tun/Punchnet/SDLContextActor.swift +++ b/Tun/Punchnet/SDLContextActor.swift @@ -29,11 +29,14 @@ actor SDLContextActor { nonisolated let rsaCipher: RSACipher // 依赖的变量 - var udpHole: SDLUDPHole? + private var udpHole: SDLUDPHole? + private var udpHoleWorkers: [Task]? + nonisolated let providerAdapter: SDLTunnelProviderAdapter var puncherActor: SDLPuncherActor? // dns的client对象 - var dnsClient: SDLDNSClient? + private var dnsClient: SDLDNSClient? + private var dnsWorker: Task? // 网络探测对象 var proberActor: SDLNATProberActor? @@ -46,6 +49,7 @@ actor SDLContextActor { // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? + private var monitorWorker: Task? // 内部socket通讯 private var noticeClient: SDLNoticeClient? @@ -55,6 +59,9 @@ actor SDLContextActor { nonisolated private let logger: SDLLogger + // 处理内部的需要长时间运行的任务 + private var loopChildWorkers: [Task] = [] + public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger self.config = config @@ -66,126 +73,173 @@ actor SDLContextActor { self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger) } - public func startNoticeClient() async throws { - // 启动noticeClient - self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) - self.logger.log("[SDLContext] noticeClient started") - try await self.noticeClient?.waitClose() + public func start() { + self.startMonitor() + + self.loopChildWorkers.append(spawnLoop { + let noticeClient = try self.startNoticeClient() + try await noticeClient.waitClose() + self.logger.log("[SDLContext] noticeClient closed!!!!") + }) + + self.loopChildWorkers.append(spawnLoop { + let dnsClient = try await self.startDnsClient() + try await dnsClient.waitClose() + self.logger.log("[SDLContext] dns closed!!!!") + }) + + self.loopChildWorkers.append(spawnLoop { + let udpHole = try await self.startUDPHole() + try await udpHole.waitClose() + self.logger.log("[SDLContext] udp closed!!!!") + }) } - public func startMonitor() async throws { + private func startNoticeClient() throws -> SDLNoticeClient { + // 启动noticeClient + let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger) + self.logger.log("[SDLContext] noticeClient started") + self.noticeClient = noticeClient + + return noticeClient + } + + private func startMonitor() { + self.monitorWorker?.cancel() + self.monitorWorker = nil + // 启动monitor let monitor = SDLNetworkMonitor() monitor.start() self.logger.log("[SDLContext] monitor started") self.monitor = monitor - for await event in monitor.eventStream { - switch event { - case .changed: - // 需要重新探测网络的nat类型 - //self.natType = await self.getNatType() - self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) - case .unreachable: - self.logger.log("didNetworkPathUnreachable", level: .warning) + self.monitorWorker = Task { + for await event in monitor.eventStream { + switch event { + case .changed: + // 需要重新探测网络的nat类型 + //self.natType = await self.getNatType() + self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) + case .unreachable: + self.logger.log("didNetworkPathUnreachable", level: .warning) + } } } } - public func startDnsClient() async throws { + private func startDnsClient() async throws -> SDLDNSClient { + self.dnsWorker?.cancel() + self.dnsWorker = nil + // 启动dns服务 let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) - let channel = try dnsClient.start() + try dnsClient.start() self.logger.log("[SDLContext] dnsClient started") self.dnsClient = dnsClient - - try await withThrowingTaskGroup(of: Void.self) {group in - group.addTask { - // 处理事件流 - for await packet in dnsClient.packetFlow { - try Task.checkCancellation() - let nePacket = NEPacket(data: packet, protocolFamily: 2) - self.providerAdapter.writePackets(packets: [nePacket]) + self.dnsWorker = Task { + // 处理事件流 + for await packet in dnsClient.packetFlow { + if Task.isCancelled { + break } + let nePacket = NEPacket(data: packet, protocolFamily: 2) + self.providerAdapter.writePackets(packets: [nePacket]) } - - group.addTask { - try await channel.closeFuture.get() - } - - try await group.next() - self.logger.log("[SDLContext] taskGroup cancel") - group.cancelAll() } + + return dnsClient } - public func startUDPHole() async throws { + private func startUDPHole() async throws -> SDLUDPHole { + self.udpHoleWorkers?.forEach {$0.cancel()} + self.udpHoleWorkers = nil + // 启动udp服务器 let udpHole = try SDLUDPHole(logger: self.logger) - let channel = try udpHole.start() + try udpHole.start() self.logger.log("[SDLContext] udpHole started") self.udpHole = udpHole await udpHole.channelIsActived() await self.handleUDPHoleReady() - - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await channel.closeFuture.get() - } + + // 处理心跳逻辑 + let pingTask = Task.detached { + let (stream, cont) = AsyncStream.makeStream(of: Void.self) + let timerStream = SDLAsyncTimerStream() + timerStream.start(cont) - // 处理UDP的事件流 - group.addTask { - while true { - try Task.checkCancellation() - try await Task.sleep(for: .seconds(5)) - await self.sendStunRequest() + for await _ in stream { + if Task.isCancelled { + break } + self.logger.log("[SDLContext] will do stunRequest22") + await self.sendStunRequest() + self.logger.log("[SDLContext] will do stunRequest44") } - // 处理数据流 - group.addTask { - for try await data in udpHole.dataStream { - try Task.checkCancellation() - try await self.handleData(data: data) - } - } - - // 处理signal信号流 - group.addTask { - for try await(remoteAddress, signal) in udpHole.signalStream { - try Task.checkCancellation() - switch signal { - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) - case .event(let event): - try await self.handleEvent(event: event) - case .stunProbeReply(let probeReply): - await self.proberActor?.handleProbeReply(reply: probeReply) - case .register(let register): - try await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - } - } - } - - try await group.next() - group.cancelAll() - self.logger.log("[SDLContext] taskGroup cancel") + self.logger.log("[SDLContext] will do stunRequest55") } + + // 处理数据流 + let dataTask = Task { + for await data in udpHole.dataStream { + if Task.isCancelled { + break + } + try? self.handleData(data: data) + } + } + + // 处理控制信号 + let signalTask = Task { + for await(remoteAddress, signal) in udpHole.signalStream { + if Task.isCancelled { + break + } + + switch signal { + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) + case .event(let event): + try? self.handleEvent(event: event) + case .stunProbeReply(let probeReply): + await self.proberActor?.handleProbeReply(reply: probeReply) + case .register(let register): + try? self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + } + } + } + + self.udpHoleWorkers = [pingTask, dataTask, signalTask] + + return udpHole } + // 处理context的停止问题 public func stop() async { - self.udpHole = nil - self.noticeClient = nil - + self.loopChildWorkers.forEach { $0.cancel() } + self.loopChildWorkers.removeAll() + + self.udpHoleWorkers?.forEach { $0.cancel() } + self.udpHoleWorkers = nil + + self.dnsWorker?.cancel() + self.dnsWorker = nil + + self.monitorWorker?.cancel() + self.monitorWorker = nil + self.readTask?.cancel() + self.readTask = nil } private func setNatType(natType: SDLNATProberActor.NatType) { @@ -227,7 +281,7 @@ actor SDLContextActor { } } - private func sendStunRequest() async { + private func sendStunRequest() { var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId @@ -418,21 +472,8 @@ actor SDLContextActor { let packets = await self.providerAdapter.readPackets() let ipPackets = packets.compactMap { IPPacket($0) } - await self.batchProcessPackets(batchSize: 20, packets: ipPackets) - } - } - } - - // 批量分发ip数据包 - private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async { - for startIndex in stride(from: 0, to: packets.count, by: batchSize) { - let endIndex = Swift.min(startIndex + batchSize, packets.count) - let chunkPackets = packets[startIndex.. Void) -> Task { + return Task.detached { + while !Task.isCancelled { + do { + try await body() + } catch is CancellationError { + break + } catch { + try? await Task.sleep(nanoseconds: 2_000_000_000) + } + } + } + } + deinit { self.udpHole = nil self.dnsClient = nil diff --git a/Tun/Punchnet/SDLDNSClient.swift b/Tun/Punchnet/SDLDNSClient.swift index c5dff28..e001d14 100644 --- a/Tun/Punchnet/SDLDNSClient.swift +++ b/Tun/Punchnet/SDLDNSClient.swift @@ -29,7 +29,7 @@ final class SDLDNSClient: ChannelInboundHandler { (self.packetFlow, self.packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) } - func start() throws -> Channel { + func start() throws { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in @@ -39,8 +39,10 @@ final class SDLDNSClient: ChannelInboundHandler { let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[DNSClient] started", level: .debug) self.channel = channel - - return channel + } + + func waitClose() async throws { + try await self.channel?.closeFuture.get() } // --MARK: ChannelInboundHandler delegate diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index ccc5c77..31c6433 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -49,7 +49,7 @@ final class SDLUDPHole: ChannelInboundHandler { (self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded) } - func start() throws -> Channel { + func start() throws { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in @@ -59,8 +59,6 @@ final class SDLUDPHole: ChannelInboundHandler { let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[UDPHole] started", level: .debug) self.channel = channel - - return channel } func channelIsActived() async { @@ -73,6 +71,10 @@ final class SDLUDPHole: ChannelInboundHandler { } } + func waitClose() async throws { + try await self.channel?.closeFuture.get() + } + // --MARK: ChannelInboundHandler delegate func channelActive(context: ChannelHandlerContext) {