fix session
This commit is contained in:
parent
ca148acc87
commit
122c60f96c
@ -15,7 +15,7 @@ enum TunnelError: Error {
|
|||||||
class PacketTunnelProvider: NEPacketTunnelProvider {
|
class PacketTunnelProvider: NEPacketTunnelProvider {
|
||||||
var contextActor: SDLContextActor?
|
var contextActor: SDLContextActor?
|
||||||
private var rootTask: Task<Void, Error>?
|
private var rootTask: Task<Void, Error>?
|
||||||
|
|
||||||
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
|
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
|
||||||
// 重置通知中心
|
// 重置通知中心
|
||||||
SDLTunnelAppNotifier.shared.clear()
|
SDLTunnelAppNotifier.shared.clear()
|
||||||
@ -25,7 +25,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||||||
completionHandler(TunnelError.invalidContext)
|
completionHandler(TunnelError.invalidContext)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加密算法
|
// 加密算法
|
||||||
let rsaCipher = try! CCRSACipher(keySize: 1024)
|
let rsaCipher = try! CCRSACipher(keySize: 1024)
|
||||||
self.rootTask = Task {
|
self.rootTask = Task {
|
||||||
@ -34,28 +34,28 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||||||
completionHandler(TunnelError.invalidConfiguration)
|
completionHandler(TunnelError.invalidConfiguration)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher)
|
self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher)
|
||||||
await self.contextActor?.start()
|
await self.contextActor?.start()
|
||||||
try await self.contextActor?.waitForReady()
|
try await self.contextActor?.waitForReady()
|
||||||
|
|
||||||
completionHandler(nil)
|
completionHandler(nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
|
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
|
||||||
// Add code here to start the process of stopping the tunnel.
|
// Add code here to start the process of stopping the tunnel.
|
||||||
Task {
|
Task {
|
||||||
await self.contextActor?.stop()
|
await self.contextActor?.stop()
|
||||||
self.contextActor = nil
|
self.contextActor = nil
|
||||||
|
|
||||||
self.rootTask?.cancel()
|
self.rootTask?.cancel()
|
||||||
self.rootTask = nil
|
self.rootTask = nil
|
||||||
|
|
||||||
completionHandler()
|
completionHandler()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
|
override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
|
||||||
// Add code here to handle the message.
|
// Add code here to handle the message.
|
||||||
Task {
|
Task {
|
||||||
@ -67,22 +67,22 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||||||
var reply = TunnelResponse()
|
var reply = TunnelResponse()
|
||||||
reply.code = 1
|
reply.code = 1
|
||||||
reply.message = err.localizedDescription
|
reply.message = err.localizedDescription
|
||||||
|
|
||||||
let errorReplyData = try? reply.serializedData()
|
let errorReplyData = try? reply.serializedData()
|
||||||
completionHandler?(errorReplyData)
|
completionHandler?(errorReplyData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override func sleep(completionHandler: @escaping () -> Void) {
|
override func sleep(completionHandler: @escaping () -> Void) {
|
||||||
// Add code here to get ready to sleep.
|
// Add code here to get ready to sleep.
|
||||||
completionHandler()
|
completionHandler()
|
||||||
}
|
}
|
||||||
|
|
||||||
override func wake() {
|
override func wake() {
|
||||||
// Add code here to wake up.
|
// Add code here to wake up.
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleAppRequest(message: AppRequest) async throws -> Data? {
|
private func handleAppRequest(message: AppRequest) async throws -> Data? {
|
||||||
guard let contextActor = self.contextActor else {
|
guard let contextActor = self.contextActor else {
|
||||||
throw TunnelError.invalidContext
|
throw TunnelError.invalidContext
|
||||||
@ -97,12 +97,12 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||||||
reply.code = 0
|
reply.code = 0
|
||||||
reply.message = "操作成功"
|
reply.message = "操作成功"
|
||||||
return try reply.serializedData()
|
return try reply.serializedData()
|
||||||
|
|
||||||
} catch let err {
|
} catch let err {
|
||||||
var reply = TunnelResponse()
|
var reply = TunnelResponse()
|
||||||
reply.code = 1
|
reply.code = 1
|
||||||
reply.message = err.localizedDescription
|
reply.message = err.localizedDescription
|
||||||
|
|
||||||
return try reply.serializedData()
|
return try reply.serializedData()
|
||||||
}
|
}
|
||||||
case .none:
|
case .none:
|
||||||
@ -112,15 +112,5 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
|
|||||||
return try reply.serializedData()
|
return try reply.serializedData()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取物理网卡ip地址
|
|
||||||
extension PacketTunnelProvider {
|
|
||||||
|
|
||||||
public static var viaInterface: NetworkInterface? = {
|
|
||||||
let interfaces = NetworkInterfaceManager.getInterfaces()
|
|
||||||
|
|
||||||
return interfaces.first {$0.name == "en0"}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -777,7 +777,7 @@ extension SDLContextActor {
|
|||||||
|
|
||||||
switch plan.action {
|
switch plan.action {
|
||||||
case .removeSession(let dstMac):
|
case .removeSession(let dstMac):
|
||||||
self.sessionManager.removeSession(dstMac: dstMac)
|
await self.sessionManager.removeSession(dstMac: dstMac)
|
||||||
case .sendRegister(let registerData, let remoteAddresses):
|
case .sendRegister(let registerData, let remoteAddresses):
|
||||||
remoteAddresses.forEach { remoteAddress in
|
remoteAddresses.forEach { remoteAddress in
|
||||||
self.sendPeerPacket(type: .register, data: registerData, remoteAddress: remoteAddress)
|
self.sendPeerPacket(type: .register, data: registerData, remoteAddress: remoteAddress)
|
||||||
@ -845,13 +845,13 @@ extension SDLContextActor {
|
|||||||
}
|
}
|
||||||
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
|
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
|
||||||
case .register(let register):
|
case .register(let register):
|
||||||
try? self.handleRegister(remoteAddress: remoteAddress, register: register)
|
try? await self.handleRegister(remoteAddress: remoteAddress, register: register)
|
||||||
case .registerAck(let registerAck):
|
case .registerAck(let registerAck):
|
||||||
self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
|
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
|
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws {
|
||||||
let networkAddr = config.networkAddress
|
let networkAddr = config.networkAddress
|
||||||
SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)")
|
SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)")
|
||||||
|
|
||||||
@ -865,19 +865,25 @@ extension SDLContextActor {
|
|||||||
|
|
||||||
self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
|
self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
|
||||||
// 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址
|
// 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址
|
||||||
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
|
if let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) {
|
||||||
self.sessionManager.addSession(session: session)
|
await self.sessionManager.addSession(session: session)
|
||||||
|
} else {
|
||||||
|
SDLLogger.log("[SDLContext] didReadRegister get unsupported remoteAddress: \(remoteAddress)", for: .debug)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)")
|
SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
|
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async {
|
||||||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||||||
let networkAddr = config.networkAddress
|
let networkAddr = config.networkAddress
|
||||||
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
|
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
|
||||||
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
|
if let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) {
|
||||||
self.sessionManager.addSession(session: session)
|
await self.sessionManager.addSession(session: session)
|
||||||
|
} else {
|
||||||
|
SDLLogger.log("[SDLContext] didReadRegisterAck get unsupported remoteAddress: \(remoteAddress)", for: .debug)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)")
|
SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)")
|
||||||
}
|
}
|
||||||
@ -1003,7 +1009,7 @@ extension SDLContextActor {
|
|||||||
// 将数据封装层2层的数据包
|
// 将数据封装层2层的数据包
|
||||||
// 构造数据包
|
// 构造数据包
|
||||||
let forwarder = self.makeLayerPacketForwarder()
|
let forwarder = self.makeLayerPacketForwarder()
|
||||||
guard let plan = try? forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else {
|
guard let plan = try? await forwarder.makeDeliveryPlan(dstMac: dstMac, type: type, data: data) else {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ struct SDLLayerPacketForwarder {
|
|||||||
let dataCipher: CCDataCipher?
|
let dataCipher: CCDataCipher?
|
||||||
let sessionManager: SessionManager
|
let sessionManager: SessionManager
|
||||||
|
|
||||||
func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) throws -> DeliveryPlan? {
|
func makeDeliveryPlan(dstMac: Data, type: LayerPacket.PacketType, data: Data) async throws -> DeliveryPlan? {
|
||||||
guard let payload = try self.makePayload(dstMac: dstMac, type: type, data: data) else {
|
guard let payload = try self.makePayload(dstMac: dstMac, type: type, data: data) else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -28,7 +28,8 @@ struct SDLLayerPacketForwarder {
|
|||||||
return .superNode(payload: payload)
|
return .superNode(payload: payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
if let session = self.sessionManager.getSession(toAddress: dstMac) {
|
let preferredSessionType = Session.AddressType(packetType: type)
|
||||||
|
if let session = await self.sessionManager.getSession(toAddress: dstMac, preferredType: preferredSessionType) {
|
||||||
return .peer(payload: payload, session: session)
|
return .peer(payload: payload, session: session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,18 +8,52 @@ import Foundation
|
|||||||
import NIOCore
|
import NIOCore
|
||||||
import Darwin
|
import Darwin
|
||||||
|
|
||||||
struct Session {
|
struct Session: @unchecked Sendable {
|
||||||
|
enum AddressType: String, Hashable {
|
||||||
|
case v4
|
||||||
|
case v6
|
||||||
|
|
||||||
|
init?(socketAddress: SocketAddress) {
|
||||||
|
switch socketAddress {
|
||||||
|
case .v4:
|
||||||
|
self = .v4
|
||||||
|
case .v6:
|
||||||
|
self = .v6
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
init?(packetType: LayerPacket.PacketType) {
|
||||||
|
switch packetType {
|
||||||
|
case .arp, .ipv4:
|
||||||
|
self = .v4
|
||||||
|
case .ipv6:
|
||||||
|
self = .v6
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 在内部的通讯的ip地址, 整数格式
|
// 在内部的通讯的ip地址, 整数格式
|
||||||
let dstMac: Data
|
let dstMac: Data
|
||||||
// 对端的主机在nat上映射的端口信息
|
// 对端的主机在nat上映射的端口信息
|
||||||
let natAddress: SocketAddress
|
let natAddress: SocketAddress
|
||||||
|
// 当前会话对应的外层地址族
|
||||||
|
let addressType: AddressType
|
||||||
|
|
||||||
// 最后使用时间
|
// 最后使用时间
|
||||||
var lastTimestamp: Int32
|
var lastTimestamp: Int32
|
||||||
|
|
||||||
init(dstMac: Data, natAddress: SocketAddress) {
|
init?(dstMac: Data, natAddress: SocketAddress) {
|
||||||
|
guard let addressType = AddressType(socketAddress: natAddress) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
self.dstMac = dstMac
|
self.dstMac = dstMac
|
||||||
self.natAddress = natAddress
|
self.natAddress = natAddress
|
||||||
|
self.addressType = addressType
|
||||||
self.lastTimestamp = Int32(Date().timeIntervalSince1970)
|
self.lastTimestamp = Int32(Date().timeIntervalSince1970)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,48 +62,61 @@ struct Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class SessionManager {
|
actor SessionManager {
|
||||||
private let locker = NSLock()
|
private var sessions: [Data: [Session.AddressType: Session]] = [:]
|
||||||
private var sessions: [Data:Session] = [:]
|
|
||||||
|
|
||||||
// session的有效时间
|
// session的有效时间
|
||||||
private let ttl: Int32 = 10
|
private let ttl: Int32 = 10
|
||||||
|
|
||||||
func getSession(toAddress: Data) -> Session? {
|
func getSession(toAddress: Data, preferredType: Session.AddressType? = nil) -> Session? {
|
||||||
let timestamp = Int32(Date().timeIntervalSince1970)
|
let timestamp = Int32(Date().timeIntervalSince1970)
|
||||||
locker.lock()
|
|
||||||
defer {
|
guard var sessions = self.sessions[toAddress] else {
|
||||||
locker.unlock()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if var session = self.sessions[toAddress] {
|
sessions = sessions.filter { $0.value.lastTimestamp + ttl >= timestamp }
|
||||||
if session.lastTimestamp + ttl >= timestamp {
|
guard !sessions.isEmpty else {
|
||||||
session.updateLastTimestamp(timestamp)
|
self.sessions.removeValue(forKey: toAddress)
|
||||||
self.sessions[toAddress] = session
|
return nil
|
||||||
|
|
||||||
return session
|
|
||||||
} else {
|
|
||||||
self.sessions.removeValue(forKey: toAddress)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
guard var session = self.selectSession(in: sessions, preferredType: preferredType) else {
|
||||||
|
self.sessions[toAddress] = sessions
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
session.updateLastTimestamp(timestamp)
|
||||||
|
sessions[session.addressType] = session
|
||||||
|
self.sessions[toAddress] = sessions
|
||||||
|
|
||||||
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
func addSession(session: Session) {
|
func addSession(session: Session) {
|
||||||
locker.lock()
|
let timestamp = Int32(Date().timeIntervalSince1970)
|
||||||
defer {
|
|
||||||
locker.unlock()
|
var sessions = self.sessions[session.dstMac, default: [:]]
|
||||||
|
sessions = sessions.filter {
|
||||||
|
$0.value.lastTimestamp + ttl >= timestamp && $0.key != session.addressType
|
||||||
}
|
}
|
||||||
self.sessions[session.dstMac] = session
|
sessions[session.addressType] = session
|
||||||
|
|
||||||
|
self.sessions[session.dstMac] = sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeSession(dstMac: Data) {
|
func removeSession(dstMac: Data) {
|
||||||
locker.lock()
|
|
||||||
defer {
|
|
||||||
locker.unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.sessions.removeValue(forKey: dstMac)
|
self.sessions.removeValue(forKey: dstMac)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private func selectSession(in sessions: [Session.AddressType: Session], preferredType: Session.AddressType?) -> Session? {
|
||||||
|
if let preferredType {
|
||||||
|
if let preferred = sessions[preferredType] {
|
||||||
|
return preferred
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessions.values.max(by: { $0.lastTimestamp < $1.lastTimestamp })
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user