59 lines
2.1 KiB
Swift
59 lines
2.1 KiB
Swift
import Foundation
|
||
import Network
|
||
|
||
final class SDLLocalDNSClient {
|
||
private var connections: [NWConnection] = []
|
||
private let logger: SDLLogger
|
||
|
||
// 准备多个公共 DNS
|
||
private let dnsServers = ["114.114.114.114", "223.5.5.5", "8.8.8.8"]
|
||
|
||
public let payloadFlow: AsyncStream<Data>
|
||
private let payloadContinuation: AsyncStream<Data>.Continuation
|
||
|
||
init(logger: SDLLogger) {
|
||
self.logger = logger
|
||
let (stream, continuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded)
|
||
self.payloadFlow = stream
|
||
self.payloadContinuation = continuation
|
||
}
|
||
|
||
func start() {
|
||
for server in dnsServers {
|
||
let endpoint = NWEndpoint.hostPort(host: NWEndpoint.Host(server), port: 53)
|
||
let parameters = NWParameters.udp
|
||
parameters.prohibitedInterfaceTypes = [.other]
|
||
|
||
let conn = NWConnection(to: endpoint, using: parameters)
|
||
|
||
conn.stateUpdateHandler = { [weak self] state in
|
||
if case .ready = state { self?.receiveLoop(for: conn) }
|
||
}
|
||
conn.start(queue: .global())
|
||
connections.append(conn)
|
||
}
|
||
}
|
||
|
||
/// 并发查询:对所有服务器广播
|
||
func query(dnsPayload: Data) {
|
||
for conn in connections where conn.state == .ready {
|
||
conn.send(content: dnsPayload, completion: .contentProcessed({ _ in }))
|
||
}
|
||
}
|
||
|
||
private func receiveLoop(for conn: NWConnection) {
|
||
conn.receiveMessage { [weak self] content, _, _, error in
|
||
if let data = content {
|
||
// !!!核心:由于 AsyncStream 是流式的
|
||
// 谁先 yield,上层就先收到谁。
|
||
// 只要上层收到了第一个有效响应并回填给系统,
|
||
// 后面迟到的重复响应会被系统协议栈自动忽略(因为 Transaction ID 已失效)
|
||
self?.payloadContinuation.yield(data)
|
||
}
|
||
if error == nil && conn.state == .ready {
|
||
self?.receiveLoop(for: conn)
|
||
}
|
||
}
|
||
}
|
||
}
|