diff --git a/Tun/PacketTunnelProvider.swift b/Tun/PacketTunnelProvider.swift index 06360e7..b1cf467 100644 --- a/Tun/PacketTunnelProvider.swift +++ b/Tun/PacketTunnelProvider.swift @@ -15,7 +15,7 @@ enum TunnelError: Error { class PacketTunnelProvider: NEPacketTunnelProvider { var contextActor: SDLContextActor? private var rootTask: Task? - + override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { // 重置通知中心 SDLTunnelAppNotifier.shared.clear() @@ -25,7 +25,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { completionHandler(TunnelError.invalidContext) return } - + // 加密算法 let rsaCipher = try! CCRSACipher(keySize: 1024) self.rootTask = Task { @@ -34,28 +34,28 @@ class PacketTunnelProvider: NEPacketTunnelProvider { completionHandler(TunnelError.invalidConfiguration) return } - + self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher) await self.contextActor?.start() try await self.contextActor?.waitForReady() - + completionHandler(nil) } } - + override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { // Add code here to start the process of stopping the tunnel. Task { await self.contextActor?.stop() self.contextActor = nil - + self.rootTask?.cancel() self.rootTask = nil - + completionHandler() } } - + override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { // Add code here to handle the message. Task { @@ -67,22 +67,22 @@ class PacketTunnelProvider: NEPacketTunnelProvider { var reply = TunnelResponse() reply.code = 1 reply.message = err.localizedDescription - + let errorReplyData = try? reply.serializedData() completionHandler?(errorReplyData) } } } - + override func sleep(completionHandler: @escaping () -> Void) { // Add code here to get ready to sleep. completionHandler() } - + override func wake() { // Add code here to wake up. } - + private func handleAppRequest(message: AppRequest) async throws -> Data? { guard let contextActor = self.contextActor else { throw TunnelError.invalidContext @@ -97,12 +97,12 @@ class PacketTunnelProvider: NEPacketTunnelProvider { reply.code = 0 reply.message = "操作成功" return try reply.serializedData() - + } catch let err { var reply = TunnelResponse() reply.code = 1 reply.message = err.localizedDescription - + return try reply.serializedData() } case .none: @@ -112,15 +112,5 @@ class PacketTunnelProvider: NEPacketTunnelProvider { return try reply.serializedData() } } - -} -// 获取物理网卡ip地址 -extension PacketTunnelProvider { - - public static var viaInterface: NetworkInterface? = { - let interfaces = NetworkInterfaceManager.getInterfaces() - - return interfaces.first {$0.name == "en0"} - }() } diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 998e680..9c82b7a 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -777,7 +777,7 @@ extension SDLContextActor { switch plan.action { case .removeSession(let dstMac): - self.sessionManager.removeSession(dstMac: dstMac) + await self.sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let registerData, let remoteAddresses): remoteAddresses.forEach { remoteAddress in self.sendPeerPacket(type: .register, data: registerData, remoteAddress: remoteAddress) @@ -845,13 +845,13 @@ extension SDLContextActor { } await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) case .register(let register): - try? self.handleRegister(remoteAddress: remoteAddress, register: register) + try? await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): - self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) } } - private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { + private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { let networkAddr = config.networkAddress SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") @@ -865,19 +865,25 @@ extension SDLContextActor { self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 - let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) - self.sessionManager.addSession(session: session) + if let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) { + await self.sessionManager.addSession(session: session) + } else { + SDLLogger.log("[SDLContext] didReadRegister get unsupported remoteAddress: \(remoteAddress)", for: .debug) + } } else { SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)") } } - private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { + private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { - let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) - self.sessionManager.addSession(session: session) + if let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) { + await self.sessionManager.addSession(session: session) + } else { + SDLLogger.log("[SDLContext] didReadRegisterAck get unsupported remoteAddress: \(remoteAddress)", for: .debug) + } } else { SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)") } @@ -1003,7 +1009,7 @@ extension SDLContextActor { // 将数据封装层2层的数据包 // 构造数据包 let forwarder = self.makeLayerPacketForwarder() - guard let plan = try? forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else { + guard let plan = try? await forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else { return } diff --git a/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift b/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift index 1876cd5..ddf0a4e 100644 --- a/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift +++ b/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift @@ -19,7 +19,7 @@ struct SDLLayerPacketForwarder { let dataCipher: CCDataCipher? let sessionManager: SessionManager - func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) throws -> DeliveryPlan? { + func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) async throws -> DeliveryPlan? { guard let payload = try self.makePayload(dstMac: dstMac, type: type, data: data) else { return nil } @@ -28,7 +28,8 @@ struct SDLLayerPacketForwarder { return .superNode(payload: payload) } - if let session = self.sessionManager.getSession(toAddress: dstMac) { + let preferredSessionType = Session.AddressType(packetType: type) + if let session = await self.sessionManager.getSession(toAddress: dstMac, preferredType: preferredSessionType) { return .peer(payload: payload, session: session) } diff --git a/Tun/Punchnet/SessionManager.swift b/Tun/Punchnet/SessionManager.swift index 08e2c87..613d850 100644 --- a/Tun/Punchnet/SessionManager.swift +++ b/Tun/Punchnet/SessionManager.swift @@ -8,18 +8,52 @@ import Foundation import NIOCore import Darwin -struct Session { +struct Session: @unchecked Sendable { + enum AddressType: String, Hashable { + case v4 + case v6 + + init?(socketAddress: SocketAddress) { + switch socketAddress { + case .v4: + self = .v4 + case .v6: + self = .v6 + default: + return nil + } + } + + init?(packetType: LayerPacket.PacketType) { + switch packetType { + case .arp, .ipv4: + self = .v4 + case .ipv6: + self = .v6 + default: + return nil + } + } + } + // 在内部的通讯的ip地址, 整数格式 let dstMac: Data // 对端的主机在nat上映射的端口信息 let natAddress: SocketAddress + // 当前会话对应的外层地址族 + let addressType: AddressType // 最后使用时间 var lastTimestamp: Int32 - init(dstMac: Data, natAddress: SocketAddress) { + init?(dstMac: Data, natAddress: SocketAddress) { + guard let addressType = AddressType(socketAddress: natAddress) else { + return nil + } + self.dstMac = dstMac self.natAddress = natAddress + self.addressType = addressType self.lastTimestamp = Int32(Date().timeIntervalSince1970) } @@ -28,48 +62,61 @@ struct Session { } } -class SessionManager { - private let locker = NSLock() - private var sessions: [Data:Session] = [:] +actor SessionManager { + private var sessions: [Data: [Session.AddressType: Session]] = [:] // session的有效时间 private let ttl: Int32 = 10 - func getSession(toAddress: Data) -> Session? { + func getSession(toAddress: Data, preferredType: Session.AddressType? = nil) -> Session? { let timestamp = Int32(Date().timeIntervalSince1970) - locker.lock() - defer { - locker.unlock() + + guard var sessions = self.sessions[toAddress] else { + return nil } - if var session = self.sessions[toAddress] { - if session.lastTimestamp + ttl >= timestamp { - session.updateLastTimestamp(timestamp) - self.sessions[toAddress] = session - - return session - } else { - self.sessions.removeValue(forKey: toAddress) - } + sessions = sessions.filter { $0.value.lastTimestamp + ttl >= timestamp } + guard !sessions.isEmpty else { + self.sessions.removeValue(forKey: toAddress) + return nil } - return nil + + guard var session = self.selectSession(in: sessions, preferredType: preferredType) else { + self.sessions[toAddress] = sessions + return nil + } + + session.updateLastTimestamp(timestamp) + sessions[session.addressType] = session + self.sessions[toAddress] = sessions + + return session } func addSession(session: Session) { - locker.lock() - defer { - locker.unlock() + let timestamp = Int32(Date().timeIntervalSince1970) + + var sessions = self.sessions[session.dstMac, default: [:]] + sessions = sessions.filter { + $0.value.lastTimestamp + ttl >= timestamp && $0.key != session.addressType } - self.sessions[session.dstMac] = session + sessions[session.addressType] = session + + self.sessions[session.dstMac] = sessions } func removeSession(dstMac: Data) { - locker.lock() - defer { - locker.unlock() - } - self.sessions.removeValue(forKey: dstMac) } + private func selectSession(in sessions: [Session.AddressType: Session], preferredType: Session.AddressType?) -> Session? { + if let preferredType { + if let preferred = sessions[preferredType] { + return preferred + } + } + + return sessions.values.max(by: { $0.lastTimestamp < $1.lastTimestamp }) + } + }