264 lines
9.2 KiB
Swift
264 lines
9.2 KiB
Swift
//
|
||
// SDLQuicClient.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2026/2/13.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import Network
|
||
|
||
// 定义错误类型,便于上层处理
|
||
enum SDLQUICError: Error {
|
||
case connectionFailed(Error)
|
||
case connectionCancelled
|
||
case timeout
|
||
case decodeError(String)
|
||
case packetTooLarge
|
||
}
|
||
|
||
final class SDLQUICClient {
|
||
private let allocator = ByteBufferAllocator()
|
||
// 单个包最大64K
|
||
private let maxPacketSize: Int
|
||
// 最大缓冲区区为2M
|
||
private let maxBufferSize: Int
|
||
|
||
public var messageStream: AsyncStream<SDLQUICInboundMessage>
|
||
private let messageCont: AsyncStream<SDLQUICInboundMessage>.Continuation
|
||
private var readTask: Task<Void, Never>?
|
||
|
||
private let connection: NWConnection
|
||
private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全
|
||
|
||
private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self)
|
||
private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self)
|
||
|
||
init(host: String, port: UInt16, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) {
|
||
let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"])
|
||
|
||
self.maxBufferSize = maxBufferSize
|
||
self.maxPacketSize = maxPacketSize
|
||
(self.messageStream, self.messageCont) = AsyncStream.makeStream(of: SDLQUICInboundMessage.self)
|
||
|
||
// TODO 这里设置证书的校验逻辑
|
||
sec_protocol_options_set_verify_block(
|
||
options.securityProtocolOptions,
|
||
{ metadata, trust, complete in
|
||
// 你可以自己决定是否信任
|
||
complete(true) // true = 接受证书
|
||
},
|
||
self.queue
|
||
)
|
||
|
||
let params = NWParameters(quic: options)
|
||
self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params)
|
||
}
|
||
|
||
func start() {
|
||
connection.stateUpdateHandler = { state in
|
||
SDLLogger.shared.log("[SDLQUICTransport] new state: \(state)")
|
||
switch state {
|
||
case .ready:
|
||
self.readyCont.yield()
|
||
self.readyCont.finish()
|
||
case .failed(_), .cancelled:
|
||
self.closeCont.yield()
|
||
self.closeCont.finish()
|
||
default:
|
||
()
|
||
}
|
||
}
|
||
connection.start(queue: self.queue)
|
||
|
||
// 启动数据读取任务
|
||
self.readTask = Task {
|
||
var buffer = allocator.buffer(capacity: self.maxBufferSize)
|
||
let threshold = self.maxBufferSize / 10 * 6
|
||
do {
|
||
while !Task.isCancelled {
|
||
let (isComplete, data) = try await self.readOnce()
|
||
if let data, !data.isEmpty {
|
||
buffer.writeBytes(data)
|
||
let frames = try parseFrames(buffer: &buffer)
|
||
if buffer.readerIndex > threshold {
|
||
buffer.discardReadBytes()
|
||
}
|
||
|
||
for frame in frames {
|
||
if let message = decode(frame: frame) {
|
||
self.messageCont.yield(message)
|
||
}
|
||
}
|
||
}
|
||
|
||
if isComplete {
|
||
break
|
||
}
|
||
}
|
||
self.messageCont.finish()
|
||
} catch {
|
||
self.messageCont.finish()
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
func send(type: SDLPacketType, data: Data) {
|
||
var len = UInt16(data.count + 1).bigEndian
|
||
|
||
var packet = Data(Data(bytes: &len, count: 2))
|
||
packet.append(type.rawValue)
|
||
packet.append(data)
|
||
|
||
connection.send(content: packet, completion: .contentProcessed { error in
|
||
if let error {
|
||
SDLLogger.shared.log("[SDLQUICClient] send data get error: \(error)")
|
||
}
|
||
})
|
||
}
|
||
|
||
func waitReady() async throws {
|
||
for await _ in readyStream {}
|
||
}
|
||
|
||
func waitClose() async {
|
||
for await _ in closeStream {}
|
||
}
|
||
|
||
func stop() {
|
||
self.connection.cancel()
|
||
}
|
||
|
||
// 尝试解析数据
|
||
private func parseFrames(buffer: inout ByteBuffer) throws -> [ByteBuffer] {
|
||
guard buffer.readableBytes >= 2 else {
|
||
return []
|
||
}
|
||
|
||
var frames: [ByteBuffer] = []
|
||
while true {
|
||
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
|
||
break
|
||
}
|
||
|
||
if len > self.maxPacketSize {
|
||
throw SDLQUICError.packetTooLarge
|
||
}
|
||
|
||
guard buffer.readableBytes >= len + 2 else {
|
||
break
|
||
}
|
||
|
||
buffer.moveReaderIndex(forwardBy: 2)
|
||
if let buf = buffer.readSlice(length: Int(len)) {
|
||
frames.append(buf)
|
||
}
|
||
}
|
||
|
||
return frames
|
||
}
|
||
|
||
// 读取一次数据
|
||
private func readOnce() async throws -> (Bool, Data?) {
|
||
return try await withCheckedThrowingContinuation { cont in
|
||
self.connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in
|
||
if let error {
|
||
cont.resume(throwing: error)
|
||
return
|
||
}
|
||
cont.resume(returning: (isComplete, data))
|
||
}
|
||
}
|
||
}
|
||
|
||
// --MARK: 编解码器
|
||
private func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? {
|
||
var buffer = frame
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type) else {
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .welcome:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let welcome = try? SDLWelcome(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .welcome(welcome)
|
||
|
||
case .registerSuperAck:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperAck(registerSuperAck)
|
||
case .registerSuperNak:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerSuperNak(registerSuperNak)
|
||
case .peerInfo:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .peerInfo(peerInfo)
|
||
case .policyResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .policyReponse(policyResponse)
|
||
case .arpResponse:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let arpResponse = try? SDLArpResponse(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .arpResponse(arpResponse)
|
||
case .event:
|
||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||
let event = SDLEventType(rawValue: eventVal),
|
||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 15")
|
||
return nil
|
||
}
|
||
|
||
switch event {
|
||
case .natChanged:
|
||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 16")
|
||
return nil
|
||
}
|
||
return .event(.natChanged(natChangedEvent))
|
||
case .sendRegister:
|
||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 17")
|
||
return nil
|
||
}
|
||
return .event(.sendRegister(sendRegisterEvent))
|
||
case .networkShutdown:
|
||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||
SDLLogger.shared.log("[SDLUDPHole] decode error 18")
|
||
return nil
|
||
}
|
||
return .event(.networkShutdown(networkShutdownEvent))
|
||
}
|
||
case .pong:
|
||
return .pong
|
||
default:
|
||
SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)")
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
self.readTask?.cancel()
|
||
self.messageCont.finish()
|
||
}
|
||
}
|