272 lines
8.7 KiB
Swift
272 lines
8.7 KiB
Swift
//
|
||
// SDLDNSClient 2.swift
|
||
// punchnet
|
||
//
|
||
// Created by 安礼成 on 2026/4/9.
|
||
//
|
||
import Foundation
|
||
import Network
|
||
|
||
enum SDLIPV6AssistError: Error {
|
||
case lostConnection
|
||
case requestTimeout
|
||
}
|
||
|
||
actor SDLIPV6AssistClient {
|
||
private struct PendingRequest {
|
||
let continuation: CheckedContinuation<SDLV6AssistProbeReply, Error>
|
||
let timeoutTask: Task<Void, Never>
|
||
}
|
||
|
||
private enum State {
|
||
case idle
|
||
case running
|
||
case stopped
|
||
}
|
||
|
||
private var state: State = .idle
|
||
private var connection: NWConnection?
|
||
private var receiveTask: Task<Void, Never>?
|
||
private let assistServerAddress: NWEndpoint
|
||
|
||
private var packetId: UInt32 = 1
|
||
private var pendingRequests: [UInt32: PendingRequest] = [:]
|
||
|
||
// 用来处理关闭事件
|
||
private let closeStream: AsyncStream<Void>
|
||
private let closeContinuation: AsyncStream<Void>.Continuation
|
||
private var didFinishCloseStream = false
|
||
|
||
init?(assistServerInfo: SDLV6Info) {
|
||
guard assistServerInfo.port <= UInt32(UInt16.max), let host = SDLUtil.ipv6DataToString(assistServerInfo.v6) else {
|
||
return nil
|
||
}
|
||
|
||
let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
|
||
self.closeStream = closeStream
|
||
self.closeContinuation = closeContinuation
|
||
|
||
self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port)))
|
||
}
|
||
|
||
func start() {
|
||
guard case .idle = self.state else {
|
||
return
|
||
}
|
||
|
||
self.state = .running
|
||
|
||
// 1. 配置参数:这是解决环路的关键
|
||
let parameters = NWParameters.udp
|
||
|
||
// 禁止此连接走 TUN 网卡(在 NE 中 TUN 通常被归类为 .other)
|
||
parameters.prohibitedInterfaceTypes = [.other]
|
||
// 2. 增强健壮性:启用多路径切换(替代 pathSelectionOptions 的意图)
|
||
parameters.multipathServiceType = .handover
|
||
|
||
// 只允许走 IPv6,避免在 assist 通道上退回到 IPv4 或双栈协商。
|
||
if let ipOptions = parameters.defaultProtocolStack.internetProtocol as? NWProtocolIP.Options {
|
||
ipOptions.version = .v6
|
||
}
|
||
|
||
// 2. 创建连接
|
||
let connection = NWConnection(to: self.assistServerAddress, using: parameters)
|
||
self.connection = connection
|
||
|
||
connection.stateUpdateHandler = { [weak self] state in
|
||
Task {
|
||
await self?.handleConnectionStateUpdate(state, for: connection)
|
||
}
|
||
}
|
||
|
||
// 启动连接队列
|
||
connection.start(queue: .global())
|
||
}
|
||
|
||
public func waitClose() async {
|
||
for await _ in self.closeStream { }
|
||
}
|
||
|
||
/// 接收数据的递归循环
|
||
private static func makeReceiveStream(for connection: NWConnection) -> AsyncStream<Data> {
|
||
return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in
|
||
func receiveNext() {
|
||
connection.receiveMessage { content, _, _, error in
|
||
if let data = content, !data.isEmpty {
|
||
// 将收到的 DNS 响应写回 AsyncStream
|
||
continuation.yield(data)
|
||
}
|
||
|
||
if error == nil && connection.state == .ready {
|
||
receiveNext() // 继续监听下一个包
|
||
} else {
|
||
continuation.finish()
|
||
}
|
||
}
|
||
}
|
||
|
||
receiveNext()
|
||
}
|
||
}
|
||
|
||
func probe(requestTimeout: Duration = .seconds(5)) async throws -> SDLV6AssistProbeReply {
|
||
guard case .running = self.state, let connection = self.connection, connection.state == .ready else {
|
||
throw SDLIPV6AssistError.lostConnection
|
||
}
|
||
|
||
let pktId = self.nextPacketId()
|
||
let requestTimeout = self.requestTimeout
|
||
var assistProbe = SDLV6AssistProbe()
|
||
assistProbe.pktID = pktId
|
||
let data = try assistProbe.serializedData()
|
||
|
||
return try await withCheckedThrowingContinuation { cont in
|
||
let timeoutTask = Task { [weak self] in
|
||
try? await Task.sleep(for: requestTimeout)
|
||
await self?.handleRequestTimeout(packetId: pktId)
|
||
}
|
||
|
||
self.pendingRequests[pktId] = .init(continuation: cont, timeoutTask: timeoutTask)
|
||
connection.send(content: data, completion: .contentProcessed { error in
|
||
if let error {
|
||
Task {
|
||
await self.handleProcessError(packetId: pktId, error: error)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
}
|
||
|
||
private func handleProcessError(packetId: UInt32, error: NWError) {
|
||
if let request = self.takePendingRequest(packetId: packetId) {
|
||
request.continuation.resume(throwing: error)
|
||
}
|
||
}
|
||
|
||
private func handleRequestTimeout(packetId: UInt32) {
|
||
if let request = self.takePendingRequest(packetId: packetId) {
|
||
request.continuation.resume(throwing: SDLIPV6AssistError.requestTimeout)
|
||
}
|
||
}
|
||
|
||
func stop() {
|
||
self.stop(pendingError: SDLIPV6AssistError.lostConnection)
|
||
}
|
||
|
||
private func stop(pendingError: any Error) {
|
||
guard self.state != .stopped else {
|
||
return
|
||
}
|
||
|
||
self.state = .stopped
|
||
self.receiveTask?.cancel()
|
||
self.receiveTask = nil
|
||
self.connection?.cancel()
|
||
self.connection = nil
|
||
self.failAllPendingRequests(error: pendingError)
|
||
self.finishCloseStreamIfNeeded()
|
||
}
|
||
|
||
private func handleConnectionStateUpdate(_ state: NWConnection.State, for connection: NWConnection) {
|
||
guard case .running = self.state else {
|
||
return
|
||
}
|
||
|
||
switch state {
|
||
case .ready:
|
||
SDLLogger.log("[SDLIPV6AssistClient] Connection ready", for: .debug)
|
||
self.startReceiveTask(for: connection)
|
||
case .failed(let error):
|
||
SDLLogger.log("[SDLIPV6AssistClient] Connection failed: \(error)", for: .debug)
|
||
self.stop(pendingError: error)
|
||
case .cancelled:
|
||
self.stop()
|
||
default:
|
||
break
|
||
}
|
||
}
|
||
|
||
private func startReceiveTask(for connection: NWConnection) {
|
||
guard self.receiveTask == nil else {
|
||
return
|
||
}
|
||
|
||
let stream = Self.makeReceiveStream(for: connection)
|
||
self.receiveTask = Task { [weak self] in
|
||
for await data in stream {
|
||
guard let self else {
|
||
break
|
||
}
|
||
await self.handleReceivedPacket(data)
|
||
}
|
||
|
||
await self?.didFinishReceiving(for: connection)
|
||
}
|
||
}
|
||
|
||
private func handleReceivedPacket(_ data: Data) {
|
||
do {
|
||
let packet = try SDLV6AssistProbeReply(serializedBytes: data)
|
||
let pktId = packet.pktID
|
||
if let request = self.takePendingRequest(packetId: pktId) {
|
||
request.continuation.resume(returning: packet)
|
||
}
|
||
} catch {
|
||
SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug)
|
||
}
|
||
}
|
||
|
||
private func didFinishReceiving(for connection: NWConnection) {
|
||
guard case .running = self.state else {
|
||
return
|
||
}
|
||
|
||
if self.connection === connection, connection.state != .ready {
|
||
self.stop()
|
||
} else {
|
||
self.receiveTask = nil
|
||
}
|
||
}
|
||
|
||
private func finishCloseStreamIfNeeded() {
|
||
guard !self.didFinishCloseStream else {
|
||
return
|
||
}
|
||
|
||
self.didFinishCloseStream = true
|
||
self.closeContinuation.finish()
|
||
}
|
||
|
||
private func nextPacketId() -> UInt32 {
|
||
let packetId = self.packetId
|
||
self.packetId &+= 1
|
||
|
||
return packetId
|
||
}
|
||
|
||
private func takePendingRequest(packetId: UInt32) -> PendingRequest? {
|
||
guard let request = self.pendingRequests.removeValue(forKey: packetId) else {
|
||
return nil
|
||
}
|
||
|
||
request.timeoutTask.cancel()
|
||
return request
|
||
}
|
||
|
||
private func failAllPendingRequests(error: any Error) {
|
||
let pendingRequests = self.pendingRequests
|
||
self.pendingRequests.removeAll()
|
||
|
||
pendingRequests.values.forEach { request in
|
||
request.timeoutTask.cancel()
|
||
request.continuation.resume(throwing: error)
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
self.connection?.cancel()
|
||
}
|
||
|
||
}
|