diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index c367ff9..13a65b1 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -144,8 +144,7 @@ actor SDLContextActor { SDLLogger.shared.log("[SDLContext] start quic client ready") self.quicWorker = Task.detached { - let stream = await quicClient.messageStream() - for await message in stream { + for await message in quicClient.messageStream { switch message { case .welcome(let welcome): SDLLogger.shared.log("[SDLContext] quic welcome: \(welcome)") diff --git a/Tun/Punchnet/Actors/SDLQuicClient.swift b/Tun/Punchnet/Actors/SDLQuicClient.swift index 44d4672..f53ffc4 100644 --- a/Tun/Punchnet/Actors/SDLQuicClient.swift +++ b/Tun/Punchnet/Actors/SDLQuicClient.swift @@ -19,15 +19,29 @@ enum SDLQUICError: Error { } final class SDLQUICClient { + private let allocator = ByteBufferAllocator() + // 单个包最大64K + private let maxPacketSize: Int + // 最大缓冲区区为2M + private let maxBufferSize: Int + + public var messageStream: AsyncStream + private let messageCont: AsyncStream.Continuation + private var readTask: Task? + private let connection: NWConnection private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全 private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self) private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self) - init(host: String, port: UInt16) { + init(host: String, port: UInt16, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) { let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"]) + self.maxBufferSize = maxBufferSize + self.maxPacketSize = maxPacketSize + (self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self) + // TODO 这里设置证书的校验逻辑 sec_protocol_options_set_verify_block( options.securityProtocolOptions, @@ -57,13 +71,37 @@ final class SDLQUICClient { } } connection.start(queue: self.queue) - } - - func messageStream() async -> AsyncStream { - let reader = SDLQUICReader(connection: self.connection) - await reader.start() - return await reader.messageStream + // 启动数据读取任务 + self.readTask = Task { + var buffer = allocator.buffer(capacity: self.maxBufferSize) + let threshold = self.maxBufferSize / 10 * 6 + do { + while !Task.isCancelled { + let (isComplete, data) = try await self.readOnce() + if let data, !data.isEmpty { + buffer.writeBytes(data) + let frames = try parseFrames(buffer: &buffer) + if buffer.readerIndex > threshold { + buffer.discardReadBytes() + } + + for frame in frames { + if let message = decode(frame: frame) { + self.messageCont.yield(message) + } + } + } + + if isComplete { + break + } + } + self.messageCont.finish() + } catch { + self.messageCont.finish() + } + } } func send(type: SDLPacketType, data: Data) { @@ -92,60 +130,6 @@ final class SDLQUICClient { self.connection.cancel() } -} - -actor SDLQUICReader { - private let allocator = ByteBufferAllocator() - // 单个包最大64K - private let maxPacketSize: Int - // 最大缓冲区区为2M - private let maxBufferSize: Int - - public var messageStream: AsyncStream - private let messageCont: AsyncStream.Continuation - - private var readTask: Task? - private let connection: NWConnection - - init(connection: NWConnection, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) { - self.connection = connection - self.maxBufferSize = maxBufferSize - self.maxPacketSize = maxPacketSize - (self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self) - } - - func start() { - self.readTask = Task { - var buffer = allocator.buffer(capacity: self.maxBufferSize) - let threshold = self.maxBufferSize / 10 * 6 - do { - while !Task.isCancelled { - let (isComplete, data) = try await self.readOnce() - if let data, !data.isEmpty { - buffer.writeBytes(data) - let frames = try parseFrames(buffer: &buffer) - if buffer.readerIndex > threshold { - buffer.discardReadBytes() - } - - for frame in frames { - if let message = SDLQUICCodec.decode(frame: frame) { - self.messageCont.yield(message) - } - } - } - - if isComplete { - break - } - } - self.messageCont.finish() - } catch { - self.messageCont.finish() - } - } - } - // 尝试解析数据 private func parseFrames(buffer: inout ByteBuffer) throws -> [ByteBuffer] { guard buffer.readableBytes >= 2 else { @@ -178,7 +162,7 @@ actor SDLQUICReader { // 读取一次数据 private func readOnce() async throws -> (Bool, Data?) { return try await withCheckedThrowingContinuation { cont in - connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in + self.connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in if let error { cont.resume(throwing: error) return @@ -188,16 +172,8 @@ actor SDLQUICReader { } } - deinit { - self.readTask?.cancel() - self.messageCont.finish() - } - -} - -struct SDLQUICCodec { // --MARK: 编解码器 - public static func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? { + private func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? { var buffer = frame guard let type = buffer.readInteger(as: UInt8.self), let packetType = SDLPacketType(rawValue: type) else { @@ -270,4 +246,9 @@ struct SDLQUICCodec { return nil } } + + deinit { + self.readTask?.cancel() + self.messageCont.finish() + } }