import Foundation import Network actor DNSLocalClient { struct DNSTracker { let transactionID: UInt16 let clientIP: UInt32 let clientPort: UInt16 let createdAt: Date } private struct PendingRequest { let tracker: DNSTracker } private enum State { case idle case running case stopped } private var state: State = .idle private var connections: [NWConnection] = [] private var receiveTasks: [ObjectIdentifier: Task] = [:] private let dnsServers = ["223.5.5.5", "119.29.29.29"] let packetFlow: AsyncStream private let packetContinuation: AsyncStream.Continuation private var pendingRequests: [UInt16: PendingRequest] = [:] private var nextTransactionID: UInt16 = 1 private var cleanupTask: Task? private let timeoutInterval: TimeInterval = 3.0 private var didFinishPacketFlow = false init() { let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .bufferingNewest(256)) self.packetFlow = stream self.packetContinuation = continuation } func start() { guard case .idle = self.state else { return } self.state = .running for server in self.dnsServers { let endpoint = NWEndpoint.hostPort(host: NWEndpoint.Host(server), port: 53) let parameters = NWParameters.udp parameters.prohibitedInterfaceTypes = [.other] let conn = NWConnection(to: endpoint, using: parameters) conn.stateUpdateHandler = { [weak self] state in Task { await self?.handleConnectionStateUpdate(state, for: conn) } } conn.start(queue: .global()) self.connections.append(conn) } self.cleanupTask = Task { [weak self] in while !Task.isCancelled { try? await Task.sleep(nanoseconds: 3 * 1_000_000_000) await self?.performCleanup() } } } func query(tracker: DNSTracker, dnsPayload: Data) { guard case .running = self.state, dnsPayload.count >= 2 else { return } guard let transactionID = self.allocateTransactionID() else { SDLLogger.log("[DNSLocalClient] no available transaction id", for: .debug) return } self.pendingRequests[transactionID] = PendingRequest(tracker: tracker) let rewrittenPayload = Self.rewriteTransactionID(in: dnsPayload, to: transactionID) var hasReadyConnection = false for conn in self.connections where conn.state == .ready { hasReadyConnection = true conn.send(content: rewrittenPayload, completion: .contentProcessed({ error in if let error { SDLLogger.log("[DNSLocalClient] send error: \(error.localizedDescription)", for: .debug) } })) } if !hasReadyConnection { self.pendingRequests.removeValue(forKey: transactionID) } } func stop() { guard self.state != .stopped else { return } self.state = .stopped self.receiveTasks.values.forEach { $0.cancel() } self.receiveTasks.removeAll() self.connections.forEach { $0.cancel() } self.connections.removeAll() self.cleanupTask?.cancel() self.cleanupTask = nil self.pendingRequests.removeAll() self.nextTransactionID = 1 self.finishPacketFlowIfNeeded() } private func handleConnectionStateUpdate(_ state: NWConnection.State, for conn: NWConnection) { guard case .running = self.state else { return } switch state { case .ready: self.startReceiveTask(for: conn) case .failed(let error): SDLLogger.log("[DNSLocalClient] failed with error: \(error.localizedDescription)", for: .debug) self.stop() case .cancelled: let key = ObjectIdentifier(conn) self.receiveTasks.removeValue(forKey: key)?.cancel() self.connections.removeAll { $0 === conn } if self.connections.isEmpty { self.stop() } default: () } } private func startReceiveTask(for conn: NWConnection) { let key = ObjectIdentifier(conn) guard self.receiveTasks[key] == nil else { return } let stream = Self.makeReceiveStream(for: conn) self.receiveTasks[key] = Task { [weak self] in guard let self else { return } for await data in stream { await self.handleResponse(data: data) } await self.didFinishReceiving(for: conn) } } private func didFinishReceiving(for conn: NWConnection) { let key = ObjectIdentifier(conn) self.receiveTasks.removeValue(forKey: key) } private func handleResponse(data: Data) { guard case .running = self.state, let rewrittenTransactionID = Self.readTransactionID(from: data), let pendingRequest = self.pendingRequests.removeValue(forKey: rewrittenTransactionID) else { return } let restoredPayload = Self.rewriteTransactionID(in: data, to: pendingRequest.tracker.transactionID) let packet = Self.createDNSResponse( payload: restoredPayload, srcIP: DNSHelper.dnsDestIpAddr, srcPort: 53, destIP: pendingRequest.tracker.clientIP, destPort: pendingRequest.tracker.clientPort ) self.packetContinuation.yield(packet) } private func performCleanup() { guard case .running = self.state else { return } let now = Date() self.pendingRequests = self.pendingRequests.filter { _, request in now.timeIntervalSince(request.tracker.createdAt) < self.timeoutInterval } } private func allocateTransactionID() -> UInt16? { var candidate = self.nextTransactionID == 0 ? 1 : self.nextTransactionID let start = candidate repeat { if self.pendingRequests[candidate] == nil { self.nextTransactionID = Self.nextTransactionID(after: candidate) return candidate } candidate = Self.nextTransactionID(after: candidate) } while candidate != start return nil } private func finishPacketFlowIfNeeded() { guard !self.didFinishPacketFlow else { return } self.didFinishPacketFlow = true self.packetContinuation.finish() } private static func nextTransactionID(after id: UInt16) -> UInt16 { return id == UInt16.max ? 1 : id &+ 1 } private static func readTransactionID(from payload: Data) -> UInt16? { guard payload.count >= 2 else { return nil } return UInt16(payload[0]) << 8 | UInt16(payload[1]) } private static func rewriteTransactionID(in payload: Data, to transactionID: UInt16) -> Data { guard payload.count >= 2 else { return payload } var rewrittenPayload = payload rewrittenPayload[0] = UInt8((transactionID >> 8) & 0xFF) rewrittenPayload[1] = UInt8(transactionID & 0xFF) return rewrittenPayload } private static func makeReceiveStream(for conn: NWConnection) -> AsyncStream { return AsyncStream(bufferingPolicy: .bufferingNewest(256)) { continuation in func receiveNext() { conn.receiveMessage { content, _, _, error in if let data = content, !data.isEmpty { continuation.yield(data) } if error == nil && conn.state == .ready { receiveNext() } else { continuation.finish() } } } receiveNext() } } } extension DNSLocalClient { static func createDNSResponse(payload: Data, srcIP: UInt32, srcPort: UInt16, destIP: UInt32, destPort: UInt16) -> Data { let udpLen = 8 + payload.count let ipLen = 20 + udpLen var ipHeader = Data(count: 20) ipHeader[0] = 0x45 ipHeader[2...3] = withUnsafeBytes(of: UInt16(ipLen).bigEndian) { Data($0) } ipHeader[8] = 64 ipHeader[9] = 17 ipHeader[12...15] = withUnsafeBytes(of: srcIP.bigEndian) { Data($0) } ipHeader[16...19] = withUnsafeBytes(of: destIP.bigEndian) { Data($0) } let ipChecksum = calculateChecksum(data: ipHeader) ipHeader[10...11] = withUnsafeBytes(of: ipChecksum.bigEndian) { Data($0) } var udpHeader = Data(count: 8) udpHeader[0...1] = withUnsafeBytes(of: srcPort.bigEndian) { Data($0) } udpHeader[2...3] = withUnsafeBytes(of: destPort.bigEndian) { Data($0) } udpHeader[4...5] = withUnsafeBytes(of: UInt16(udpLen).bigEndian) { Data($0) } udpHeader[6...7] = Data([0, 0]) var packet = Data(capacity: ipLen) packet.append(ipHeader) packet.append(udpHeader) packet.append(payload) return packet } static func calculateChecksum(data: Data) -> UInt16 { var sum: UInt32 = 0 let count = data.count data.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in guard let baseAddress = ptr.baseAddress else { return } let wordCount = count / 2 let words = baseAddress.bindMemory(to: UInt16.self, capacity: wordCount) for i in 0..> 16) != 0 { sum = (sum & 0xffff) + (sum >> 16) } return UInt16(~sum & 0xffff) } }