242 lines
8.0 KiB
Swift
242 lines
8.0 KiB
Swift
//
|
||
// SDLQuicClient.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2026/2/13.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import Network
|
||
|
||
// 定义错误类型,便于上层处理
|
||
enum SDLQUICError: Error {
|
||
case connectionFailed(Error)
|
||
case connectionCancelled
|
||
case timeout
|
||
case decodeError(String)
|
||
}
|
||
|
||
final class SDLQUICClient {
|
||
private let transport: SDLQUICTransport
|
||
private let allocator = ByteBufferAllocator()
|
||
private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全
|
||
|
||
private var closeCont: CheckedContinuation<Void, Never>?
|
||
private var readyCont: CheckedContinuation<Void, Error>?
|
||
|
||
init(host: String, port: UInt16) {
|
||
self.transport = SDLQUICTransport(host: host, port: port)
|
||
}
|
||
|
||
func start() {
|
||
self.transport.start(queue: self.queue) { event in
|
||
switch event {
|
||
case .ready:
|
||
self.readyCont?.resume()
|
||
self.readyCont = nil
|
||
case .failed(let error):
|
||
self.closeCont?.resume()
|
||
self.closeCont = nil
|
||
case .cancelled:
|
||
self.closeCont?.resume()
|
||
self.closeCont = nil
|
||
}
|
||
}
|
||
}
|
||
|
||
func receiveStream(maxLen: Int) -> AsyncCompactMapSequence<AsyncStream<Data>, SDLQUICInboundMessage> {
|
||
return transport.receiveMessageStream(maxLen: maxLen).compactMap { data in
|
||
var buf = self.allocator.buffer(bytes: data)
|
||
return try? QUICCodec.decode(buffer: &buf)
|
||
}
|
||
}
|
||
|
||
func send(data: Data) {
|
||
transport.send(data)
|
||
}
|
||
|
||
func waitReady() async throws {
|
||
return try await withCheckedThrowingContinuation { cont in
|
||
self.readyCont = cont
|
||
}
|
||
}
|
||
|
||
func waitClose() async {
|
||
return await withCheckedContinuation { cont in
|
||
self.closeCont = cont
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
self.readyCont?.resume(throwing: SDLQUICError.connectionCancelled)
|
||
self.readyCont = nil
|
||
self.closeCont?.resume()
|
||
self.closeCont = nil
|
||
|
||
self.transport.stop()
|
||
}
|
||
|
||
}
|
||
|
||
final class SDLQUICTransport {
|
||
enum Event {
|
||
case ready
|
||
case failed(Error)
|
||
case cancelled
|
||
}
|
||
|
||
private let connection: NWConnection
|
||
|
||
init(host: String, port: UInt16) {
|
||
let params = NWParameters(quic: .init())
|
||
self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params)
|
||
}
|
||
|
||
func start(queue: DispatchQueue, onEvent: @escaping (Event) -> Void) {
|
||
connection.stateUpdateHandler = { state in
|
||
switch state {
|
||
case .ready: onEvent(.ready)
|
||
case .failed(let e): onEvent(.failed(e))
|
||
case .cancelled: onEvent(.cancelled)
|
||
default: break
|
||
}
|
||
}
|
||
connection.start(queue: queue)
|
||
}
|
||
|
||
func receiveMessageStream(maxLen: Int) -> AsyncStream<Data> {
|
||
let connection = self.connection
|
||
|
||
return AsyncStream { continuation in
|
||
var buffer = Data()
|
||
|
||
func tryParse() {
|
||
while true {
|
||
// 至少要有长度
|
||
guard buffer.count >= 2 else {
|
||
return
|
||
}
|
||
|
||
let len0 = UInt16(bigEndian: buffer.withUnsafeBytes { $0.load(as: UInt16.self) })
|
||
let len = Int(len0)
|
||
|
||
// 数据不够一个完整包
|
||
guard buffer.count >= 2 + len else {
|
||
return
|
||
}
|
||
|
||
// 取 body
|
||
let body = buffer.subdata(in: 2 ..< 2 + len)
|
||
continuation.yield(body)
|
||
|
||
// 移除已消费
|
||
buffer.removeSubrange(0 ..< 2 + len)
|
||
}
|
||
}
|
||
|
||
func loopReceive() {
|
||
connection.receive(minimumIncompleteLength: 1, maximumLength: maxLen) { data, _, _, error in
|
||
if let data, !data.isEmpty {
|
||
buffer.append(data)
|
||
tryParse()
|
||
}
|
||
if error == nil {
|
||
loopReceive()
|
||
} else {
|
||
continuation.finish()
|
||
}
|
||
}
|
||
}
|
||
|
||
loopReceive()
|
||
}
|
||
}
|
||
|
||
func send(_ data: Data) {
|
||
var len = UInt16(data.count).bigEndian
|
||
var packet = Data(Data(bytes: &len, count: 2))
|
||
packet.append(data)
|
||
|
||
connection.send(content: packet, completion: .contentProcessed { _ in })
|
||
}
|
||
|
||
func stop() {
|
||
connection.cancel()
|
||
}
|
||
|
||
}
|
||
|
||
extension SDLQUICClient {
|
||
|
||
struct QUICCodec {
|
||
// --MARK: 编解码器
|
||
public static func decode(buffer: inout ByteBuffer) throws -> SDLQUICInboundMessage? {
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type) else {
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .registerSuperAck:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperAck(registerSuperAck)
|
||
case .registerSuperNak:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperNak(registerSuperNak)
|
||
case .peerInfo:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .peerInfo(peerInfo)
|
||
case .policyResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .policyReponse(policyResponse)
|
||
case .event:
|
||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||
let event = SDLEventType(rawValue: eventVal),
|
||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 15")
|
||
return nil
|
||
}
|
||
|
||
switch event {
|
||
case .natChanged:
|
||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 16")
|
||
return nil
|
||
}
|
||
return .event(.natChanged(natChangedEvent))
|
||
case .sendRegister:
|
||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 17")
|
||
return nil
|
||
}
|
||
return .event(.sendRegister(sendRegisterEvent))
|
||
case .networkShutdown:
|
||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 18")
|
||
return nil
|
||
}
|
||
return .event(.networkShutdown(networkShutdownEvent))
|
||
}
|
||
default:
|
||
SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)")
|
||
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
}
|