diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index c25d1e5..957e56c 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -282,12 +282,13 @@ actor SDLContextActor { // 启动dns服务 let dnsLocalClient = DNSLocalClient() - dnsLocalClient.start() + await dnsLocalClient.start() SDLLogger.log("[SDLContext] dnsClient started") self.dnsLocalClient = dnsLocalClient + let packetFlow = await dnsLocalClient.packetFlow self.dnsLocalWorker = Task.detached { // 处理事件流 - for await packet in dnsLocalClient.packetFlow { + for await packet in packetFlow { if Task.isCancelled { break } @@ -306,7 +307,7 @@ actor SDLContextActor { // 启动udp服务器 let udpHole = try SDLUDPHole() let localAddress = try udpHole.start() - SDLLogger.log("[SDLContext] udpHole started, on address: \(localAddress.debugDescription)") + SDLLogger.log("[SDLContext] udpHole started, on address: \(localAddress)") // 处理心跳逻辑 let pingTask = Task.detached { @@ -383,7 +384,9 @@ actor SDLContextActor { self.dnsWorker = nil self.dnsClient = nil - self.dnsLocalClient?.stop() + if let dnsLocalClient = self.dnsLocalClient { + await dnsLocalClient.stop() + } self.dnsLocalWorker?.cancel() self.dnsLocalWorker = nil self.dnsLocalClient = nil @@ -784,7 +787,9 @@ actor SDLContextActor { clientIP: packet.header.source, clientPort: udpPacket.srcPort, createdAt: Date()) - self.dnsLocalClient?.query(tracker: tracker, dnsPayload: dnsPayload) + if let dnsLocalClient = self.dnsLocalClient { + await dnsLocalClient.query(tracker: tracker, dnsPayload: dnsPayload) + } } } } diff --git a/Tun/Punchnet/DNS/DNSLocalClient.swift b/Tun/Punchnet/DNS/DNSLocalClient.swift index 1407ca0..b9b6aed 100644 --- a/Tun/Punchnet/DNS/DNSLocalClient.swift +++ b/Tun/Punchnet/DNS/DNSLocalClient.swift @@ -1,182 +1,265 @@ import Foundation import Network -final class DNSLocalClient { - // 需要保存DNS请求的追踪信息 +actor DNSLocalClient { + struct DNSTracker { let transactionID: UInt16 - let clientIP: UInt32 // 原始包的源 IP (大端序) - let clientPort: UInt16 // 原始包的源端口 (大端序) - let createdAt: Date // 用于超时清理 + let clientIP: UInt32 + let clientPort: UInt16 + let createdAt: Date } - private var connections: [NWConnection] = [] + private struct PendingRequest { + let tracker: DNSTracker + } - // 阿里云 + 腾讯云 + private enum State { + case idle + case running + case stopped + } + + private var state: State = .idle + private var connections: [NWConnection] = [] private let dnsServers = ["223.5.5.5", "119.29.29.29"] - - public let packetFlow: AsyncStream + + let packetFlow: AsyncStream private let packetContinuation: AsyncStream.Continuation - private let locker = NSLock() - private var trackers: [UInt16: [DNSTracker]] = [:] + private var pendingRequests: [UInt16: PendingRequest] = [:] + private var nextTransactionID: UInt16 = 1 - // 定期的任务清理 private var cleanupTask: Task? - private let timeoutInterval: TimeInterval = 10.0 // 超过10秒认为丢包 - + private let timeoutInterval: TimeInterval = 3.0 + private var didFinishPacketFlow = false + init() { let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) self.packetFlow = stream self.packetContinuation = continuation } - + func start() { - for server in dnsServers { + guard case .idle = self.state else { + return + } + + self.state = .running + + for server in self.dnsServers { let endpoint = NWEndpoint.hostPort(host: NWEndpoint.Host(server), port: 53) let parameters = NWParameters.udp parameters.prohibitedInterfaceTypes = [.other] let conn = NWConnection(to: endpoint, using: parameters) - conn.stateUpdateHandler = { [weak self] state in - switch state { - case .ready: - self?.receiveLoop(for: conn) - case .failed(let error): - SDLLogger.log("[DNSLocalClient] failed with error: \(error.localizedDescription)", for: .debug) - self?.stop() - case .cancelled: - self?.packetContinuation.finish() - default: - () + Task { + await self?.handleConnectionStateUpdate(state, for: conn) } } conn.start(queue: .global()) - - connections.append(conn) + self.connections.append(conn) } - // 启动清理循环 self.cleanupTask = Task { [weak self] in while !Task.isCancelled { - // 每隔 cleanupTick 秒运行一次 - try? await Task.sleep(nanoseconds: 5 * 1_000_000_000) - self?.performCleanup() - } - } - } - - /// 并发查询:对所有服务器广播 - func query(tracker: DNSTracker, dnsPayload: Data) { - locker.lock() - self.trackers[tracker.transactionID, default: []].append(tracker) - locker.unlock() - - for conn in connections where conn.state == .ready { - conn.send(content: dnsPayload, completion: .contentProcessed({ _ in })) - } - } - - private func receiveLoop(for conn: NWConnection) { - conn.receiveMessage { [weak self] content, _, _, error in - if let data = content { - // !!!核心:由于 AsyncStream 是流式的 - // 谁先 yield,上层就先收到谁。 - // 只要上层收到了第一个有效响应并回填给系统, - self?.handleResponse(data: data) - } - - if error == nil && conn.state == .ready { - self?.receiveLoop(for: conn) + try? await Task.sleep(nanoseconds: 3 * 1_000_000_000) + await self?.performCleanup() } } } - private func handleResponse(data: Data) { - guard data.count > 2 else { + func query(tracker: DNSTracker, dnsPayload: Data) { + guard case .running = self.state, dnsPayload.count >= 2 else { return } - let tranId = UInt16(data[0]) << 8 | UInt16(data[1]) - - locker.lock() - let items = self.trackers.removeValue(forKey: tranId) - locker.unlock() - - items?.forEach { tracker in - let packet = Self.createDNSResponse( - payload: data, - srcIP: DNSHelper.dnsDestIpAddr, - srcPort: 53, - destIP: tracker.clientIP, - destPort: tracker.clientPort - ) - self.packetContinuation.yield(packet) - } - } - - private func performCleanup() { - locker.lock() - defer { - locker.unlock() + guard let transactionID = self.allocateTransactionID() else { + SDLLogger.log("[DNSLocalClient] no available transaction id", for: .debug) + return } - // 遍历所有 ID,过滤掉过期的 tracker - let now = Date() - for (id, list) in trackers { - let validItems = list.filter { now.timeIntervalSince($0.createdAt) < timeoutInterval } - - if validItems.isEmpty { - trackers.removeValue(forKey: id) - } else { - trackers[id] = validItems - } + self.pendingRequests[transactionID] = PendingRequest(tracker: tracker) + let rewrittenPayload = Self.rewriteTransactionID(in: dnsPayload, to: transactionID) + + var hasReadyConnection = false + for conn in self.connections where conn.state == .ready { + hasReadyConnection = true + conn.send(content: rewrittenPayload, completion: .contentProcessed({ error in + if let error { + SDLLogger.log("[DNSLocalClient] send error: \(error.localizedDescription)", for: .debug) + } + })) + } + + if !hasReadyConnection { + self.pendingRequests.removeValue(forKey: transactionID) } } func stop() { - connections.forEach { conn in - conn.cancel() + guard self.state != .stopped else { + return } + + self.state = .stopped + self.connections.forEach { $0.cancel() } self.connections.removeAll() self.cleanupTask?.cancel() self.cleanupTask = nil + + self.pendingRequests.removeAll() + self.nextTransactionID = 1 + self.finishPacketFlowIfNeeded() } + private func handleConnectionStateUpdate(_ state: NWConnection.State, for conn: NWConnection) { + guard case .running = self.state else { + return + } + + switch state { + case .ready: + self.receiveLoop(for: conn) + case .failed(let error): + SDLLogger.log("[DNSLocalClient] failed with error: \(error.localizedDescription)", for: .debug) + self.stop() + case .cancelled: + self.connections.removeAll { $0 === conn } + if self.connections.isEmpty { + self.stop() + } + default: + () + } + } + + private func receiveLoop(for conn: NWConnection) { + conn.receiveMessage { [weak self] content, _, _, error in + Task { + await self?.handleReceive(content: content, error: error, for: conn) + } + } + } + + private func handleReceive(content: Data?, error: NWError?, for conn: NWConnection) { + guard case .running = self.state else { + return + } + + if let data = content { + self.handleResponse(data: data) + } + + if error == nil && conn.state == .ready { + self.receiveLoop(for: conn) + } + } + + private func handleResponse(data: Data) { + guard case .running = self.state, + let rewrittenTransactionID = Self.readTransactionID(from: data), + let pendingRequest = self.pendingRequests.removeValue(forKey: rewrittenTransactionID) else { + return + } + + let restoredPayload = Self.rewriteTransactionID(in: data, to: pendingRequest.tracker.transactionID) + + let packet = Self.createDNSResponse( + payload: restoredPayload, + srcIP: DNSHelper.dnsDestIpAddr, + srcPort: 53, + destIP: pendingRequest.tracker.clientIP, + destPort: pendingRequest.tracker.clientPort + ) + self.packetContinuation.yield(packet) + } + + private func performCleanup() { + guard case .running = self.state else { + return + } + + let now = Date() + self.pendingRequests = self.pendingRequests.filter { _, request in + now.timeIntervalSince(request.tracker.createdAt) < self.timeoutInterval + } + } + + private func allocateTransactionID() -> UInt16? { + var candidate = self.nextTransactionID == 0 ? 1 : self.nextTransactionID + let start = candidate + + repeat { + if self.pendingRequests[candidate] == nil { + self.nextTransactionID = Self.nextTransactionID(after: candidate) + return candidate + } + + candidate = Self.nextTransactionID(after: candidate) + } while candidate != start + + return nil + } + + private func finishPacketFlowIfNeeded() { + guard !self.didFinishPacketFlow else { + return + } + + self.didFinishPacketFlow = true + self.packetContinuation.finish() + } + + private static func nextTransactionID(after id: UInt16) -> UInt16 { + return id == UInt16.max ? 1 : id &+ 1 + } + + private static func readTransactionID(from payload: Data) -> UInt16? { + guard payload.count >= 2 else { + return nil + } + + return UInt16(payload[0]) << 8 | UInt16(payload[1]) + } + + private static func rewriteTransactionID(in payload: Data, to transactionID: UInt16) -> Data { + guard payload.count >= 2 else { + return payload + } + + var rewrittenPayload = payload + rewrittenPayload[0] = UInt8((transactionID >> 8) & 0xFF) + rewrittenPayload[1] = UInt8(transactionID & 0xFF) + return rewrittenPayload + } } extension DNSLocalClient { - /// 构造发回 TUN 的完整 UDP/IPv4 数据包 static func createDNSResponse(payload: Data, srcIP: UInt32, srcPort: UInt16, destIP: UInt32, destPort: UInt16) -> Data { let udpLen = 8 + payload.count let ipLen = 20 + udpLen - // --- 1. IPv4 Header (20 字节) --- var ipHeader = Data(count: 20) - ipHeader[0] = 0x45 // Version 4, IHL 5 + ipHeader[0] = 0x45 ipHeader[2...3] = withUnsafeBytes(of: UInt16(ipLen).bigEndian) { Data($0) } - ipHeader[8] = 64 // TTL - ipHeader[9] = 17 // Protocol UDP + ipHeader[8] = 64 + ipHeader[9] = 17 - // 填充 IP 地址 ipHeader[12...15] = withUnsafeBytes(of: srcIP.bigEndian) { Data($0) } ipHeader[16...19] = withUnsafeBytes(of: destIP.bigEndian) { Data($0) } - // 计算 IP Checksum let ipChecksum = calculateChecksum(data: ipHeader) ipHeader[10...11] = withUnsafeBytes(of: ipChecksum.bigEndian) { Data($0) } - // --- 2. UDP Header (8 字节) --- var udpHeader = Data(count: 8) udpHeader[0...1] = withUnsafeBytes(of: srcPort.bigEndian) { Data($0) } udpHeader[2...3] = withUnsafeBytes(of: destPort.bigEndian) { Data($0) } udpHeader[4...5] = withUnsafeBytes(of: UInt16(udpLen).bigEndian) { Data($0) } - // UDP Checksum 在 IPv4 中可选,设为 0 可跳过计算(大部分系统接受) udpHeader[6...7] = Data([0, 0]) - // --- 3. 拼接 --- var packet = Data(capacity: ipLen) packet.append(ipHeader) packet.append(udpHeader) @@ -185,7 +268,6 @@ extension DNSLocalClient { return packet } - /// 经典的 Internet Checksum 算法 static func calculateChecksum(data: Data) -> UInt16 { var sum: UInt32 = 0 let count = data.count @@ -193,30 +275,23 @@ extension DNSLocalClient { data.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in guard let baseAddress = ptr.baseAddress else { return } - // 1. 处理成对的 16-bit 单词 let wordCount = count / 2 let words = baseAddress.bindMemory(to: UInt16.self, capacity: wordCount) for i in 0..> 16) != 0 { sum = (sum & 0xffff) + (sum >> 16) } return UInt16(~sum & 0xffff) } - - }