diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index de88c71..bec0271 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -14,22 +14,49 @@ import SwiftProtobuf final class SDLUDPHole: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope + private enum State { + case idle + case ready + case stopping + case stopped + } + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? + private var state: State = .idle + private var didFinishMessageStream: Bool = false + + private let closeStream: AsyncStream + private let closeContinuation: AsyncStream.Continuation public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)> private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation - private var isStopped: Bool = false - // 启动函数 init() throws { + let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) + self.closeStream = closeStream + self.closeContinuation = closeContinuation + let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048)) self.messageStream = stream self.messageContinuation = continuation } func start() throws -> SocketAddress { + switch self.state { + case .ready: + guard let channel = self.channel else { + preconditionFailure("SDLUDPHole is ready but channel is nil") + } + precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind") + return channel.localAddress! + case .stopping, .stopped: + preconditionFailure("SDLUDPHole cannot be restarted after stop") + case .idle: + break + } + let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in @@ -38,34 +65,41 @@ final class SDLUDPHole: ChannelInboundHandler { let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.channel = channel + self.state = .ready precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind") return channel.localAddress! } func waitClose() async throws { - try await self.channel?.closeFuture.get() + for await _ in self.closeStream { } } func stop() { - guard !self.isStopped else { + switch self.state { + case .stopping, .stopped: return + case .idle: + self.state = .stopped + self.finishMessageStream() + self.closeContinuation.finish() + return + case .ready: + self.state = .stopping } - self.isStopped = true - self.messageContinuation.finish() + self.finishMessageStream() self.channel?.close(promise: nil) self.channel = nil - try? self.group.syncShutdownGracefully() } // --MARK: ChannelInboundHandler delegate - - func channelActive(context: ChannelHandlerContext) { - - } - + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard case .ready = self.state else { + return + } + let envelope = unwrapInboundIn(data) var buffer = envelope.data @@ -87,18 +121,26 @@ final class SDLUDPHole: ChannelInboundHandler { } func channelInactive(context: ChannelHandlerContext) { + self.finishMessageStream() self.channel = nil + self.state = .stopped + self.closeContinuation.yield(()) + self.closeContinuation.finish() context.close(promise: nil) } func errorCaught(context: ChannelHandlerContext, error: any Error) { + self.finishMessageStream() self.channel = nil + self.state = .stopped + self.closeContinuation.yield(()) + self.closeContinuation.finish() context.close(promise: nil) } // MARK: 处理写入逻辑 func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { - guard let channel = self.channel, !self.isStopped else { + guard case .ready = self.state, let channel = self.channel else { return } @@ -157,8 +199,18 @@ final class SDLUDPHole: ChannelInboundHandler { } } + private func finishMessageStream() { + guard !self.didFinishMessageStream else { + return + } + + self.didFinishMessageStream = true + self.messageContinuation.finish() + } + deinit { self.stop() + try? self.group.syncShutdownGracefully() } }