diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index ef1569d..320d1e0 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -121,6 +121,7 @@ actor SDLContextActor { self.startMonitor() // 启动arp的定时清理任务 + await self.puncherActor.start() await self.arpServer.start() await self.startDnsClient() await self.startDnsLocalClient() @@ -369,6 +370,7 @@ actor SDLContextActor { self.state = .unregistered await self.supervisor.stop() + await self.puncherActor.stop() self.udpHoleWorkers?.forEach { $0.cancel() } self.udpHoleWorkers = nil diff --git a/Tun/Punchnet/Actors/SDLPuncherActor.swift b/Tun/Punchnet/Actors/SDLPuncherActor.swift index 3e1367e..9a8c3c4 100644 --- a/Tun/Punchnet/Actors/SDLPuncherActor.swift +++ b/Tun/Punchnet/Actors/SDLPuncherActor.swift @@ -11,18 +11,8 @@ import NIOCore actor SDLPuncherActor { // 10秒内只需要提交一次查询 nonisolated private let cooldownInterval: TimeInterval = 10 - - struct RequestContext { - let expireAt: Date - let request: RegisterRequest - - func isExpired() -> Bool { - return expireAt < Date() - } - } - - // dstMac - private var pendingRequests: [Data: RequestContext] = [:] + // 等待peerInfo返回的超时时间 + nonisolated private let peerInfoTimeout: TimeInterval = 3 struct RegisterRequest { let srcMac: Data @@ -30,57 +20,137 @@ actor SDLPuncherActor { let networkId: UInt32 } + private enum RequestPhase { + case waitingPeerInfo(deadline: Date) + case coolingDown + } + + private struct RequestEntry { + let request: RegisterRequest + let cooldownUntil: Date + var phase: RequestPhase + + func canSubmit(at now: Date) -> Bool { + return cooldownUntil <= now + } + + func isWaitingPeerInfo(at now: Date) -> Bool { + guard case .waitingPeerInfo(let deadline) = self.phase else { + return false + } + + return deadline > now + } + + mutating func markCoolingDown() { + self.phase = .coolingDown + } + } + + // dstMac + private var requestEntries: [Data: RequestEntry] = [:] + private var cleanupTask: Task? + + func start() { + guard self.cleanupTask == nil else { + return + } + + self.cleanupTask = Task { [weak self] in + while !Task.isCancelled { + try? await Task.sleep(for: .seconds(1)) + await self?.cleanupExpiredEntries() + } + } + } + func submitRegisterRequest(quicClient: SDLQUICClient?, request: RegisterRequest) { guard let quicClient else { return } - // 数据不存在,或者已经过期;才能提交 - let dstMac = request.dstMac - guard self.isRequestExpired(dstMac: dstMac) else { + let now = Date() + self.cleanupExpiredEntries(now: now) + + if let entry = self.requestEntries[request.dstMac], !entry.canSubmit(at: now) { return } - self.pendingRequests[dstMac] = .init(expireAt: Date().addingTimeInterval(cooldownInterval), request: request) - // 触发一次打洞 var queryInfo = SDLQueryInfo() queryInfo.dstMac = request.dstMac - if let queryData = try? queryInfo.serializedData() { - quicClient.send(type: .queryInfo, data: queryData) + + guard let queryData = try? queryInfo.serializedData() else { + SDLLogger.log("[SDLPuncherActor] failed to encode queryInfo", for: .debug) + return } + + self.requestEntries[request.dstMac] = RequestEntry( + request: request, + cooldownUntil: now.addingTimeInterval(self.cooldownInterval), + phase: .waitingPeerInfo(deadline: now.addingTimeInterval(self.peerInfoTimeout)) + ) + + quicClient.send(type: .queryInfo, data: queryData) } func handlePeerInfo(using udpHole: SDLUDPHole?, peerInfo: SDLPeerInfo) async { - // 如果服务器返回了值,优先删除掉; 避免数据堆积 - guard let requestContext = pendingRequests.removeValue(forKey: peerInfo.dstMac) else { + let now = Date() + self.cleanupExpiredEntries(now: now) + + guard var entry = self.requestEntries[peerInfo.dstMac] else { return } - // 判断必要的值是否存在 - guard let udpHole, peerInfo.hasV4Info else { + guard entry.isWaitingPeerInfo(at: now) else { return } - if let remoteAddress = try? await peerInfo.v4Info.socketAddress() { - SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", for: .punchnet) - // 发送register包 - var register = SDLRegister() - register.networkID = requestContext.request.networkId - register.srcMac = requestContext.request.srcMac - register.dstMac = requestContext.request.dstMac - - if let registerData = try? register.serializedData() { - udpHole.send(type: .register, data: registerData, remoteAddress: remoteAddress) - } + entry.markCoolingDown() + self.requestEntries[peerInfo.dstMac] = entry + + guard let udpHole else { + SDLLogger.log("[SDLPuncherActor] udpHole is nil when peerInfo arrived", for: .debug) + return + } + + guard peerInfo.hasV4Info else { + SDLLogger.log("[SDLPuncherActor] peerInfo missing v4Info", for: .debug) + return + } + + guard let remoteAddress = try? await peerInfo.v4Info.socketAddress() else { + SDLLogger.log("[SDLPuncherActor] failed to resolve peerInfo.v4Info", for: .debug) + return + } + + SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", for: .punchnet) + + var register = SDLRegister() + register.networkID = entry.request.networkId + register.srcMac = entry.request.srcMac + register.dstMac = entry.request.dstMac + + guard let registerData = try? register.serializedData() else { + SDLLogger.log("[SDLPuncherActor] failed to encode register", for: .debug) + return + } + + udpHole.send(type: .register, data: registerData, remoteAddress: remoteAddress) + } + + func stop() { + self.cleanupTask?.cancel() + self.cleanupTask = nil + self.requestEntries.removeAll() + } + + private func cleanupExpiredEntries(now: Date = Date()) { + self.requestEntries = self.requestEntries.filter { _, entry in + !entry.canSubmit(at: now) } } - // 判断是否需要提交 - func isRequestExpired(dstMac: Data) -> Bool { - if let context = pendingRequests[dstMac] { - return context.isExpired() - } - return true + deinit { + self.cleanupTask?.cancel() } - } diff --git a/Tun/Punchnet/DNS/DNSLocalClient.swift b/Tun/Punchnet/DNS/DNSLocalClient.swift index a0e052c..6aef9b6 100644 --- a/Tun/Punchnet/DNS/DNSLocalClient.swift +++ b/Tun/Punchnet/DNS/DNSLocalClient.swift @@ -149,15 +149,14 @@ actor DNSLocalClient { let stream = Self.makeReceiveStream(for: conn) self.receiveTasks[key] = Task { [weak self] in - guard let self else { - return - } - for await data in stream { + guard let self else { + break + } await self.handleResponse(data: data) } - await self.didFinishReceiving(for: conn) + await self?.didFinishReceiving(for: conn) } }