diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 957e56c..ef1569d 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -261,12 +261,13 @@ actor SDLContextActor { // 启动dns服务 let dnsClient = DNSCloudClient(host: self.config.serverIp, port: 15353) - dnsClient.start() + await dnsClient.start() SDLLogger.log("[SDLContext] dnsClient started") self.dnsClient = dnsClient + let packetFlow = dnsClient.packetFlow self.dnsWorker = Task.detached { // 处理事件流 - for await packet in dnsClient.packetFlow { + for await packet in packetFlow { if Task.isCancelled { break } @@ -285,7 +286,7 @@ actor SDLContextActor { await dnsLocalClient.start() SDLLogger.log("[SDLContext] dnsClient started") self.dnsLocalClient = dnsLocalClient - let packetFlow = await dnsLocalClient.packetFlow + let packetFlow = dnsLocalClient.packetFlow self.dnsLocalWorker = Task.detached { // 处理事件流 for await packet in packetFlow { @@ -379,7 +380,9 @@ actor SDLContextActor { self.quicClient?.stop() self.quicClient = nil - self.dnsClient?.stop() + if let dnsClient = self.dnsClient { + await dnsClient.stop() + } self.dnsWorker?.cancel() self.dnsWorker = nil self.dnsClient = nil @@ -765,7 +768,9 @@ actor SDLContextActor { // 如果是内部域名,则转发整个ip包的内容到云端服务器 if name.contains(self.config.networkAddress.networkDomain) { SDLLogger.log("[SDLContext] get cloud dns request: \(name)") - self.dnsClient?.forward(ipPacketData: packet.data) + if let dnsClient = self.dnsClient { + await dnsClient.forward(ipPacketData: packet.data) + } } // 如果开启了出口节点,则转发给出口节点 else if let exitNode = config.exitNode { diff --git a/Tun/Punchnet/DNS/DNSCloudClient.swift b/Tun/Punchnet/DNS/DNSCloudClient.swift index f1e1867..e48cae3 100644 --- a/Tun/Punchnet/DNS/DNSCloudClient.swift +++ b/Tun/Punchnet/DNS/DNSCloudClient.swift @@ -7,28 +7,49 @@ import Foundation import Network -final class DNSCloudClient { +actor DNSCloudClient { + private enum State { + case idle + case running + case stopped + } + + private var state: State = .idle private var connection: NWConnection? + private var receiveTask: Task? private let dnsServerAddress: NWEndpoint // 用于对外输出收到的 DNS 响应包 public let packetFlow: AsyncStream private let packetContinuation: AsyncStream.Continuation + private var didFinishPacketFlow = false // 用来处理关闭事件 - private let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self) + private let closeStream: AsyncStream + private let closeContinuation: AsyncStream.Continuation + private var didFinishCloseStream = false /// - Parameter host: 你的 sn-server 地址 (如 "8.8.8.8") /// - Parameter port: 端口 (如 53) init(host: String, port: UInt16 ) { self.dnsServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: port)) - let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) - self.packetFlow = stream - self.packetContinuation = continuation + let (packetStream, packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .bufferingNewest(256)) + self.packetFlow = packetStream + self.packetContinuation = packetContinuation + + let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) + self.closeStream = closeStream + self.closeContinuation = closeContinuation } func start() { + guard case .idle = self.state else { + return + } + + self.state = .running + // 1. 配置参数:这是解决环路的关键 let parameters = NWParameters.udp @@ -42,18 +63,8 @@ final class DNSCloudClient { self.connection = connection connection.stateUpdateHandler = { [weak self] state in - switch state { - case .ready: - SDLLogger.log("[DNSClient] Connection ready", for: .debug) - self?.receiveLoop() // 开始循环接收数据 - case .failed(let error): - SDLLogger.log("[DNSClient] Connection failed: \(error)", for: .debug) - self?.stop() - case .cancelled: - self?.packetContinuation.finish() - self?.closeContinuation.finish() - default: - break + Task { + await self?.handleConnectionStateUpdate(state, for: connection) } } @@ -62,26 +73,34 @@ final class DNSCloudClient { } public func waitClose() async { - for await _ in closeStream { } + for await _ in self.closeStream { } } /// 接收数据的递归循环 - private func receiveLoop() { - connection?.receiveMessage { [weak self] content, _, isComplete, error in - if let data = content, !data.isEmpty { - // 将收到的 DNS 响应写回 AsyncStream - self?.packetContinuation.yield(data) + private static func makeReceiveStream(for connection: NWConnection) -> AsyncStream { + return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in + func receiveNext() { + connection.receiveMessage { content, _, _, error in + if let data = content, !data.isEmpty { + // 将收到的 DNS 响应写回 AsyncStream + continuation.yield(data) + } + + if error == nil && connection.state == .ready { + receiveNext() // 继续监听下一个包 + } else { + continuation.finish() + } + } } - if error == nil && self?.connection?.state == .ready { - self?.receiveLoop() // 继续监听下一个包 - } + receiveNext() } } /// 发送 DNS 查询包(由 TUN 拦截到的原始 IP 包数据) func forward(ipPacketData: Data) { - guard let connection = self.connection, connection.state == .ready else { + guard case .running = self.state, let connection = self.connection, connection.state == .ready else { return } @@ -93,13 +112,96 @@ final class DNSCloudClient { } func stop() { - connection?.cancel() - connection = nil + guard self.state != .stopped else { + return + } + + self.state = .stopped + self.receiveTask?.cancel() + self.receiveTask = nil + self.connection?.cancel() + self.connection = nil + self.finishPacketFlowIfNeeded() + self.finishCloseStreamIfNeeded() + } + + private func handleConnectionStateUpdate(_ state: NWConnection.State, for connection: NWConnection) { + guard case .running = self.state else { + return + } + + switch state { + case .ready: + SDLLogger.log("[DNSClient] Connection ready", for: .debug) + self.startReceiveTask(for: connection) + case .failed(let error): + SDLLogger.log("[DNSClient] Connection failed: \(error)", for: .debug) + self.stop() + case .cancelled: + self.stop() + default: + break + } + } + + private func startReceiveTask(for connection: NWConnection) { + guard self.receiveTask == nil else { + return + } + + let stream = Self.makeReceiveStream(for: connection) + self.receiveTask = Task { [weak self] in + guard let self else { + return + } + + for await data in stream { + await self.handleReceivedPacket(data) + } + + await self.didFinishReceiving(for: connection) + } + } + + private func handleReceivedPacket(_ data: Data) { + guard case .running = self.state else { + return + } + + self.packetContinuation.yield(data) + } + + private func didFinishReceiving(for connection: NWConnection) { + guard case .running = self.state else { + return + } + + if self.connection === connection, connection.state != .ready { + self.stop() + } else { + self.receiveTask = nil + } + } + + private func finishPacketFlowIfNeeded() { + guard !self.didFinishPacketFlow else { + return + } + + self.didFinishPacketFlow = true + self.packetContinuation.finish() + } + + private func finishCloseStreamIfNeeded() { + guard !self.didFinishCloseStream else { + return + } + + self.didFinishCloseStream = true + self.closeContinuation.finish() } deinit { - stop() + self.connection?.cancel() } - } -