punchnet-macos/Tun/Punchnet/SDLIPV6AssistClient.swift
2026-04-16 16:16:29 +08:00

323 lines
10 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLDNSClient 2.swift
// punchnet
//
// Created by on 2026/4/9.
//
import Foundation
import Network
actor SDLIPV6AssistClient {
struct Packet: Sendable {
enum IPVersion: UInt8, Sendable {
case ipv4 = 4
case ipv6 = 6
var protocolFamily: Int32 {
switch self {
case .ipv4:
return 2
case .ipv6:
return 30
}
}
}
let packetId: UInt32
let ipPacketData: Data
let ipVersion: IPVersion
var protocolFamily: Int32 {
return self.ipVersion.protocolFamily
}
}
private enum State {
case idle
case running
case stopped
}
private struct PendingPacket: Sendable {
let packetId: UInt32
let ipVersion: Packet.IPVersion
}
private enum PacketParseError: Error {
case packetTooShort
case unmatchedPacketId(UInt32)
case invalidIPVersion(UInt8)
case unsupportedIPVersion(UInt8)
}
private var state: State = .idle
private var connection: NWConnection?
private var receiveTask: Task<Void, Never>?
private let assistServerAddress: NWEndpoint
private var idGenerator: SDLIdGenerator
private var pendingPackets: [UInt32: PendingPacket] = [:]
// IP
let packetFlow: AsyncStream<Packet>
private let packetContinuation: AsyncStream<Packet>.Continuation
private var didFinishPacketFlow = false
//
private let closeStream: AsyncStream<Void>
private let closeContinuation: AsyncStream<Void>.Continuation
private var didFinishCloseStream = false
init?(assistServerInfo: SDLV6Info) {
guard assistServerInfo.port <= UInt32(UInt16.max), let host = SDLUtil.ipv6DataToString(assistServerInfo.v6) else {
return nil
}
let (packetStream, packetContinuation) = AsyncStream.makeStream(of: Packet.self, bufferingPolicy: .bufferingNewest(256))
self.packetFlow = packetStream
self.packetContinuation = packetContinuation
let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
self.closeStream = closeStream
self.closeContinuation = closeContinuation
self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port)))
self.idGenerator = SDLIdGenerator(seed: UInt32.random(in: 1..<UInt32.max))
}
func start() {
guard case .idle = self.state else {
return
}
self.state = .running
// 1.
let parameters = NWParameters.udp
// TUN NE TUN .other
parameters.prohibitedInterfaceTypes = [.other]
// 2. pathSelectionOptions
parameters.multipathServiceType = .handover
// IPv6 assist 退 IPv4
if let ipOptions = parameters.defaultProtocolStack.internetProtocol as? NWProtocolIP.Options {
ipOptions.version = .v6
}
// 2.
let connection = NWConnection(to: self.assistServerAddress, using: parameters)
self.connection = connection
connection.stateUpdateHandler = { [weak self] state in
Task {
await self?.handleConnectionStateUpdate(state, for: connection)
}
}
//
connection.start(queue: .global())
}
public func waitClose() async {
for await _ in self.closeStream { }
}
///
private static func makeReceiveStream(for connection: NWConnection) -> AsyncStream<Data> {
return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in
func receiveNext() {
connection.receiveMessage { content, _, _, error in
if let data = content, !data.isEmpty {
// DNS AsyncStream
continuation.yield(data)
}
if error == nil && connection.state == .ready {
receiveNext() //
} else {
continuation.finish()
}
}
}
receiveNext()
}
}
/// DNS TUN IP
/// IP 4 packetId
@discardableResult
func forward(ipPacketData: Data) -> UInt32? {
guard case .running = self.state, let connection = self.connection, connection.state == .ready else {
return nil
}
let ipVersion: Packet.IPVersion
do {
ipVersion = try self.parseIPVersion(packetData: ipPacketData)
} catch {
SDLLogger.log("[SDLIPV6AssistClient] Invalid outbound packet: \(error)", for: .debug)
return nil
}
let packetId = self.idGenerator.nextId()
self.pendingPackets[packetId] = .init(packetId: packetId, ipVersion: ipVersion)
let outboundPacket = Data(components: Data(uint32: packetId), ipPacketData)
connection.send(content: outboundPacket, completion: .contentProcessed { [weak self] error in
Task {
await self?.handleSendCompletion(packetId: packetId, error: error)
}
})
return packetId
}
func stop() {
guard self.state != .stopped else {
return
}
self.state = .stopped
self.receiveTask?.cancel()
self.receiveTask = nil
self.pendingPackets.removeAll()
self.connection?.cancel()
self.connection = nil
self.finishPacketFlowIfNeeded()
self.finishCloseStreamIfNeeded()
}
private func handleConnectionStateUpdate(_ state: NWConnection.State, for connection: NWConnection) {
guard case .running = self.state else {
return
}
switch state {
case .ready:
SDLLogger.log("[SDLIPV6AssistClient] Connection ready", for: .debug)
self.startReceiveTask(for: connection)
case .failed(let error):
SDLLogger.log("[SDLIPV6AssistClient] Connection failed: \(error)", for: .debug)
self.stop()
case .cancelled:
self.stop()
default:
break
}
}
private func startReceiveTask(for connection: NWConnection) {
guard self.receiveTask == nil else {
return
}
let stream = Self.makeReceiveStream(for: connection)
self.receiveTask = Task { [weak self] in
for await data in stream {
guard let self else {
break
}
await self.handleReceivedPacket(data)
}
await self?.didFinishReceiving(for: connection)
}
}
private func handleReceivedPacket(_ data: Data) {
guard case .running = self.state else {
return
}
do {
let packet = try self.parseInboundPacket(data)
self.packetContinuation.yield(packet)
} catch {
SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug)
}
}
private func handleSendCompletion(packetId: UInt32, error: Error?) {
guard case .running = self.state else {
return
}
if let error {
self.pendingPackets.removeValue(forKey: packetId)
SDLLogger.log("[SDLIPV6AssistClient] Send error: \(error), packetId: \(packetId)", for: .debug)
}
}
private func didFinishReceiving(for connection: NWConnection) {
guard case .running = self.state else {
return
}
if self.connection === connection, connection.state != .ready {
self.stop()
} else {
self.receiveTask = nil
}
}
private func parseInboundPacket(_ data: Data) throws -> Packet {
guard data.count > 4 else {
throw PacketParseError.packetTooShort
}
let packetId = UInt32(data: Data(data.prefix(4)))
guard let pendingPacket = self.pendingPackets.removeValue(forKey: packetId) else {
throw PacketParseError.unmatchedPacketId(packetId)
}
let ipPacketData = Data(data.dropFirst(4))
let ipVersion = try self.parseIPVersion(packetData: ipPacketData)
if ipVersion != pendingPacket.ipVersion {
SDLLogger.log("[SDLIPV6AssistClient] packet version mismatch, packetId: \(packetId), request: \(pendingPacket.ipVersion.rawValue), response: \(ipVersion.rawValue)", for: .debug)
}
return .init(packetId: pendingPacket.packetId, ipPacketData: ipPacketData, ipVersion: ipVersion)
}
private func parseIPVersion(packetData: Data) throws -> Packet.IPVersion {
guard let firstByte = packetData.first else {
throw PacketParseError.packetTooShort
}
let rawVersion = firstByte >> 4
guard let ipVersion = Packet.IPVersion(rawValue: rawVersion) else {
throw PacketParseError.invalidIPVersion(rawVersion)
}
guard ipVersion == .ipv6 else {
throw PacketParseError.unsupportedIPVersion(rawVersion)
}
return ipVersion
}
private func finishPacketFlowIfNeeded() {
guard !self.didFinishPacketFlow else {
return
}
self.didFinishPacketFlow = true
self.packetContinuation.finish()
}
private func finishCloseStreamIfNeeded() {
guard !self.didFinishCloseStream else {
return
}
self.didFinishCloseStream = true
self.closeContinuation.finish()
}
deinit {
self.connection?.cancel()
}
}