punchnet-macos/Tun/Punchnet/DNS/DNSLocalClient.swift

188 lines
6.4 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Foundation
import Network
final class DNSLocalClient {
// DNS
struct DNSTracker {
let transactionID: UInt16
let clientIP: UInt32 // IP ()
let clientPort: UInt16 // ()
let createdAt: Date //
}
private var connections: [NWConnection] = []
// DNS
private let dnsServers = ["114.114.114.114", "223.5.5.5", "8.8.8.8"]
public let packetFlow: AsyncStream<Data>
private let packetContinuation: AsyncStream<Data>.Continuation
private let locker = NSLock()
private var trackers: [UInt16: [DNSTracker]] = [:]
init() {
let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded)
self.packetFlow = stream
self.packetContinuation = continuation
}
func start() {
for server in dnsServers {
let endpoint = NWEndpoint.hostPort(host: NWEndpoint.Host(server), port: 53)
let parameters = NWParameters.udp
parameters.prohibitedInterfaceTypes = [.other]
let conn = NWConnection(to: endpoint, using: parameters)
conn.stateUpdateHandler = { [weak self] state in
switch state {
case .ready:
self?.receiveLoop(for: conn)
case .failed(let error):
SDLLogger.shared.log("[DNSLocalClient] failed with error: \(error.localizedDescription)")
self?.stop()
case .cancelled:
self?.packetContinuation.finish()
default:
()
}
}
conn.start(queue: .global())
connections.append(conn)
}
}
/// 广
func query(tracker: DNSTracker, dnsPayload: Data) {
locker.lock()
self.trackers[tracker.transactionID, default: []].append(tracker)
locker.unlock()
for conn in connections where conn.state == .ready {
conn.send(content: dnsPayload, completion: .contentProcessed({ _ in }))
}
}
private func receiveLoop(for conn: NWConnection) {
conn.receiveMessage { [weak self] content, _, _, error in
if let data = content {
// AsyncStream
// yield
//
self?.handleResponse(data: data)
}
if error == nil && conn.state == .ready {
self?.receiveLoop(for: conn)
}
}
}
private func handleResponse(data: Data) {
guard data.count > 2 else {
return
}
let tranId = UInt16(data[0]) << 8 | UInt16(data[1])
locker.lock()
let items = self.trackers.removeValue(forKey: tranId)
locker.unlock()
items?.forEach { tracker in
let packet = Self.createDNSResponse(
payload: data,
srcIP: DNSHelper.dnsDestIpAddr,
srcPort: 53,
destIP: tracker.clientIP,
destPort: tracker.clientPort
)
self.packetContinuation.yield(packet)
}
}
func stop() {
connections.forEach { conn in
conn.cancel()
}
self.connections.removeAll()
}
}
extension DNSLocalClient {
/// TUN UDP/IPv4
static func createDNSResponse(payload: Data, srcIP: UInt32, srcPort: UInt16, destIP: UInt32, destPort: UInt16) -> Data {
let udpLen = 8 + payload.count
let ipLen = 20 + udpLen
// --- 1. IPv4 Header (20 ) ---
var ipHeader = Data(count: 20)
ipHeader[0] = 0x45 // Version 4, IHL 5
ipHeader[2...3] = withUnsafeBytes(of: UInt16(ipLen).bigEndian) { Data($0) }
ipHeader[8] = 64 // TTL
ipHeader[9] = 17 // Protocol UDP
// IP
ipHeader[12...15] = withUnsafeBytes(of: srcIP.bigEndian) { Data($0) }
ipHeader[16...19] = withUnsafeBytes(of: destIP.bigEndian) { Data($0) }
// IP Checksum
let ipChecksum = calculateChecksum(data: ipHeader)
ipHeader[10...11] = withUnsafeBytes(of: ipChecksum.bigEndian) { Data($0) }
// --- 2. UDP Header (8 ) ---
var udpHeader = Data(count: 8)
udpHeader[0...1] = withUnsafeBytes(of: srcPort.bigEndian) { Data($0) }
udpHeader[2...3] = withUnsafeBytes(of: destPort.bigEndian) { Data($0) }
udpHeader[4...5] = withUnsafeBytes(of: UInt16(udpLen).bigEndian) { Data($0) }
// UDP Checksum IPv4 0
udpHeader[6...7] = Data([0, 0])
// --- 3. ---
var packet = Data(capacity: ipLen)
packet.append(ipHeader)
packet.append(udpHeader)
packet.append(payload)
return packet
}
/// Internet Checksum
static func calculateChecksum(data: Data) -> UInt16 {
var sum: UInt32 = 0
let count = data.count
data.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in
guard let baseAddress = ptr.baseAddress else { return }
// 1. 16-bit
let wordCount = count / 2
let words = baseAddress.bindMemory(to: UInt16.self, capacity: wordCount)
for i in 0..<wordCount {
// 使 bigEndian: words[i]
sum += UInt32(UInt16(bigEndian: words[i]))
}
// 2.
if count % 2 != 0 {
// 8 16-bit
let lastByte = ptr[count - 1]
sum += UInt32(lastByte) << 8
}
}
// 3. 16
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16)
}
return UInt16(~sum & 0xffff)
}
}