// // SDLQuicClient.swift // Tun // // Created by 安礼成 on 2026/2/13. // import Foundation import NIOCore import Network import CryptoKit import Security // 定义错误类型,便于上层处理 enum SDLQUICError: Error { case connectionFailed(Error) case connectionCancelled case timeout case decodeError(String) case packetTooLarge } final class SDLQUICClient { private let allocator = ByteBufferAllocator() // 单个包最大64K private let maxPacketSize: Int // 最大缓冲区区为2M private let maxBufferSize: Int public var messageStream: AsyncStream private let messageCont: AsyncStream.Continuation private var readTask: Task? private var pingTask: Task? private let connection: NWConnection private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全 private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self) private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self) init(host: String, port: UInt16, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) { let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"]) self.maxBufferSize = maxBufferSize self.maxPacketSize = maxPacketSize (self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self) // TODO 这里设置证书的校验逻辑 sec_protocol_options_set_verify_block( options.securityProtocolOptions, { metadata, trust, complete in // 执行公钥校验 complete(QUICVerifier.verify(trust: trust, host: host)) }, self.queue ) let params = NWParameters(quic: options) self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params) } func start() { connection.stateUpdateHandler = { state in SDLLogger.shared.log("[SDLQUICClient] new state: \(state)") switch state { case .ready: self.readyCont.yield() self.readyCont.finish() case .failed(_), .cancelled: self.closeCont.yield() self.closeCont.finish() default: () } } connection.start(queue: self.queue) // 启动数据读取任务 self.readTask = Task { var buffer = allocator.buffer(capacity: self.maxBufferSize) let threshold = self.maxBufferSize / 10 * 6 do { while !Task.isCancelled { let (isComplete, data) = try await self.readOnce() if let data, !data.isEmpty { buffer.writeBytes(data) let frames = try parseFrames(buffer: &buffer) if buffer.readerIndex > threshold { buffer.discardReadBytes() } for frame in frames { if let message = decode(frame: frame) { self.messageCont.yield(message) } } } if isComplete { break } } self.messageCont.finish() } catch { self.messageCont.finish() } } // 处理心跳逻辑 self.pingTask = Task { let timerStream = SDLAsyncTimerStream() timerStream.start(interval: .seconds(5)) for await _ in timerStream.stream { if Task.isCancelled { break } self.send(type: .ping, data: Data()) } SDLLogger.shared.log("[SDLQUICClient] udp pingTask cancel") } } func send(type: SDLPacketType, data: Data) { var len = UInt16(data.count + 1).bigEndian var packet = Data(Data(bytes: &len, count: 2)) packet.append(type.rawValue) packet.append(data) connection.send(content: packet, completion: .contentProcessed { error in if let error { SDLLogger.shared.log("[SDLQUICClient] send data get error: \(error)") } }) } func waitReady() async throws { for await _ in readyStream {} } func waitClose() async { for await _ in closeStream {} } func stop() { self.connection.cancel() } // 尝试解析数据 private func parseFrames(buffer: inout ByteBuffer) throws -> [ByteBuffer] { guard buffer.readableBytes >= 2 else { return [] } var frames: [ByteBuffer] = [] while true { guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { break } if len > self.maxPacketSize { throw SDLQUICError.packetTooLarge } guard buffer.readableBytes >= len + 2 else { break } buffer.moveReaderIndex(forwardBy: 2) if let buf = buffer.readSlice(length: Int(len)) { frames.append(buf) } } return frames } // 读取一次数据 private func readOnce() async throws -> (Bool, Data?) { return try await withCheckedThrowingContinuation { cont in self.connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in if let error { cont.resume(throwing: error) return } cont.resume(returning: (isComplete, data)) } } } // --MARK: 编解码器 private func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? { var buffer = frame guard let type = buffer.readInteger(as: UInt8.self), let packetType = SDLPacketType(rawValue: type) else { return nil } switch packetType { case .welcome: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let welcome = try? SDLWelcome(serializedBytes: bytes) else { return nil } return .welcome(welcome) case .registerSuperAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { return nil } return .registerSuperAck(registerSuperAck) case .registerSuperNak: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { return nil } return .registerSuperNak(registerSuperNak) case .peerInfo: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { return nil } return .peerInfo(peerInfo) case .policyResponse: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else { return nil } return .policyReponse(policyResponse) case .arpResponse: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let arpResponse = try? SDLArpResponse(serializedBytes: bytes) else { return nil } return .arpResponse(arpResponse) case .event: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let event = try? SDLEvent(serializedBytes: bytes) else { return nil } return .event(event) case .pong: return .pong default: SDLLogger.shared.log("SDLQUICClient decode miss type: \(type)") return nil } } deinit { self.readTask?.cancel() self.pingTask?.cancel() self.messageCont.finish() } } extension SDLQUICClient { enum QUICVerifier { // 你的 Base64 公钥指纹 static let pinnedPublicKeyHashes = [ "Q41r6hbMWEVyxo6heNAH4Wx/TH5NNOWlNif9bewcJ3E=" ] static func verify(trust: sec_trust_t, host: String) -> Bool { let secTrust = sec_trust_copy_ref(trust).takeRetainedValue() // --- Step 1: 系统验证 --- var error: CFError? guard SecTrustEvaluateWithError(secTrust, &error) else { SDLLogger.shared.log("❌ 系统证书验证失败: \(error?.localizedDescription ?? "未知错误")") return false } // --- Step 2: 主机名验证 --- let policy = SecPolicyCreateSSL(true, host as CFString) SecTrustSetPolicies(secTrust, policy) guard SecTrustEvaluateWithError(secTrust, &error) else { SDLLogger.shared.log("❌ 主机名校验失败: \(error?.localizedDescription ?? "未知错误")") return false } // --- Step 3: 获取叶子证书 --- guard let chain = SecTrustCopyCertificateChain(secTrust) as? [SecCertificate], let leafCertificate = chain.first else { SDLLogger.shared.log("❌ 无法获取证书链或叶子证书") return false } // --- Step 4: 提取公钥 --- guard let publicKey = SecCertificateCopyKey(leafCertificate), let publicKeyData = SecKeyCopyExternalRepresentation(publicKey, nil) as Data? else { SDLLogger.shared.log("❌ 无法提取公钥") return false } // --- Step 5: SHA256 校验 --- let hash = SHA256.hash(data: publicKeyData) let hashBase64 = Data(hash).base64EncodedString() if pinnedPublicKeyHashes.contains(hashBase64) { SDLLogger.shared.log("✅ 公钥校验通过") return true } else { SDLLogger.shared.log("⚠️ 公钥不匹配! 收到: \(hashBase64)") return false } } } }