diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 0eaed43..c367ff9 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -144,28 +144,24 @@ actor SDLContextActor { SDLLogger.shared.log("[SDLContext] start quic client ready") self.quicWorker = Task.detached { - let reader = quicClient.getReader() - while let frame = try? await reader.next() { - if let message = SDLQUICCodec.decode(frame: frame) { - switch message { - case .welcome(let welcome): - SDLLogger.shared.log("[SDLContext] quic welcome: \(welcome)") - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - SDLLogger.shared.log("[SDLContext] peer message: \(peerInfo)") - case .event(let event): - await self.handleEvent(event: event) - case .policyReponse(let policyResponse): - // 处理权限的请求问题 - await self.identifyStore.apply(policyResponse: policyResponse) - } - + let stream = await quicClient.messageStream() + for await message in stream { + switch message { + case .welcome(let welcome): + SDLLogger.shared.log("[SDLContext] quic welcome: \(welcome)") + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + SDLLogger.shared.log("[SDLContext] peer message: \(peerInfo)") + case .event(let event): + await self.handleEvent(event: event) + case .policyReponse(let policyResponse): + // 处理权限的请求问题 + await self.identifyStore.apply(policyResponse: policyResponse) } } - } self.quicClient = quicClient diff --git a/Tun/Punchnet/Actors/SDLQuicClient.swift b/Tun/Punchnet/Actors/SDLQuicClient.swift index 71778c7..4068542 100644 --- a/Tun/Punchnet/Actors/SDLQuicClient.swift +++ b/Tun/Punchnet/Actors/SDLQuicClient.swift @@ -19,60 +19,18 @@ enum SDLQUICError: Error { } final class SDLQUICClient { - private let transport: SDLQUICTransport + private let connection: NWConnection private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全 private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self) private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self) - init(host: String, port: UInt16) { - self.transport = SDLQUICTransport(host: host, port: port) - } - - func start() { - self.transport.start(queue: self.queue) { event in - switch event { - case .ready: - self.readyCont.yield() - self.readyCont.finish() - case .failed(_), .cancelled: - self.closeCont.yield() - self.closeCont.finish() - } - } - } - - func getReader() -> SDLQUICReader { - return transport.getReader() - } - - func send(type: SDLPacketType, data: Data) { - self.transport.send(type: type, data: data) - } - - func waitReady() async throws { - for await _ in readyStream {} - } - - func waitClose() async { - for await _ in closeStream {} - } - - func stop() { - self.transport.stop() - } - -} - -final class SDLQUICTransport { enum Event { case ready case failed(Error) case cancelled } - private let connection: NWConnection - init(host: String, port: UInt16) { let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"]) @@ -90,24 +48,31 @@ final class SDLQUICTransport { self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params) } - func start(queue: DispatchQueue, onEvent: @escaping (Event) -> Void) { + func start() { SDLLogger.shared.log("[SDLQUICTransport] call start") connection.stateUpdateHandler = { state in SDLLogger.shared.log("[SDLQUICTransport] new state: \(state)") switch state { - case .ready: onEvent(.ready) - case .failed(let e): onEvent(.failed(e)) - case .cancelled: onEvent(.cancelled) - default: break + case .ready: + self.readyCont.yield() + self.readyCont.finish() + case .failed(_), .cancelled: + self.closeCont.yield() + self.closeCont.finish() + default: + () } } - connection.start(queue: queue) + connection.start(queue: self.queue) } - func getReader() -> SDLQUICReader { - return SDLQUICReader(connection: self.connection) + func messageStream() async -> AsyncStream { + let reader = SDLQUICReader(connection: self.connection) + await reader.start() + + return await reader.messageStream } - + func send(type: SDLPacketType, data: Data) { var len = UInt16(data.count + 1).bigEndian @@ -117,75 +82,70 @@ final class SDLQUICTransport { connection.send(content: packet, completion: .contentProcessed { _ in }) } - + + func waitReady() async throws { + for await _ in readyStream {} + } + + func waitClose() async { + for await _ in closeStream {} + } + func stop() { - connection.cancel() + self.connection.cancel() } } -actor SDLQUICReader: AsyncIteratorProtocol { - - typealias Element = ByteBuffer - +actor SDLQUICReader { private let allocator = ByteBufferAllocator() - private var buffer: ByteBuffer - // 用来缓存包,有可能一次读取到多个包 - private var packets: [ByteBuffer] = [] // 单个包最大64K private let maxPacketSize: Int // 最大缓冲区区为2M private let maxBufferSize: Int - // 是否已经读取完成 - private var isComplete: Bool = false + public var messageStream: AsyncStream + private let messageCont: AsyncStream.Continuation + private var readTask: Task? private let connection: NWConnection init(connection: NWConnection, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) { self.connection = connection self.maxBufferSize = maxBufferSize self.maxPacketSize = maxPacketSize - self.buffer = allocator.buffer(capacity: maxBufferSize) + (self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self) } - func next() async throws -> ByteBuffer? { - // 如果还有包 - if !self.packets.isEmpty { - return self.packets.removeFirst() - } - - // 尝试读取,并返回 - self.packets = try await self.readPacket() - if !self.packets.isEmpty { - return self.packets.removeFirst() - } else { - return nil - } - } - - private func readPacket() async throws -> [ByteBuffer] { - while true { - if self.isComplete { - return try parseFrames() - } - - let (isComplete, data) = try await readOnce() - self.isComplete = isComplete - - if !data.isEmpty { - buffer.writeBytes(data) - // 尝试解析出完整的包 - let packets = try parseFrames() - if !packets.isEmpty { - return packets + func start() { + self.readTask = Task { + var buffer: ByteBuffer = allocator.buffer(capacity: self.maxBufferSize) + do { + while !Task.isCancelled { + let (isComplete, data) = try await self.readOnce() + if !data.isEmpty { + buffer.writeBytes(data) + let frames = try parseFrames(buffer: &buffer) + for frame in frames { + if let message = SDLQUICCodec.decode(frame: frame) { + self.messageCont.yield(message) + } + } + } + + if isComplete { + break + } } + self.messageCont.finish() + } catch { + self.messageCont.finish() } } } // 尝试解析数据 - private func parseFrames() throws -> [ByteBuffer] { + private func parseFrames(buffer: inout ByteBuffer) throws -> [ByteBuffer] { guard buffer.readableBytes >= 2 else { return [] } @@ -236,6 +196,11 @@ actor SDLQUICReader: AsyncIteratorProtocol { } } + deinit { + self.readTask?.cancel() + self.messageCont.finish() + } + } struct SDLQUICCodec { @@ -314,4 +279,3 @@ struct SDLQUICCodec { } } } -