fix Local DNS

This commit is contained in:
anlicheng 2026-04-14 11:33:44 +08:00
parent 0669938f73
commit fcfb2042ca
2 changed files with 198 additions and 118 deletions

View File

@ -282,12 +282,13 @@ actor SDLContextActor {
// dns
let dnsLocalClient = DNSLocalClient()
dnsLocalClient.start()
await dnsLocalClient.start()
SDLLogger.log("[SDLContext] dnsClient started")
self.dnsLocalClient = dnsLocalClient
let packetFlow = await dnsLocalClient.packetFlow
self.dnsLocalWorker = Task.detached {
//
for await packet in dnsLocalClient.packetFlow {
for await packet in packetFlow {
if Task.isCancelled {
break
}
@ -306,7 +307,7 @@ actor SDLContextActor {
// udp
let udpHole = try SDLUDPHole()
let localAddress = try udpHole.start()
SDLLogger.log("[SDLContext] udpHole started, on address: \(localAddress.debugDescription)")
SDLLogger.log("[SDLContext] udpHole started, on address: \(localAddress)")
//
let pingTask = Task.detached {
@ -383,7 +384,9 @@ actor SDLContextActor {
self.dnsWorker = nil
self.dnsClient = nil
self.dnsLocalClient?.stop()
if let dnsLocalClient = self.dnsLocalClient {
await dnsLocalClient.stop()
}
self.dnsLocalWorker?.cancel()
self.dnsLocalWorker = nil
self.dnsLocalClient = nil
@ -784,7 +787,9 @@ actor SDLContextActor {
clientIP: packet.header.source,
clientPort: udpPacket.srcPort,
createdAt: Date())
self.dnsLocalClient?.query(tracker: tracker, dnsPayload: dnsPayload)
if let dnsLocalClient = self.dnsLocalClient {
await dnsLocalClient.query(tracker: tracker, dnsPayload: dnsPayload)
}
}
}
}

View File

@ -1,182 +1,265 @@
import Foundation
import Network
final class DNSLocalClient {
// DNS
actor DNSLocalClient {
struct DNSTracker {
let transactionID: UInt16
let clientIP: UInt32 // IP ()
let clientPort: UInt16 // ()
let createdAt: Date //
let clientIP: UInt32
let clientPort: UInt16
let createdAt: Date
}
private var connections: [NWConnection] = []
private struct PendingRequest {
let tracker: DNSTracker
}
// +
private enum State {
case idle
case running
case stopped
}
private var state: State = .idle
private var connections: [NWConnection] = []
private let dnsServers = ["223.5.5.5", "119.29.29.29"]
public let packetFlow: AsyncStream<Data>
let packetFlow: AsyncStream<Data>
private let packetContinuation: AsyncStream<Data>.Continuation
private let locker = NSLock()
private var trackers: [UInt16: [DNSTracker]] = [:]
private var pendingRequests: [UInt16: PendingRequest] = [:]
private var nextTransactionID: UInt16 = 1
//
private var cleanupTask: Task<Void, Never>?
private let timeoutInterval: TimeInterval = 10.0 // 10
private let timeoutInterval: TimeInterval = 3.0
private var didFinishPacketFlow = false
init() {
let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded)
self.packetFlow = stream
self.packetContinuation = continuation
}
func start() {
for server in dnsServers {
guard case .idle = self.state else {
return
}
self.state = .running
for server in self.dnsServers {
let endpoint = NWEndpoint.hostPort(host: NWEndpoint.Host(server), port: 53)
let parameters = NWParameters.udp
parameters.prohibitedInterfaceTypes = [.other]
let conn = NWConnection(to: endpoint, using: parameters)
conn.stateUpdateHandler = { [weak self] state in
switch state {
case .ready:
self?.receiveLoop(for: conn)
case .failed(let error):
SDLLogger.log("[DNSLocalClient] failed with error: \(error.localizedDescription)", for: .debug)
self?.stop()
case .cancelled:
self?.packetContinuation.finish()
default:
()
Task {
await self?.handleConnectionStateUpdate(state, for: conn)
}
}
conn.start(queue: .global())
connections.append(conn)
self.connections.append(conn)
}
//
self.cleanupTask = Task { [weak self] in
while !Task.isCancelled {
// cleanupTick
try? await Task.sleep(nanoseconds: 5 * 1_000_000_000)
self?.performCleanup()
}
}
}
/// 广
func query(tracker: DNSTracker, dnsPayload: Data) {
locker.lock()
self.trackers[tracker.transactionID, default: []].append(tracker)
locker.unlock()
for conn in connections where conn.state == .ready {
conn.send(content: dnsPayload, completion: .contentProcessed({ _ in }))
}
}
private func receiveLoop(for conn: NWConnection) {
conn.receiveMessage { [weak self] content, _, _, error in
if let data = content {
// AsyncStream
// yield
//
self?.handleResponse(data: data)
}
if error == nil && conn.state == .ready {
self?.receiveLoop(for: conn)
try? await Task.sleep(nanoseconds: 3 * 1_000_000_000)
await self?.performCleanup()
}
}
}
private func handleResponse(data: Data) {
guard data.count > 2 else {
func query(tracker: DNSTracker, dnsPayload: Data) {
guard case .running = self.state, dnsPayload.count >= 2 else {
return
}
let tranId = UInt16(data[0]) << 8 | UInt16(data[1])
locker.lock()
let items = self.trackers.removeValue(forKey: tranId)
locker.unlock()
items?.forEach { tracker in
let packet = Self.createDNSResponse(
payload: data,
srcIP: DNSHelper.dnsDestIpAddr,
srcPort: 53,
destIP: tracker.clientIP,
destPort: tracker.clientPort
)
self.packetContinuation.yield(packet)
}
}
private func performCleanup() {
locker.lock()
defer {
locker.unlock()
guard let transactionID = self.allocateTransactionID() else {
SDLLogger.log("[DNSLocalClient] no available transaction id", for: .debug)
return
}
// ID tracker
let now = Date()
for (id, list) in trackers {
let validItems = list.filter { now.timeIntervalSince($0.createdAt) < timeoutInterval }
if validItems.isEmpty {
trackers.removeValue(forKey: id)
} else {
trackers[id] = validItems
}
self.pendingRequests[transactionID] = PendingRequest(tracker: tracker)
let rewrittenPayload = Self.rewriteTransactionID(in: dnsPayload, to: transactionID)
var hasReadyConnection = false
for conn in self.connections where conn.state == .ready {
hasReadyConnection = true
conn.send(content: rewrittenPayload, completion: .contentProcessed({ error in
if let error {
SDLLogger.log("[DNSLocalClient] send error: \(error.localizedDescription)", for: .debug)
}
}))
}
if !hasReadyConnection {
self.pendingRequests.removeValue(forKey: transactionID)
}
}
func stop() {
connections.forEach { conn in
conn.cancel()
guard self.state != .stopped else {
return
}
self.state = .stopped
self.connections.forEach { $0.cancel() }
self.connections.removeAll()
self.cleanupTask?.cancel()
self.cleanupTask = nil
self.pendingRequests.removeAll()
self.nextTransactionID = 1
self.finishPacketFlowIfNeeded()
}
private func handleConnectionStateUpdate(_ state: NWConnection.State, for conn: NWConnection) {
guard case .running = self.state else {
return
}
switch state {
case .ready:
self.receiveLoop(for: conn)
case .failed(let error):
SDLLogger.log("[DNSLocalClient] failed with error: \(error.localizedDescription)", for: .debug)
self.stop()
case .cancelled:
self.connections.removeAll { $0 === conn }
if self.connections.isEmpty {
self.stop()
}
default:
()
}
}
private func receiveLoop(for conn: NWConnection) {
conn.receiveMessage { [weak self] content, _, _, error in
Task {
await self?.handleReceive(content: content, error: error, for: conn)
}
}
}
private func handleReceive(content: Data?, error: NWError?, for conn: NWConnection) {
guard case .running = self.state else {
return
}
if let data = content {
self.handleResponse(data: data)
}
if error == nil && conn.state == .ready {
self.receiveLoop(for: conn)
}
}
private func handleResponse(data: Data) {
guard case .running = self.state,
let rewrittenTransactionID = Self.readTransactionID(from: data),
let pendingRequest = self.pendingRequests.removeValue(forKey: rewrittenTransactionID) else {
return
}
let restoredPayload = Self.rewriteTransactionID(in: data, to: pendingRequest.tracker.transactionID)
let packet = Self.createDNSResponse(
payload: restoredPayload,
srcIP: DNSHelper.dnsDestIpAddr,
srcPort: 53,
destIP: pendingRequest.tracker.clientIP,
destPort: pendingRequest.tracker.clientPort
)
self.packetContinuation.yield(packet)
}
private func performCleanup() {
guard case .running = self.state else {
return
}
let now = Date()
self.pendingRequests = self.pendingRequests.filter { _, request in
now.timeIntervalSince(request.tracker.createdAt) < self.timeoutInterval
}
}
private func allocateTransactionID() -> UInt16? {
var candidate = self.nextTransactionID == 0 ? 1 : self.nextTransactionID
let start = candidate
repeat {
if self.pendingRequests[candidate] == nil {
self.nextTransactionID = Self.nextTransactionID(after: candidate)
return candidate
}
candidate = Self.nextTransactionID(after: candidate)
} while candidate != start
return nil
}
private func finishPacketFlowIfNeeded() {
guard !self.didFinishPacketFlow else {
return
}
self.didFinishPacketFlow = true
self.packetContinuation.finish()
}
private static func nextTransactionID(after id: UInt16) -> UInt16 {
return id == UInt16.max ? 1 : id &+ 1
}
private static func readTransactionID(from payload: Data) -> UInt16? {
guard payload.count >= 2 else {
return nil
}
return UInt16(payload[0]) << 8 | UInt16(payload[1])
}
private static func rewriteTransactionID(in payload: Data, to transactionID: UInt16) -> Data {
guard payload.count >= 2 else {
return payload
}
var rewrittenPayload = payload
rewrittenPayload[0] = UInt8((transactionID >> 8) & 0xFF)
rewrittenPayload[1] = UInt8(transactionID & 0xFF)
return rewrittenPayload
}
}
extension DNSLocalClient {
/// TUN UDP/IPv4
static func createDNSResponse(payload: Data, srcIP: UInt32, srcPort: UInt16, destIP: UInt32, destPort: UInt16) -> Data {
let udpLen = 8 + payload.count
let ipLen = 20 + udpLen
// --- 1. IPv4 Header (20 ) ---
var ipHeader = Data(count: 20)
ipHeader[0] = 0x45 // Version 4, IHL 5
ipHeader[0] = 0x45
ipHeader[2...3] = withUnsafeBytes(of: UInt16(ipLen).bigEndian) { Data($0) }
ipHeader[8] = 64 // TTL
ipHeader[9] = 17 // Protocol UDP
ipHeader[8] = 64
ipHeader[9] = 17
// IP
ipHeader[12...15] = withUnsafeBytes(of: srcIP.bigEndian) { Data($0) }
ipHeader[16...19] = withUnsafeBytes(of: destIP.bigEndian) { Data($0) }
// IP Checksum
let ipChecksum = calculateChecksum(data: ipHeader)
ipHeader[10...11] = withUnsafeBytes(of: ipChecksum.bigEndian) { Data($0) }
// --- 2. UDP Header (8 ) ---
var udpHeader = Data(count: 8)
udpHeader[0...1] = withUnsafeBytes(of: srcPort.bigEndian) { Data($0) }
udpHeader[2...3] = withUnsafeBytes(of: destPort.bigEndian) { Data($0) }
udpHeader[4...5] = withUnsafeBytes(of: UInt16(udpLen).bigEndian) { Data($0) }
// UDP Checksum IPv4 0
udpHeader[6...7] = Data([0, 0])
// --- 3. ---
var packet = Data(capacity: ipLen)
packet.append(ipHeader)
packet.append(udpHeader)
@ -185,7 +268,6 @@ extension DNSLocalClient {
return packet
}
/// Internet Checksum
static func calculateChecksum(data: Data) -> UInt16 {
var sum: UInt32 = 0
let count = data.count
@ -193,30 +275,23 @@ extension DNSLocalClient {
data.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in
guard let baseAddress = ptr.baseAddress else { return }
// 1. 16-bit
let wordCount = count / 2
let words = baseAddress.bindMemory(to: UInt16.self, capacity: wordCount)
for i in 0..<wordCount {
// 使 bigEndian: words[i]
sum += UInt32(UInt16(bigEndian: words[i]))
}
// 2.
if count % 2 != 0 {
// 8 16-bit
let lastByte = ptr[count - 1]
sum += UInt32(lastByte) << 8
}
}
// 3. 16
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return UInt16(~sum & 0xffff)
}
}