fix session

This commit is contained in:
anlicheng 2026-04-16 10:42:44 +08:00
parent ca148acc87
commit 122c60f96c
4 changed files with 108 additions and 64 deletions

View File

@ -15,7 +15,7 @@ enum TunnelError: Error {
class PacketTunnelProvider: NEPacketTunnelProvider { class PacketTunnelProvider: NEPacketTunnelProvider {
var contextActor: SDLContextActor? var contextActor: SDLContextActor?
private var rootTask: Task<Void, Error>? private var rootTask: Task<Void, Error>?
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
// //
SDLTunnelAppNotifier.shared.clear() SDLTunnelAppNotifier.shared.clear()
@ -25,7 +25,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
completionHandler(TunnelError.invalidContext) completionHandler(TunnelError.invalidContext)
return return
} }
// //
let rsaCipher = try! CCRSACipher(keySize: 1024) let rsaCipher = try! CCRSACipher(keySize: 1024)
self.rootTask = Task { self.rootTask = Task {
@ -34,28 +34,28 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
completionHandler(TunnelError.invalidConfiguration) completionHandler(TunnelError.invalidConfiguration)
return return
} }
self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher) self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher)
await self.contextActor?.start() await self.contextActor?.start()
try await self.contextActor?.waitForReady() try await self.contextActor?.waitForReady()
completionHandler(nil) completionHandler(nil)
} }
} }
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
// Add code here to start the process of stopping the tunnel. // Add code here to start the process of stopping the tunnel.
Task { Task {
await self.contextActor?.stop() await self.contextActor?.stop()
self.contextActor = nil self.contextActor = nil
self.rootTask?.cancel() self.rootTask?.cancel()
self.rootTask = nil self.rootTask = nil
completionHandler() completionHandler()
} }
} }
override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
// Add code here to handle the message. // Add code here to handle the message.
Task { Task {
@ -67,22 +67,22 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
var reply = TunnelResponse() var reply = TunnelResponse()
reply.code = 1 reply.code = 1
reply.message = err.localizedDescription reply.message = err.localizedDescription
let errorReplyData = try? reply.serializedData() let errorReplyData = try? reply.serializedData()
completionHandler?(errorReplyData) completionHandler?(errorReplyData)
} }
} }
} }
override func sleep(completionHandler: @escaping () -> Void) { override func sleep(completionHandler: @escaping () -> Void) {
// Add code here to get ready to sleep. // Add code here to get ready to sleep.
completionHandler() completionHandler()
} }
override func wake() { override func wake() {
// Add code here to wake up. // Add code here to wake up.
} }
private func handleAppRequest(message: AppRequest) async throws -> Data? { private func handleAppRequest(message: AppRequest) async throws -> Data? {
guard let contextActor = self.contextActor else { guard let contextActor = self.contextActor else {
throw TunnelError.invalidContext throw TunnelError.invalidContext
@ -97,12 +97,12 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
reply.code = 0 reply.code = 0
reply.message = "操作成功" reply.message = "操作成功"
return try reply.serializedData() return try reply.serializedData()
} catch let err { } catch let err {
var reply = TunnelResponse() var reply = TunnelResponse()
reply.code = 1 reply.code = 1
reply.message = err.localizedDescription reply.message = err.localizedDescription
return try reply.serializedData() return try reply.serializedData()
} }
case .none: case .none:
@ -112,15 +112,5 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
return try reply.serializedData() return try reply.serializedData()
} }
} }
}
// ip
extension PacketTunnelProvider {
public static var viaInterface: NetworkInterface? = {
let interfaces = NetworkInterfaceManager.getInterfaces()
return interfaces.first {$0.name == "en0"}
}()
} }

View File

@ -777,7 +777,7 @@ extension SDLContextActor {
switch plan.action { switch plan.action {
case .removeSession(let dstMac): case .removeSession(let dstMac):
self.sessionManager.removeSession(dstMac: dstMac) await self.sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let registerData, let remoteAddresses): case .sendRegister(let registerData, let remoteAddresses):
remoteAddresses.forEach { remoteAddress in remoteAddresses.forEach { remoteAddress in
self.sendPeerPacket(type: .register, data: registerData, remoteAddress: remoteAddress) self.sendPeerPacket(type: .register, data: registerData, remoteAddress: remoteAddress)
@ -845,13 +845,13 @@ extension SDLContextActor {
} }
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
case .register(let register): case .register(let register):
try? self.handleRegister(remoteAddress: remoteAddress, register: register) try? await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck): case .registerAck(let registerAck):
self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
} }
} }
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws {
let networkAddr = config.networkAddress let networkAddr = config.networkAddress
SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)")
@ -865,19 +865,25 @@ extension SDLContextActor {
self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat // , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) if let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) {
self.sessionManager.addSession(session: session) await self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegister get unsupported remoteAddress: \(remoteAddress)", for: .debug)
}
} else { } else {
SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)") SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)")
} }
} }
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async {
// tun, // tun,
let networkAddr = config.networkAddress let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) if let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) {
self.sessionManager.addSession(session: session) await self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegisterAck get unsupported remoteAddress: \(remoteAddress)", for: .debug)
}
} else { } else {
SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)") SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)")
} }
@ -1003,7 +1009,7 @@ extension SDLContextActor {
// 2 // 2
// //
let forwarder = self.makeLayerPacketForwarder() let forwarder = self.makeLayerPacketForwarder()
guard let plan = try? forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else { guard let plan = try? await forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else {
return return
} }

View File

@ -19,7 +19,7 @@ struct SDLLayerPacketForwarder {
let dataCipher: CCDataCipher? let dataCipher: CCDataCipher?
let sessionManager: SessionManager let sessionManager: SessionManager
func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) throws -> DeliveryPlan? { func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) async throws -> DeliveryPlan? {
guard let payload = try self.makePayload(dstMac: dstMac, type: type, data: data) else { guard let payload = try self.makePayload(dstMac: dstMac, type: type, data: data) else {
return nil return nil
} }
@ -28,7 +28,8 @@ struct SDLLayerPacketForwarder {
return .superNode(payload: payload) return .superNode(payload: payload)
} }
if let session = self.sessionManager.getSession(toAddress: dstMac) { let preferredSessionType = Session.AddressType(packetType: type)
if let session = await self.sessionManager.getSession(toAddress: dstMac, preferredType: preferredSessionType) {
return .peer(payload: payload, session: session) return .peer(payload: payload, session: session)
} }

View File

@ -8,18 +8,52 @@ import Foundation
import NIOCore import NIOCore
import Darwin import Darwin
struct Session { struct Session: @unchecked Sendable {
enum AddressType: String, Hashable {
case v4
case v6
init?(socketAddress: SocketAddress) {
switch socketAddress {
case .v4:
self = .v4
case .v6:
self = .v6
default:
return nil
}
}
init?(packetType: LayerPacket.PacketType) {
switch packetType {
case .arp, .ipv4:
self = .v4
case .ipv6:
self = .v6
default:
return nil
}
}
}
// ip, // ip,
let dstMac: Data let dstMac: Data
// nat // nat
let natAddress: SocketAddress let natAddress: SocketAddress
//
let addressType: AddressType
// 使 // 使
var lastTimestamp: Int32 var lastTimestamp: Int32
init(dstMac: Data, natAddress: SocketAddress) { init?(dstMac: Data, natAddress: SocketAddress) {
guard let addressType = AddressType(socketAddress: natAddress) else {
return nil
}
self.dstMac = dstMac self.dstMac = dstMac
self.natAddress = natAddress self.natAddress = natAddress
self.addressType = addressType
self.lastTimestamp = Int32(Date().timeIntervalSince1970) self.lastTimestamp = Int32(Date().timeIntervalSince1970)
} }
@ -28,48 +62,61 @@ struct Session {
} }
} }
class SessionManager { actor SessionManager {
private let locker = NSLock() private var sessions: [Data: [Session.AddressType: Session]] = [:]
private var sessions: [Data:Session] = [:]
// session // session
private let ttl: Int32 = 10 private let ttl: Int32 = 10
func getSession(toAddress: Data) -> Session? { func getSession(toAddress: Data, preferredType: Session.AddressType? = nil) -> Session? {
let timestamp = Int32(Date().timeIntervalSince1970) let timestamp = Int32(Date().timeIntervalSince1970)
locker.lock()
defer { guard var sessions = self.sessions[toAddress] else {
locker.unlock() return nil
} }
if var session = self.sessions[toAddress] { sessions = sessions.filter { $0.value.lastTimestamp + ttl >= timestamp }
if session.lastTimestamp + ttl >= timestamp { guard !sessions.isEmpty else {
session.updateLastTimestamp(timestamp) self.sessions.removeValue(forKey: toAddress)
self.sessions[toAddress] = session return nil
return session
} else {
self.sessions.removeValue(forKey: toAddress)
}
} }
return nil
guard var session = self.selectSession(in: sessions, preferredType: preferredType) else {
self.sessions[toAddress] = sessions
return nil
}
session.updateLastTimestamp(timestamp)
sessions[session.addressType] = session
self.sessions[toAddress] = sessions
return session
} }
func addSession(session: Session) { func addSession(session: Session) {
locker.lock() let timestamp = Int32(Date().timeIntervalSince1970)
defer {
locker.unlock() var sessions = self.sessions[session.dstMac, default: [:]]
sessions = sessions.filter {
$0.value.lastTimestamp + ttl >= timestamp && $0.key != session.addressType
} }
self.sessions[session.dstMac] = session sessions[session.addressType] = session
self.sessions[session.dstMac] = sessions
} }
func removeSession(dstMac: Data) { func removeSession(dstMac: Data) {
locker.lock()
defer {
locker.unlock()
}
self.sessions.removeValue(forKey: dstMac) self.sessions.removeValue(forKey: dstMac)
} }
private func selectSession(in sessions: [Session.AddressType: Session], preferredType: Session.AddressType?) -> Session? {
if let preferredType {
if let preferred = sessions[preferredType] {
return preferred
}
}
return sessions.values.max(by: { $0.lastTimestamp < $1.lastTimestamp })
}
} }