318 lines
10 KiB
Swift
318 lines
10 KiB
Swift
//
|
||
// SDLQuicClient.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2026/2/13.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import Network
|
||
|
||
// 定义错误类型,便于上层处理
|
||
enum SDLQUICError: Error {
|
||
case connectionFailed(Error)
|
||
case connectionCancelled
|
||
case timeout
|
||
case decodeError(String)
|
||
case packetTooLarge
|
||
}
|
||
|
||
final class SDLQUICClient {
|
||
private let transport: SDLQUICTransport
|
||
private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全
|
||
|
||
private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self)
|
||
private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self)
|
||
|
||
init(host: String, port: UInt16) {
|
||
self.transport = SDLQUICTransport(host: host, port: port)
|
||
}
|
||
|
||
func start() {
|
||
self.transport.start(queue: self.queue) { event in
|
||
switch event {
|
||
case .ready:
|
||
self.readyCont.yield()
|
||
self.readyCont.finish()
|
||
case .failed(_), .cancelled:
|
||
self.closeCont.yield()
|
||
self.closeCont.finish()
|
||
}
|
||
}
|
||
}
|
||
|
||
func getReader() -> SDLQUICReader {
|
||
return transport.getReader()
|
||
}
|
||
|
||
func send(type: SDLPacketType, data: Data) {
|
||
self.transport.send(type: type, data: data)
|
||
}
|
||
|
||
func waitReady() async throws {
|
||
for await _ in readyStream {}
|
||
}
|
||
|
||
func waitClose() async {
|
||
for await _ in closeStream {}
|
||
}
|
||
|
||
func stop() {
|
||
self.transport.stop()
|
||
}
|
||
|
||
}
|
||
|
||
final class SDLQUICTransport {
|
||
enum Event {
|
||
case ready
|
||
case failed(Error)
|
||
case cancelled
|
||
}
|
||
|
||
private let connection: NWConnection
|
||
|
||
init(host: String, port: UInt16) {
|
||
let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"])
|
||
|
||
// TODO 这里设置证书的校验逻辑
|
||
sec_protocol_options_set_verify_block(
|
||
options.securityProtocolOptions,
|
||
{ metadata, trust, complete in
|
||
// 你可以自己决定是否信任
|
||
complete(true) // true = 接受证书
|
||
},
|
||
DispatchQueue.global()
|
||
)
|
||
|
||
let params = NWParameters(quic: options)
|
||
self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params)
|
||
}
|
||
|
||
func start(queue: DispatchQueue, onEvent: @escaping (Event) -> Void) {
|
||
SDLLogger.shared.log("[SDLQUICTransport] call start")
|
||
connection.stateUpdateHandler = { state in
|
||
SDLLogger.shared.log("[SDLQUICTransport] new state: \(state)")
|
||
switch state {
|
||
case .ready: onEvent(.ready)
|
||
case .failed(let e): onEvent(.failed(e))
|
||
case .cancelled: onEvent(.cancelled)
|
||
default: break
|
||
}
|
||
}
|
||
connection.start(queue: queue)
|
||
}
|
||
|
||
func getReader() -> SDLQUICReader {
|
||
return SDLQUICReader(connection: self.connection)
|
||
}
|
||
|
||
func send(type: SDLPacketType, data: Data) {
|
||
var len = UInt16(data.count + 1).bigEndian
|
||
|
||
var packet = Data(Data(bytes: &len, count: 2))
|
||
packet.append(type.rawValue)
|
||
packet.append(data)
|
||
|
||
connection.send(content: packet, completion: .contentProcessed { _ in })
|
||
}
|
||
|
||
func stop() {
|
||
connection.cancel()
|
||
}
|
||
|
||
}
|
||
|
||
actor SDLQUICReader: AsyncIteratorProtocol {
|
||
|
||
typealias Element = ByteBuffer
|
||
|
||
private let allocator = ByteBufferAllocator()
|
||
private var buffer: ByteBuffer
|
||
// 用来缓存包,有可能一次读取到多个包
|
||
private var packets: [ByteBuffer] = []
|
||
// 单个包最大64K
|
||
private let maxPacketSize: Int
|
||
// 最大缓冲区区为2M
|
||
private let maxBufferSize: Int
|
||
|
||
// 是否已经读取完成
|
||
private var isComplete: Bool = false
|
||
|
||
private let connection: NWConnection
|
||
|
||
init(connection: NWConnection, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) {
|
||
self.connection = connection
|
||
self.maxBufferSize = maxBufferSize
|
||
self.maxPacketSize = maxPacketSize
|
||
self.buffer = allocator.buffer(capacity: maxBufferSize)
|
||
}
|
||
|
||
func next() async throws -> ByteBuffer? {
|
||
// 如果还有包
|
||
if !self.packets.isEmpty {
|
||
return self.packets.removeFirst()
|
||
}
|
||
|
||
// 尝试读取,并返回
|
||
self.packets = try await self.readPacket()
|
||
if !self.packets.isEmpty {
|
||
return self.packets.removeFirst()
|
||
} else {
|
||
return nil
|
||
}
|
||
}
|
||
|
||
private func readPacket() async throws -> [ByteBuffer] {
|
||
while true {
|
||
if self.isComplete {
|
||
return try parseFrames()
|
||
}
|
||
|
||
let (isComplete, data) = try await readOnce()
|
||
self.isComplete = isComplete
|
||
|
||
if !data.isEmpty {
|
||
buffer.writeBytes(data)
|
||
// 尝试解析出完整的包
|
||
let packets = try parseFrames()
|
||
if !packets.isEmpty {
|
||
return packets
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 尝试解析数据
|
||
private func parseFrames() throws -> [ByteBuffer] {
|
||
guard buffer.readableBytes >= 2 else {
|
||
return []
|
||
}
|
||
|
||
var frames: [ByteBuffer] = []
|
||
while true {
|
||
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
|
||
break
|
||
}
|
||
|
||
if len > self.maxPacketSize {
|
||
throw SDLQUICError.packetTooLarge
|
||
}
|
||
|
||
guard buffer.readableBytes >= len + 2 else {
|
||
break
|
||
}
|
||
|
||
buffer.moveReaderIndex(forwardBy: 2)
|
||
if let buf = buffer.readSlice(length: Int(len)) {
|
||
frames.append(buf)
|
||
}
|
||
}
|
||
|
||
if buffer.readerIndex > maxBufferSize / 10 * 6 {
|
||
buffer.discardReadBytes()
|
||
}
|
||
|
||
return frames
|
||
}
|
||
|
||
// 读取一次数据
|
||
private func readOnce() async throws -> (Bool, Data) {
|
||
return try await withCheckedThrowingContinuation { cont in
|
||
connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in
|
||
if let error {
|
||
cont.resume(throwing: error)
|
||
return
|
||
}
|
||
|
||
if let data, !data.isEmpty {
|
||
SDLLogger.shared.log("[SDLQUICTransport] read bytes: \(data.count)")
|
||
cont.resume(returning: (isComplete, data))
|
||
} else {
|
||
cont.resume(returning: (isComplete, Data()))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
struct SDLQUICCodec {
|
||
// --MARK: 编解码器
|
||
public static func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? {
|
||
var buffer = frame
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type) else {
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .welcome:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let welcome = try? SDLWelcome(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .welcome(welcome)
|
||
|
||
case .registerSuperAck:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperAck(registerSuperAck)
|
||
case .registerSuperNak:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperNak(registerSuperNak)
|
||
case .peerInfo:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .peerInfo(peerInfo)
|
||
case .policyResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .policyReponse(policyResponse)
|
||
case .event:
|
||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||
let event = SDLEventType(rawValue: eventVal),
|
||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 15")
|
||
return nil
|
||
}
|
||
|
||
switch event {
|
||
case .natChanged:
|
||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 16")
|
||
return nil
|
||
}
|
||
return .event(.natChanged(natChangedEvent))
|
||
case .sendRegister:
|
||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 17")
|
||
return nil
|
||
}
|
||
return .event(.sendRegister(sendRegisterEvent))
|
||
case .networkShutdown:
|
||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 18")
|
||
return nil
|
||
}
|
||
return .event(.networkShutdown(networkShutdownEvent))
|
||
}
|
||
default:
|
||
SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)")
|
||
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|