diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index c21eb38..dea948c 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -25,6 +25,15 @@ actor SDLContextActor { private enum UDPHoleKind: Equatable { case v4 case v6 + + func convertAddressType() -> Session.AddressType { + switch self { + case .v4: + return .v4 + case .v6: + return .v6 + } + } } private var readyState: ReadyState = .idle @@ -845,13 +854,13 @@ extension SDLContextActor { } await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply) case .register(let register): - try? await self.handleRegister(remoteAddress: remoteAddress, register: register) + try? await self.handleRegister(remoteAddress: remoteAddress, register: register, source: source) case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck, source: source) } } - private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { + private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister, source: UDPHoleKind) async throws { let networkAddr = config.networkAddress SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)") @@ -865,7 +874,7 @@ extension SDLContextActor { self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 - if let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) { + if let session = Session(dstMac: register.srcMac, natAddress: remoteAddress, addressType: source.convertAddressType()) { await self.sessionManager.addSession(session: session) } else { SDLLogger.log("[SDLContext] didReadRegister get unsupported remoteAddress: \(remoteAddress)", for: .debug) @@ -875,11 +884,11 @@ extension SDLContextActor { } } - private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { + private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck, source: UDPHoleKind) async { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { - if let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) { + if let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress, addressType: source.convertAddressType()) { await self.sessionManager.addSession(session: session) } else { SDLLogger.log("[SDLContext] didReadRegisterAck get unsupported remoteAddress: \(remoteAddress)", for: .debug) diff --git a/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift b/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift index ddf0a4e..649f349 100644 --- a/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift +++ b/Tun/Punchnet/Actors/SDLLayerPacketForwarder.swift @@ -8,6 +8,7 @@ import Foundation struct SDLLayerPacketForwarder { + enum DeliveryPlan { case superNode(payload: Data) case peer(payload: Data, session: Session) @@ -28,8 +29,7 @@ struct SDLLayerPacketForwarder { return .superNode(payload: payload) } - let preferredSessionType = Session.AddressType(packetType: type) - if let session = await self.sessionManager.getSession(toAddress: dstMac, preferredType: preferredSessionType) { + if let session = await self.sessionManager.getSession(toAddress: dstMac) { return .peer(payload: payload, session: session) } diff --git a/Tun/Punchnet/SessionManager.swift b/Tun/Punchnet/SessionManager.swift index 613d850..e6e84d3 100644 --- a/Tun/Punchnet/SessionManager.swift +++ b/Tun/Punchnet/SessionManager.swift @@ -8,32 +8,10 @@ import Foundation import NIOCore import Darwin -struct Session: @unchecked Sendable { +struct Session { enum AddressType: String, Hashable { case v4 case v6 - - init?(socketAddress: SocketAddress) { - switch socketAddress { - case .v4: - self = .v4 - case .v6: - self = .v6 - default: - return nil - } - } - - init?(packetType: LayerPacket.PacketType) { - switch packetType { - case .arp, .ipv4: - self = .v4 - case .ipv6: - self = .v6 - default: - return nil - } - } } // 在内部的通讯的ip地址, 整数格式 @@ -46,11 +24,7 @@ struct Session: @unchecked Sendable { // 最后使用时间 var lastTimestamp: Int32 - init?(dstMac: Data, natAddress: SocketAddress) { - guard let addressType = AddressType(socketAddress: natAddress) else { - return nil - } - + init?(dstMac: Data, natAddress: SocketAddress, addressType: AddressType) { self.dstMac = dstMac self.natAddress = natAddress self.addressType = addressType @@ -68,7 +42,7 @@ actor SessionManager { // session的有效时间 private let ttl: Int32 = 10 - func getSession(toAddress: Data, preferredType: Session.AddressType? = nil) -> Session? { + func getSession(toAddress: Data) -> Session? { let timestamp = Int32(Date().timeIntervalSince1970) guard var sessions = self.sessions[toAddress] else { @@ -81,7 +55,7 @@ actor SessionManager { return nil } - guard var session = self.selectSession(in: sessions, preferredType: preferredType) else { + guard var session = self.selectSession(in: sessions) else { self.sessions[toAddress] = sessions return nil } @@ -109,13 +83,7 @@ actor SessionManager { self.sessions.removeValue(forKey: dstMac) } - private func selectSession(in sessions: [Session.AddressType: Session], preferredType: Session.AddressType?) -> Session? { - if let preferredType { - if let preferred = sessions[preferredType] { - return preferred - } - } - + private func selectSession(in sessions: [Session.AddressType: Session]) -> Session? { return sessions.values.max(by: { $0.lastTimestamp < $1.lastTimestamp }) }