From 182f6ffd17f79b0f19ab10a218cf58f75805da50 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Tue, 10 Mar 2026 17:57:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=87=8F=E5=B0=91actor=E7=9A=84=E4=BD=BF?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Tun/Punchnet/Actors/SDLContextActor.swift | 26 +++++++++++------------ Tun/Punchnet/SessionManager.swift | 17 ++++++++++++++- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index d1c09a2..16e926f 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -46,7 +46,7 @@ actor SDLContextActor { private var dnsWorker: Task? private var quicClient: SDLQUICClient? - private var quicWorkers: [Task]? + private var quicWorker: Task? nonisolated private let puncherActor: SDLPuncherActor // 网络探测对象 @@ -136,7 +136,7 @@ actor SDLContextActor { } private func startQUICClient() async throws -> SDLQUICClient { - self.quicWorkers?.forEach {$0.cancel()} + self.quicWorker?.cancel() self.quicClient?.stop() // 启动monitor @@ -149,7 +149,7 @@ actor SDLContextActor { try await Task.sleep(for: .seconds(0.2)) SDLLogger.shared.log("[SDLContext] start quic client ready") - let messageTask = Task.detached { + self.quicWorker = Task.detached { for await message in quicClient.messageStream { switch message { case .welcome(let welcome): @@ -174,9 +174,7 @@ actor SDLContextActor { } } } - self.quicClient = quicClient - self.quicWorkers = [messageTask] return quicClient } @@ -310,8 +308,8 @@ actor SDLContextActor { self.udpHoleWorkers?.forEach { $0.cancel() } self.udpHoleWorkers = nil - self.quicWorkers?.forEach { $0.cancel() } - self.quicWorkers = nil + self.quicWorker?.cancel() + self.quicWorker = nil self.dnsWorker?.cancel() self.dnsWorker = nil @@ -433,7 +431,7 @@ actor SDLContextActor { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac SDLLogger.shared.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) - await sessionManager.removeSession(dstMac: dstMac) + sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): SDLLogger.shared.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) @@ -473,7 +471,7 @@ actor SDLContextActor { } } - private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { + private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { let networkAddr = config.networkAddress SDLLogger.shared.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) @@ -488,18 +486,18 @@ actor SDLContextActor { self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) + self.sessionManager.addSession(session: session) } else { SDLLogger.shared.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } } - private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { + private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) + self.sessionManager.addSession(session: session) } else { SDLLogger.shared.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } @@ -553,7 +551,7 @@ actor SDLContextActor { let identitySnapshot = self.snapshotPublisher.current() let ruleMap = identitySnapshot.lookup(data.identityID) - if self.authIPPacket(ipPacket: ipPacket, ruleMap: ruleMap) { + if self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) { let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([packet]) SDLLogger.shared.log("[SDLContext] identity: \(data.identityID), allow", level: .debug) @@ -568,7 +566,7 @@ actor SDLContextActor { } } - private func authIPPacket(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { + private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool { // 进来的数据反转一下,然后再处理 if let reverseFlowSession = ipPacket.flowSession()?.reverse(), self.flowSessionManager.hasSession(reverseFlowSession) { diff --git a/Tun/Punchnet/SessionManager.swift b/Tun/Punchnet/SessionManager.swift index 06e5fa3..e0437e5 100644 --- a/Tun/Punchnet/SessionManager.swift +++ b/Tun/Punchnet/SessionManager.swift @@ -28,7 +28,8 @@ struct Session { } } -actor SessionManager { +class SessionManager { + private let locker = NSLock() private var sessions: [Data:Session] = [:] // session的有效时间 @@ -36,6 +37,11 @@ actor SessionManager { func getSession(toAddress: Data) -> Session? { let timestamp = Int32(Date().timeIntervalSince1970) + locker.lock() + defer { + locker.unlock() + } + if var session = self.sessions[toAddress] { if session.lastTimestamp >= timestamp + ttl { session.updateLastTimestamp(timestamp) @@ -50,10 +56,19 @@ actor SessionManager { } func addSession(session: Session) { + locker.lock() + defer { + locker.unlock() + } self.sessions[session.dstMac] = session } func removeSession(dstMac: Data) { + locker.lock() + defer { + locker.unlock() + } + self.sessions.removeValue(forKey: dstMac) }