From cb33d8142875eefc260ff66898759cfbc6a732b4 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Thu, 12 Mar 2026 17:00:07 +0800 Subject: [PATCH] fix arpServer --- Tun/Punchnet/Actors/SDLContextActor.swift | 15 ++++-- Tun/Punchnet/ArpServer.swift | 61 ++++++++++++----------- Tun/Punchnet/SDLMessage.swift | 7 ++- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 41c4b25..4b53add 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -56,7 +56,7 @@ actor SDLContextActor { private var readTask: Task<(), Never>? nonisolated private let sessionManager = SessionManager() - nonisolated private let arpServer = ArpServer() + nonisolated private let arpServer: ArpServer // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? @@ -93,6 +93,8 @@ actor SDLContextActor { self.puncherActor = SDLPuncherActor() self.proberActor = SDLNATProberActor(addressArray: config.stunProbeSocketAddressArray) + self.arpServer = ArpServer() + // 权限控制 let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty()) self.identifyStore = IdentityStore(publisher: snapshotPublisher) @@ -102,6 +104,9 @@ actor SDLContextActor { public func start() async { self.startMonitor() + // 启动arp的定时清理任务 + await self.arpServer.start() + // 启动puncher的定期扫描任务 await self.puncherActor.start() @@ -172,7 +177,7 @@ actor SDLContextActor { await self.identifyStore.applyPolicyResponse(policyResponse) case .arpResponse(let arpResponse): SDLLogger.shared.log("[SDLContext] get arp response: \(arpResponse)") - self.arpServer.handleArpResponse(arpResponse: arpResponse) + await self.arpServer.handleArpResponse(arpResponse: arpResponse) } } } @@ -536,7 +541,7 @@ actor SDLContextActor { await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: SDLLogger.shared.log("[SDLContext] get arp response packet", level: .debug) - self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) + await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { SDLLogger.shared.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug) @@ -660,7 +665,7 @@ actor SDLContextActor { } // 查找arp缓存中是否有目标mac地址 - if let dstMac = self.arpServer.query(ip: dstIp) { + if let dstMac = await self.arpServer.query(ip: dstIp) { SDLLogger.shared.log("[SDLContext] dstIp: \(dstIp.asIpAddress()), dst_mac is: \(SDLUtil.formatMacAddress(mac: dstMac))", level: .debug) await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } @@ -670,7 +675,7 @@ actor SDLContextActor { // let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp) // await self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal()) - try? self.arpServer.arpRequest(targetIp: dstIp, use: self.quicClient) + try? await self.arpServer.arpRequest(targetIp: dstIp, use: self.quicClient) } } diff --git a/Tun/Punchnet/ArpServer.swift b/Tun/Punchnet/ArpServer.swift index 2334f06..e0a41ee 100644 --- a/Tun/Punchnet/ArpServer.swift +++ b/Tun/Punchnet/ArpServer.swift @@ -1,35 +1,46 @@ // // ArpServer.swift // sdlan -// +// 1. 通过ip地址查找mac地址 +// 2. 要限制单位时间内,同一个ip的查询 // Created by 安礼成 on 2025/7/14. // import Foundation import Darwin -final class ArpServer { +actor ArpServer { // 增加缓存时间逻辑 struct ArpEntry { var mac: Data var expireTime: TimeInterval } - private let locker = NSLock() + private var coolingDown: [UInt32: Date] = [:] private var packetId: UInt32 = 1 private var known_macs: [UInt32: ArpEntry] = [:] private let arpTTL: TimeInterval + private var cleanupTask: Task? + init(arpTTL: TimeInterval = 300) { self.arpTTL = arpTTL } - func query(ip: UInt32) -> Data? { - locker.lock() - defer { - locker.unlock() + func start() { + guard self.cleanupTask == nil else { + return } + self.cleanupTask = Task { + while !Task.isCancelled { + try? await Task.sleep(for: .seconds(1)) + self.cleanup() + } + } + } + + func query(ip: UInt32) -> Data? { guard let entry = known_macs[ip] else { return nil } @@ -43,47 +54,32 @@ final class ArpServer { } func append(ip: UInt32, mac: Data) { - locker.lock() - defer { - locker.unlock() - } - let expireAt = Date().timeIntervalSince1970 + arpTTL self.known_macs[ip] = ArpEntry(mac: mac, expireTime: expireAt) } func remove(ip: UInt32) { - locker.lock() - defer { - locker.unlock() - } - self.known_macs.removeValue(forKey: ip) } func dropMacs(macs: [Data]) { - locker.lock() - defer { - locker.unlock() - } self.known_macs = self.known_macs.filter { !macs.contains($0.value.mac) } } func clear() { - locker.lock() self.known_macs = [:] - locker.unlock() } func arpRequest(targetIp: UInt32, use quicClient: SDLQUICClient?) throws { - guard let quicClient else { + guard let quicClient, self.coolingDown[targetIp] == nil else { return } - locker.lock() let pktID = self.packetId - self.packetId += 1 - locker.unlock() + self.packetId &+= 1 + + // 单位时间内指允许提交一次 + self.coolingDown[targetIp] = Date().addingTimeInterval(3) // 进行arp查询 var arpRequest = SDLArpRequest() @@ -98,10 +94,17 @@ final class ArpServer { let targetMac = arpResponse.targetMac if !targetMac.isEmpty { let expireAt = Date().timeIntervalSince1970 + arpTTL - locker.lock() self.known_macs[targetIp] = ArpEntry(mac: targetMac, expireTime: expireAt) - locker.unlock() } } + private func cleanup() { + let now = Date() + self.coolingDown = self.coolingDown.filter { $0.value > now } + } + + deinit { + self.cleanupTask?.cancel() + } + } diff --git a/Tun/Punchnet/SDLMessage.swift b/Tun/Punchnet/SDLMessage.swift index 416ddfb..ec4a3b9 100644 --- a/Tun/Punchnet/SDLMessage.swift +++ b/Tun/Punchnet/SDLMessage.swift @@ -87,9 +87,12 @@ enum SDLNAKErrorCode: UInt8 { } extension SDLV4Info { - func socketAddress() async throws -> SocketAddress { - let address = "\(v4[0]).\(v4[1]).\(v4[2]).\(v4[3])" + func socketAddress() async throws -> SocketAddress? { + guard self.v4.count == 4 else { + return nil + } + let address = "\(v4[0]).\(v4[1]).\(v4[2]).\(v4[3])" return try await SDLAddressResolver.shared.resolve(host: address, port: Int(port)) } }