diff --git a/Sources/Punchnet/HolerManager.swift b/Sources/Punchnet/HolerManager.swift deleted file mode 100644 index e37ab86..0000000 --- a/Sources/Punchnet/HolerManager.swift +++ /dev/null @@ -1,31 +0,0 @@ -// -// HolerManager.swift -// sdlan -// -// Created by 安礼成 on 2025/7/14. -// -import Foundation - -actor HolerManager { - private var holers: [Data:Task<(), Never>] = [:] - - func addHoler(dstMac: Data, creator: @escaping () -> Task<(), Never>) { - if let task = self.holers[dstMac] { - if task.isCancelled { - self.holers[dstMac] = creator() - } - } else { - self.holers[dstMac] = creator() - } - } - - func cleanup() { - for holer in holers.values { - holer.cancel() - } - self.holers.removeAll() - } - -} - - diff --git a/Sources/Punchnet/SDLContext.swift b/Sources/Punchnet/SDLContext.swift index e840170..deb52f1 100644 --- a/Sources/Punchnet/SDLContext.swift +++ b/Sources/Punchnet/SDLContext.swift @@ -56,7 +56,6 @@ public class SDLContext: @unchecked Sendable { let provider: NEPacketTunnelProvider private var sessionManager: SessionManager - private var holerManager: HolerManager private var arpServer: ArpServer // 记录最后发送的stunRequest的cookie @@ -76,6 +75,16 @@ public class SDLContext: @unchecked Sendable { private var flowTracer = SDLFlowTracerActor() private var flowTracerCancel: AnyCancellable? + // 处理holer + private var holerPublishers: [Data:PassthroughSubject] = [:] + private var bag = Set() + + struct RegisterRequest { + let srcMac: Data + let dstMac: Data + let networkId: UInt32 + } + public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) { self.config = config self.rsaCipher = rsaCipher @@ -88,7 +97,6 @@ public class SDLContext: @unchecked Sendable { self.provider = provider self.sessionManager = SessionManager() - self.holerManager = HolerManager() self.arpServer = ArpServer(known_macs: [:]) } @@ -278,7 +286,7 @@ public class SDLContext: @unchecked Sendable { let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 - await self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac) + await self.udpHole?.sendRegister(remoteAddress: remoteAddress, networkId: self.devAddr.networkID, srcMac: self.devAddr.mac, dst_mac: sendRegisterEvent.dstMac) } case .networkShutdown(let shutdownEvent): @@ -444,8 +452,6 @@ public class SDLContext: @unchecked Sendable { // 网卡配置设置必须成功 do { try await self.provider.setTunnelNetworkSettings(networkSettings) - - await self.holerManager.cleanup() self.startReader() NSLog("[SDLContext] setTunnelNetworkSettings success, start read packet") @@ -521,15 +527,30 @@ public class SDLContext: @unchecked Sendable { await self.flowTracer.inc(num: data.count, type: .forward) // 尝试打洞 - await self.holerManager.addHoler(dstMac: dstMac) { - self.holerTask(dstMac: dstMac) - } + let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID) + self.submitRegisterRequest(dstMac: dstMac, request: registerRequest) + } } - func holerTask(dstMac: Data) -> Task<(), Never> { - return Task { - guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac).get() else { + private func submitRegisterRequest(dstMac: Data, request: RegisterRequest) { + if let publisher = self.holerPublishers[dstMac] { + publisher.send(request) + } else { + let publisher = PassthroughSubject() + publisher.debounce(for: .seconds(5), scheduler: DispatchQueue.global()) + .sink { request in + self.tryHole(request: request) + } + .store(in: &self.bag) + + self.holerPublishers[dstMac] = publisher + } + } + + private func tryHole(request: RegisterRequest) { + Task { + guard let message = try? await self.superClient?.queryInfo(dst_mac: request.dstMac).get() else { return } @@ -540,7 +561,7 @@ public class SDLContext: @unchecked Sendable { if let remoteAddress = peerInfo.v4Info.socketAddress() { SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .warning) // 发送register包 - await self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac) + await self.udpHole?.sendRegister(remoteAddress: remoteAddress, networkId: request.networkId, srcMac: request.srcMac, dst_mac: request.dstMac) } else { SDLLogger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) } diff --git a/Sources/Punchnet/SDLSuperClient.swift b/Sources/Punchnet/SDLSuperClient.swift index c0d5f40..26546d8 100644 --- a/Sources/Punchnet/SDLSuperClient.swift +++ b/Sources/Punchnet/SDLSuperClient.swift @@ -128,7 +128,7 @@ actor SDLSuperClient { } // 查询目标服务器的相关信息 - func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> EventLoopFuture { + func queryInfo(dst_mac: Data) async throws -> EventLoopFuture { var queryInfo = SDLQueryInfo() queryInfo.dstMac = dst_mac diff --git a/Sources/Punchnet/SDLUDPHole.swift b/Sources/Punchnet/SDLUDPHole.swift index 72eaa5a..d58d250 100644 --- a/Sources/Punchnet/SDLUDPHole.swift +++ b/Sources/Punchnet/SDLUDPHole.swift @@ -190,14 +190,14 @@ actor SDLUDPHole { } // 发送register包 - func sendRegister(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) { + func sendRegister(remoteAddress: SocketAddress, networkId: UInt32, srcMac: Data, dst_mac: Data) { var register = SDLRegister() - register.networkID = ctx.devAddr.networkID - register.srcMac = ctx.devAddr.mac + register.networkID = networkId + register.srcMac = srcMac register.dstMac = dst_mac if let packet = try? register.serializedData() { - SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: ctx.devAddr.mac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug) + SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: srcMac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug) self.send(remoteAddress: remoteAddress, type: .register, data: packet) } }