diff --git a/Tun/Punchnet/SDLIPV6AssistClient.swift b/Tun/Punchnet/SDLIPV6AssistClient.swift index 345feaa..5ab0637 100644 --- a/Tun/Punchnet/SDLIPV6AssistClient.swift +++ b/Tun/Punchnet/SDLIPV6AssistClient.swift @@ -8,19 +8,81 @@ import Foundation import Network actor SDLIPV6AssistClient { + + struct Packet: Sendable { + enum IPVersion: UInt8, Sendable { + case ipv4 = 4 + case ipv6 = 6 + + var protocolFamily: Int32 { + switch self { + case .ipv4: + return 2 + case .ipv6: + return 30 + } + } + } + + let packetId: UInt32 + let ipPacketData: Data + let ipVersion: IPVersion + + var protocolFamily: Int32 { + return self.ipVersion.protocolFamily + } + } + private enum State { case idle case running case stopped } + private struct PendingPacket: Sendable { + let packetId: UInt32 + let ipVersion: Packet.IPVersion + } + + private enum PacketParseError: Error { + case packetTooShort + case unmatchedPacketId(UInt32) + case invalidIPVersion(UInt8) + case unsupportedIPVersion(UInt8) + } + private var state: State = .idle private var connection: NWConnection? private var receiveTask: Task? private let assistServerAddress: NWEndpoint + private var idGenerator: SDLIdGenerator + private var pendingPackets: [UInt32: PendingPacket] = [:] - init(host: String, port: UInt16 ) { - self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: port)) + // 用于对外输出收到的原始 IP 响应包 + let packetFlow: AsyncStream + private let packetContinuation: AsyncStream.Continuation + private var didFinishPacketFlow = false + + // 用来处理关闭事件 + private let closeStream: AsyncStream + private let closeContinuation: AsyncStream.Continuation + private var didFinishCloseStream = false + + init?(assistServerInfo: SDLV6Info) { + guard assistServerInfo.port <= UInt32(UInt16.max), let host = SDLUtil.ipv6DataToString(assistServerInfo.v6) else { + return nil + } + + let (packetStream, packetContinuation) = AsyncStream.makeStream(of: Packet.self, bufferingPolicy: .bufferingNewest(256)) + self.packetFlow = packetStream + self.packetContinuation = packetContinuation + + let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) + self.closeStream = closeStream + self.closeContinuation = closeContinuation + + self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port))) + self.idGenerator = SDLIdGenerator(seed: UInt32.random(in: 1.. AsyncStream { return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in @@ -75,16 +146,32 @@ actor SDLIPV6AssistClient { } /// 发送 DNS 查询包(由 TUN 拦截到的原始 IP 包数据) - func forward(ipPacketData: Data) { + /// 当前实现会在原始 IP 包前增加 4 字节的 packetId,用于建立请求和响应的对应关系。 + @discardableResult + func forward(ipPacketData: Data) -> UInt32? { guard case .running = self.state, let connection = self.connection, connection.state == .ready else { - return + return nil } - connection.send(content: ipPacketData, completion: .contentProcessed { error in - if let error = error { - SDLLogger.log("[SDLIPV6AssistClient] Send error: \(error)", for: .debug) + let ipVersion: Packet.IPVersion + do { + ipVersion = try self.parseIPVersion(packetData: ipPacketData) + } catch { + SDLLogger.log("[SDLIPV6AssistClient] Invalid outbound packet: \(error)", for: .debug) + return nil + } + + let packetId = self.idGenerator.nextId() + self.pendingPackets[packetId] = .init(packetId: packetId, ipVersion: ipVersion) + + let outboundPacket = Data(components: Data(uint32: packetId), ipPacketData) + connection.send(content: outboundPacket, completion: .contentProcessed { [weak self] error in + Task { + await self?.handleSendCompletion(packetId: packetId, error: error) } }) + + return packetId } func stop() { @@ -95,8 +182,11 @@ actor SDLIPV6AssistClient { self.state = .stopped self.receiveTask?.cancel() self.receiveTask = nil + self.pendingPackets.removeAll() self.connection?.cancel() self.connection = nil + self.finishPacketFlowIfNeeded() + self.finishCloseStreamIfNeeded() } private func handleConnectionStateUpdate(_ state: NWConnection.State, for connection: NWConnection) { @@ -131,6 +221,8 @@ actor SDLIPV6AssistClient { } await self.handleReceivedPacket(data) } + + await self?.didFinishReceiving(for: connection) } } @@ -139,7 +231,89 @@ actor SDLIPV6AssistClient { return } - NSLog("data: \(data)") + do { + let packet = try self.parseInboundPacket(data) + self.packetContinuation.yield(packet) + } catch { + SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug) + } + } + + private func handleSendCompletion(packetId: UInt32, error: Error?) { + guard case .running = self.state else { + return + } + + if let error { + self.pendingPackets.removeValue(forKey: packetId) + SDLLogger.log("[SDLIPV6AssistClient] Send error: \(error), packetId: \(packetId)", for: .debug) + } + } + + private func didFinishReceiving(for connection: NWConnection) { + guard case .running = self.state else { + return + } + + if self.connection === connection, connection.state != .ready { + self.stop() + } else { + self.receiveTask = nil + } + } + + private func parseInboundPacket(_ data: Data) throws -> Packet { + guard data.count > 4 else { + throw PacketParseError.packetTooShort + } + + let packetId = UInt32(data: Data(data.prefix(4))) + guard let pendingPacket = self.pendingPackets.removeValue(forKey: packetId) else { + throw PacketParseError.unmatchedPacketId(packetId) + } + + let ipPacketData = Data(data.dropFirst(4)) + let ipVersion = try self.parseIPVersion(packetData: ipPacketData) + + if ipVersion != pendingPacket.ipVersion { + SDLLogger.log("[SDLIPV6AssistClient] packet version mismatch, packetId: \(packetId), request: \(pendingPacket.ipVersion.rawValue), response: \(ipVersion.rawValue)", for: .debug) + } + + return .init(packetId: pendingPacket.packetId, ipPacketData: ipPacketData, ipVersion: ipVersion) + } + + private func parseIPVersion(packetData: Data) throws -> Packet.IPVersion { + guard let firstByte = packetData.first else { + throw PacketParseError.packetTooShort + } + + let rawVersion = firstByte >> 4 + guard let ipVersion = Packet.IPVersion(rawValue: rawVersion) else { + throw PacketParseError.invalidIPVersion(rawVersion) + } + guard ipVersion == .ipv6 else { + throw PacketParseError.unsupportedIPVersion(rawVersion) + } + + return ipVersion + } + + private func finishPacketFlowIfNeeded() { + guard !self.didFinishPacketFlow else { + return + } + + self.didFinishPacketFlow = true + self.packetContinuation.finish() + } + + private func finishCloseStreamIfNeeded() { + guard !self.didFinishCloseStream else { + return + } + + self.didFinishCloseStream = true + self.closeContinuation.finish() } deinit {