diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index bec0271..ae1dd85 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -14,7 +14,7 @@ import SwiftProtobuf final class SDLUDPHole: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope - private enum State { + private enum State: Equatable { case idle case ready case stopping @@ -23,21 +23,15 @@ final class SDLUDPHole: ChannelInboundHandler { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? + private var closeFuture: EventLoopFuture? private var state: State = .idle private var didFinishMessageStream: Bool = false - private let closeStream: AsyncStream - private let closeContinuation: AsyncStream.Continuation - public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)> private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation // 启动函数 init() throws { - let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) - self.closeStream = closeStream - self.closeContinuation = closeContinuation - let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048)) self.messageStream = stream self.messageContinuation = continuation @@ -65,6 +59,7 @@ final class SDLUDPHole: ChannelInboundHandler { let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.channel = channel + self.closeFuture = channel.closeFuture self.state = .ready precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind") @@ -72,7 +67,15 @@ final class SDLUDPHole: ChannelInboundHandler { } func waitClose() async throws { - for await _ in self.closeStream { } + switch self.state { + case .idle: + return + case .ready, .stopping, .stopped: + guard let closeFuture = self.closeFuture else { + return + } + try await closeFuture.get() + } } func stop() { @@ -82,7 +85,6 @@ final class SDLUDPHole: ChannelInboundHandler { case .idle: self.state = .stopped self.finishMessageStream() - self.closeContinuation.finish() return case .ready: self.state = .stopping @@ -90,7 +92,6 @@ final class SDLUDPHole: ChannelInboundHandler { self.finishMessageStream() self.channel?.close(promise: nil) - self.channel = nil } // --MARK: ChannelInboundHandler delegate @@ -124,17 +125,14 @@ final class SDLUDPHole: ChannelInboundHandler { self.finishMessageStream() self.channel = nil self.state = .stopped - self.closeContinuation.yield(()) - self.closeContinuation.finish() - context.close(promise: nil) } func errorCaught(context: ChannelHandlerContext, error: any Error) { + SDLLogger.log("[SDLUDPHole] channel error: \(error)", for: .debug) self.finishMessageStream() - self.channel = nil - self.state = .stopped - self.closeContinuation.yield(()) - self.closeContinuation.finish() + if self.state != .stopped { + self.state = .stopping + } context.close(promise: nil) }