punchnet-macos/Tun/Punchnet/SDLIPV6AssistClient.swift
2026-04-16 20:05:16 +08:00

272 lines
8.7 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLDNSClient 2.swift
// punchnet
//
// Created by on 2026/4/9.
//
import Foundation
import Network
enum SDLIPV6AssistError: Error {
case lostConnection
case requestTimeout
}
actor SDLIPV6AssistClient {
private struct PendingRequest {
let continuation: CheckedContinuation<SDLV6AssistProbeReply, Error>
let timeoutTask: Task<Void, Never>
}
private enum State {
case idle
case running
case stopped
}
private var state: State = .idle
private var connection: NWConnection?
private var receiveTask: Task<Void, Never>?
private let assistServerAddress: NWEndpoint
private var packetId: UInt32 = 1
private var pendingRequests: [UInt32: PendingRequest] = [:]
//
private let closeStream: AsyncStream<Void>
private let closeContinuation: AsyncStream<Void>.Continuation
private var didFinishCloseStream = false
init?(assistServerInfo: SDLV6Info) {
guard assistServerInfo.port <= UInt32(UInt16.max), let host = SDLUtil.ipv6DataToString(assistServerInfo.v6) else {
return nil
}
let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
self.closeStream = closeStream
self.closeContinuation = closeContinuation
self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port)))
}
func start() {
guard case .idle = self.state else {
return
}
self.state = .running
// 1.
let parameters = NWParameters.udp
// TUN NE TUN .other
parameters.prohibitedInterfaceTypes = [.other]
// 2. pathSelectionOptions
parameters.multipathServiceType = .handover
// IPv6 assist 退 IPv4
if let ipOptions = parameters.defaultProtocolStack.internetProtocol as? NWProtocolIP.Options {
ipOptions.version = .v6
}
// 2.
let connection = NWConnection(to: self.assistServerAddress, using: parameters)
self.connection = connection
connection.stateUpdateHandler = { [weak self] state in
Task {
await self?.handleConnectionStateUpdate(state, for: connection)
}
}
//
connection.start(queue: .global())
}
public func waitClose() async {
for await _ in self.closeStream { }
}
///
private static func makeReceiveStream(for connection: NWConnection) -> AsyncStream<Data> {
return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in
func receiveNext() {
connection.receiveMessage { content, _, _, error in
if let data = content, !data.isEmpty {
// DNS AsyncStream
continuation.yield(data)
}
if error == nil && connection.state == .ready {
receiveNext() //
} else {
continuation.finish()
}
}
}
receiveNext()
}
}
func probe(requestTimeout: Duration = .seconds(5)) async throws -> SDLV6AssistProbeReply {
guard case .running = self.state, let connection = self.connection, connection.state == .ready else {
throw SDLIPV6AssistError.lostConnection
}
let pktId = self.nextPacketId()
let requestTimeout = self.requestTimeout
var assistProbe = SDLV6AssistProbe()
assistProbe.pktID = pktId
let data = try assistProbe.serializedData()
return try await withCheckedThrowingContinuation { cont in
let timeoutTask = Task { [weak self] in
try? await Task.sleep(for: requestTimeout)
await self?.handleRequestTimeout(packetId: pktId)
}
self.pendingRequests[pktId] = .init(continuation: cont, timeoutTask: timeoutTask)
connection.send(content: data, completion: .contentProcessed { error in
if let error {
Task {
await self.handleProcessError(packetId: pktId, error: error)
}
}
})
}
}
private func handleProcessError(packetId: UInt32, error: NWError) {
if let request = self.takePendingRequest(packetId: packetId) {
request.continuation.resume(throwing: error)
}
}
private func handleRequestTimeout(packetId: UInt32) {
if let request = self.takePendingRequest(packetId: packetId) {
request.continuation.resume(throwing: SDLIPV6AssistError.requestTimeout)
}
}
func stop() {
self.stop(pendingError: SDLIPV6AssistError.lostConnection)
}
private func stop(pendingError: any Error) {
guard self.state != .stopped else {
return
}
self.state = .stopped
self.receiveTask?.cancel()
self.receiveTask = nil
self.connection?.cancel()
self.connection = nil
self.failAllPendingRequests(error: pendingError)
self.finishCloseStreamIfNeeded()
}
private func handleConnectionStateUpdate(_ state: NWConnection.State, for connection: NWConnection) {
guard case .running = self.state else {
return
}
switch state {
case .ready:
SDLLogger.log("[SDLIPV6AssistClient] Connection ready", for: .debug)
self.startReceiveTask(for: connection)
case .failed(let error):
SDLLogger.log("[SDLIPV6AssistClient] Connection failed: \(error)", for: .debug)
self.stop(pendingError: error)
case .cancelled:
self.stop()
default:
break
}
}
private func startReceiveTask(for connection: NWConnection) {
guard self.receiveTask == nil else {
return
}
let stream = Self.makeReceiveStream(for: connection)
self.receiveTask = Task { [weak self] in
for await data in stream {
guard let self else {
break
}
await self.handleReceivedPacket(data)
}
await self?.didFinishReceiving(for: connection)
}
}
private func handleReceivedPacket(_ data: Data) {
do {
let packet = try SDLV6AssistProbeReply(serializedBytes: data)
let pktId = packet.pktID
if let request = self.takePendingRequest(packetId: pktId) {
request.continuation.resume(returning: packet)
}
} catch {
SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug)
}
}
private func didFinishReceiving(for connection: NWConnection) {
guard case .running = self.state else {
return
}
if self.connection === connection, connection.state != .ready {
self.stop()
} else {
self.receiveTask = nil
}
}
private func finishCloseStreamIfNeeded() {
guard !self.didFinishCloseStream else {
return
}
self.didFinishCloseStream = true
self.closeContinuation.finish()
}
private func nextPacketId() -> UInt32 {
let packetId = self.packetId
self.packetId &+= 1
return packetId
}
private func takePendingRequest(packetId: UInt32) -> PendingRequest? {
guard let request = self.pendingRequests.removeValue(forKey: packetId) else {
return nil
}
request.timeoutTask.cancel()
return request
}
private func failAllPendingRequests(error: any Error) {
let pendingRequests = self.pendingRequests
self.pendingRequests.removeAll()
pendingRequests.values.forEach { request in
request.timeoutTask.cancel()
request.continuation.resume(throwing: error)
}
}
deinit {
self.connection?.cancel()
}
}