339 lines
12 KiB
Swift
339 lines
12 KiB
Swift
//
|
||
// SDLQuicClient.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2026/2/13.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import Network
|
||
import CryptoKit
|
||
import Security
|
||
|
||
// 定义错误类型,便于上层处理
|
||
enum SDLQUICError: Error {
|
||
case connectionFailed(Error)
|
||
case connectionCancelled
|
||
case timeout
|
||
case decodeError(String)
|
||
case packetTooLarge
|
||
}
|
||
|
||
final class SDLQUICClient {
|
||
private let allocator = ByteBufferAllocator()
|
||
// 单个包最大64K
|
||
private let maxPacketSize: Int
|
||
// 最大缓冲区区为2M
|
||
private let maxBufferSize: Int
|
||
|
||
public var messageStream: AsyncStream<SDLQUICInboundMessage>
|
||
private let messageCont: AsyncStream<SDLQUICInboundMessage>.Continuation
|
||
private var readTask: Task<Void, Never>?
|
||
private var pingTask: Task<Void, Never>?
|
||
|
||
private let connection: NWConnection
|
||
private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全
|
||
|
||
private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self)
|
||
private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self)
|
||
|
||
init(host: String, port: UInt16, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) {
|
||
let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"])
|
||
|
||
self.maxBufferSize = maxBufferSize
|
||
self.maxPacketSize = maxPacketSize
|
||
(self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self)
|
||
|
||
// TODO 这里设置证书的校验逻辑
|
||
sec_protocol_options_set_verify_block(
|
||
options.securityProtocolOptions,
|
||
{ metadata, trust, complete in
|
||
// 执行公钥校验
|
||
complete(QUICVerifier.verify(trust: trust, host: host))
|
||
},
|
||
self.queue
|
||
)
|
||
|
||
let params = NWParameters(quic: options)
|
||
self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params)
|
||
}
|
||
|
||
func start() {
|
||
connection.stateUpdateHandler = { state in
|
||
SDLLogger.shared.log("[SDLQUICClient] new state: \(state)")
|
||
switch state {
|
||
case .ready:
|
||
self.readyCont.yield()
|
||
self.readyCont.finish()
|
||
case .failed(_), .cancelled:
|
||
self.closeCont.yield()
|
||
self.closeCont.finish()
|
||
default:
|
||
()
|
||
}
|
||
}
|
||
connection.start(queue: self.queue)
|
||
|
||
// 启动数据读取任务
|
||
self.readTask = Task {
|
||
var buffer = allocator.buffer(capacity: self.maxBufferSize)
|
||
let threshold = self.maxBufferSize / 10 * 6
|
||
do {
|
||
while !Task.isCancelled {
|
||
let (isComplete, data) = try await self.readOnce()
|
||
if let data, !data.isEmpty {
|
||
buffer.writeBytes(data)
|
||
let frames = try parseFrames(buffer: &buffer)
|
||
if buffer.readerIndex > threshold {
|
||
buffer.discardReadBytes()
|
||
}
|
||
|
||
for frame in frames {
|
||
if let message = decode(frame: frame) {
|
||
self.messageCont.yield(message)
|
||
}
|
||
}
|
||
}
|
||
|
||
if isComplete {
|
||
break
|
||
}
|
||
}
|
||
self.messageCont.finish()
|
||
} catch {
|
||
self.messageCont.finish()
|
||
}
|
||
}
|
||
|
||
// 处理心跳逻辑
|
||
self.pingTask = Task {
|
||
let timerStream = SDLAsyncTimerStream()
|
||
timerStream.start(interval: .seconds(5))
|
||
|
||
for await _ in timerStream.stream {
|
||
if Task.isCancelled {
|
||
break
|
||
}
|
||
self.send(type: .ping, data: Data())
|
||
}
|
||
|
||
SDLLogger.shared.log("[SDLQUICClient] udp pingTask cancel")
|
||
}
|
||
}
|
||
|
||
func send(type: SDLPacketType, data: Data) {
|
||
var len = UInt16(data.count + 1).bigEndian
|
||
|
||
var packet = Data(Data(bytes: &len, count: 2))
|
||
packet.append(type.rawValue)
|
||
packet.append(data)
|
||
|
||
connection.send(content: packet, completion: .contentProcessed { error in
|
||
if let error {
|
||
SDLLogger.shared.log("[SDLQUICClient] send data get error: \(error)")
|
||
}
|
||
})
|
||
}
|
||
|
||
func waitReady() async throws {
|
||
for await _ in readyStream {}
|
||
}
|
||
|
||
func waitClose() async {
|
||
for await _ in closeStream {}
|
||
}
|
||
|
||
func stop() {
|
||
self.connection.cancel()
|
||
}
|
||
|
||
// 尝试解析数据
|
||
private func parseFrames(buffer: inout ByteBuffer) throws -> [ByteBuffer] {
|
||
guard buffer.readableBytes >= 2 else {
|
||
return []
|
||
}
|
||
|
||
var frames: [ByteBuffer] = []
|
||
while true {
|
||
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
|
||
break
|
||
}
|
||
|
||
if len > self.maxPacketSize {
|
||
throw SDLQUICError.packetTooLarge
|
||
}
|
||
|
||
guard buffer.readableBytes >= len + 2 else {
|
||
break
|
||
}
|
||
|
||
buffer.moveReaderIndex(forwardBy: 2)
|
||
if let buf = buffer.readSlice(length: Int(len)) {
|
||
frames.append(buf)
|
||
}
|
||
}
|
||
|
||
return frames
|
||
}
|
||
|
||
// 读取一次数据
|
||
private func readOnce() async throws -> (Bool, Data?) {
|
||
return try await withCheckedThrowingContinuation { cont in
|
||
self.connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in
|
||
if let error {
|
||
cont.resume(throwing: error)
|
||
return
|
||
}
|
||
cont.resume(returning: (isComplete, data))
|
||
}
|
||
}
|
||
}
|
||
|
||
// --MARK: 编解码器
|
||
private func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? {
|
||
var buffer = frame
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type) else {
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .welcome:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let welcome = try? SDLWelcome(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .welcome(welcome)
|
||
|
||
case .registerSuperAck:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperAck(registerSuperAck)
|
||
case .registerSuperNak:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperNak(registerSuperNak)
|
||
case .peerInfo:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .peerInfo(peerInfo)
|
||
case .policyResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .policyReponse(policyResponse)
|
||
case .arpResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let arpResponse = try? SDLArpResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .arpResponse(arpResponse)
|
||
case .event:
|
||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||
let event = SDLEventType(rawValue: eventVal),
|
||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||
SDLLogger.shared.log("[SDLQUICClient] decode error 15")
|
||
return nil
|
||
}
|
||
|
||
switch event {
|
||
case .natChanged:
|
||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLQUICClient] decode error 16")
|
||
return nil
|
||
}
|
||
return .event(.natChanged(natChangedEvent))
|
||
case .sendRegister:
|
||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLQUICClient] decode error 17")
|
||
return nil
|
||
}
|
||
return .event(.sendRegister(sendRegisterEvent))
|
||
case .networkShutdown:
|
||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLQUICClient] decode error 18")
|
||
return nil
|
||
}
|
||
return .event(.networkShutdown(networkShutdownEvent))
|
||
}
|
||
case .pong:
|
||
return .pong
|
||
default:
|
||
SDLLogger.shared.log("SDLQUICClient decode miss type: \(type)")
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
self.readTask?.cancel()
|
||
self.pingTask?.cancel()
|
||
self.messageCont.finish()
|
||
}
|
||
}
|
||
|
||
extension SDLQUICClient {
|
||
|
||
enum QUICVerifier {
|
||
// 你的 Base64 公钥指纹
|
||
static let pinnedPublicKeyHashes = [
|
||
"Q41r6hbMWEVyxo6heNAH4Wx/TH5NNOWlNif9bewcJ3E="
|
||
]
|
||
|
||
static func verify(trust: sec_trust_t, host: String) -> Bool {
|
||
let secTrust = sec_trust_copy_ref(trust).takeRetainedValue()
|
||
|
||
// --- Step 1: 系统验证 ---
|
||
var error: CFError?
|
||
guard SecTrustEvaluateWithError(secTrust, &error) else {
|
||
SDLLogger.shared.log("❌ 系统证书验证失败: \(error?.localizedDescription ?? "未知错误")")
|
||
return false
|
||
}
|
||
|
||
// --- Step 2: 主机名验证 ---
|
||
let policy = SecPolicyCreateSSL(true, host as CFString)
|
||
SecTrustSetPolicies(secTrust, policy)
|
||
|
||
guard SecTrustEvaluateWithError(secTrust, &error) else {
|
||
SDLLogger.shared.log("❌ 主机名校验失败: \(error?.localizedDescription ?? "未知错误")")
|
||
return false
|
||
}
|
||
|
||
// --- Step 3: 获取叶子证书 ---
|
||
guard let chain = SecTrustCopyCertificateChain(secTrust) as? [SecCertificate],
|
||
let leafCertificate = chain.first else {
|
||
SDLLogger.shared.log("❌ 无法获取证书链或叶子证书")
|
||
return false
|
||
}
|
||
|
||
// --- Step 4: 提取公钥 ---
|
||
guard let publicKey = SecCertificateCopyKey(leafCertificate),
|
||
let publicKeyData = SecKeyCopyExternalRepresentation(publicKey, nil) as Data? else {
|
||
SDLLogger.shared.log("❌ 无法提取公钥")
|
||
return false
|
||
}
|
||
|
||
// --- Step 5: SHA256 校验 ---
|
||
let hash = SHA256.hash(data: publicKeyData)
|
||
let hashBase64 = Data(hash).base64EncodedString()
|
||
|
||
if pinnedPublicKeyHashes.contains(hashBase64) {
|
||
SDLLogger.shared.log("✅ 公钥校验通过")
|
||
return true
|
||
} else {
|
||
SDLLogger.shared.log("⚠️ 公钥不匹配! 收到: \(hashBase64)")
|
||
return false
|
||
}
|
||
}
|
||
|
||
}
|
||
}
|