diff --git a/Tun/Punchnet/ArpServer.swift b/Tun/Punchnet/ArpServer.swift index 1963f97..d26d1b1 100644 --- a/Tun/Punchnet/ArpServer.swift +++ b/Tun/Punchnet/ArpServer.swift @@ -5,27 +5,47 @@ // Created by 安礼成 on 2025/7/14. // import Foundation +import Darwin -actor ArpServer { +final class ArpServer { private var known_macs: [UInt32:Data] = [:] + private var lock = os_unfair_lock() init(known_macs: [UInt32:Data]) { self.known_macs = known_macs } func query(ip: UInt32) -> Data? { - return self.known_macs[ip] + return withLock { + return self.known_macs[ip] + } } func append(ip: UInt32, mac: Data) { - self.known_macs[ip] = mac + withLock { + self.known_macs[ip] = mac + } } func remove(ip: UInt32) { - self.known_macs.removeValue(forKey: ip) + withLock { + _ = self.known_macs.removeValue(forKey: ip) + } } func clear() { - self.known_macs = [:] + withLock { + self.known_macs = [:] + } } + + private func withLock(_ body: () throws -> T) rethrows -> T { + os_unfair_lock_lock(&lock) + defer{ + os_unfair_lock_unlock(&lock) + } + + return try body() + } + } diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 9168eb3..9e42b52 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -303,7 +303,7 @@ public class SDLContext { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) - await sessionManager.removeSession(dstMac: dstMac) + sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug) let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) @@ -338,7 +338,7 @@ public class SDLContext { self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) + self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } @@ -349,7 +349,7 @@ public class SDLContext { let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) - await self.sessionManager.addSession(session: session) + self.sessionManager.addSession(session: session) } else { self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } @@ -383,7 +383,7 @@ public class SDLContext { await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: self.logger.log("[SDLContext] get arp response packet", level: .debug) - await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) + self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug) @@ -469,7 +469,7 @@ public class SDLContext { } // 查找arp缓存中是否有目标mac地址 - if let dstMac = await self.arpServer.query(ip: dstIp) { + if let dstMac = self.arpServer.query(ip: dstIp) { await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } else { @@ -504,7 +504,7 @@ public class SDLContext { } else { // 通过session发送到对端 - if let session = await self.sessionManager.getSession(toAddress: dstMac) { + if let session = self.sessionManager.getSession(toAddress: dstMac) { self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug) self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress) await self.flowTracer.inc(num: data.count, type: .p2p) diff --git a/Tun/Punchnet/SessionManager.swift b/Tun/Punchnet/SessionManager.swift index 8fe1505..81fbc63 100644 --- a/Tun/Punchnet/SessionManager.swift +++ b/Tun/Punchnet/SessionManager.swift @@ -6,6 +6,7 @@ // import Foundation import NIOCore +import Darwin struct Session { // 在内部的通讯的ip地址, 整数格式 @@ -27,13 +28,19 @@ struct Session { } } -actor SessionManager { +final class SessionManager { private var sessions: [Data:Session] = [:] + private var lock = os_unfair_lock() // session的有效时间 private let ttl: Int32 = 10 func getSession(toAddress: Data) -> Session? { + os_unfair_lock_lock(&lock) + defer{ + os_unfair_lock_unlock(&lock) + } + let timestamp = Int32(Date().timeIntervalSince1970) if let session = self.sessions[toAddress] { if session.lastTimestamp >= timestamp + ttl { @@ -47,10 +54,20 @@ actor SessionManager { } func addSession(session: Session) { + os_unfair_lock_lock(&lock) + defer{ + os_unfair_lock_unlock(&lock) + } + self.sessions[session.dstMac] = session } func removeSession(dstMac: Data) { + os_unfair_lock_lock(&lock) + defer{ + os_unfair_lock_unlock(&lock) + } + self.sessions.removeValue(forKey: dstMac) }