From 352dff8e19ba7edf562c3b1e5a7595552b498e5f Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Fri, 30 Jan 2026 16:41:10 +0800 Subject: [PATCH] fix context --- Tun/Punchnet/SDLContext.swift | 71 +++++++++++++++++------------------ 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 93e444f..d87298d 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -13,24 +13,24 @@ import NIOCore /* 1. 处理rsa的加解密逻辑 */ -public class SDLContext { - let config: SDLConfiguration +actor SDLContext { + nonisolated let config: SDLConfiguration // nat的网络类型 var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 - let aesCipher: AESCipher + nonisolated let aesCipher: AESCipher // aes var aesKey: Data = Data() // rsa的相关配置, public_key是本地生成的 - let rsaCipher: RSACipher + nonisolated let rsaCipher: RSACipher // 依赖的变量 var udpHole: SDLUDPHole? - var providerAdapter: SDLTunnelProviderAdapter + nonisolated let providerAdapter: SDLTunnelProviderAdapter var puncherActor: SDLPuncherActor? // dns的client对象 var dnsClient: SDLDNSClient? @@ -44,9 +44,6 @@ public class SDLContext { private var sessionManager: SessionManager private var arpServer: ArpServer - // 记录最后发送的stunRequest的cookie - private var lastCookie: UInt32? = 0 - // 网络状态变化的健康 private var monitor: SDLNetworkMonitor? @@ -54,9 +51,9 @@ public class SDLContext { private var noticeClient: SDLNoticeClient? // 流量统计 - private var flowTracer = SDLFlowTracer() + nonisolated private let flowTracer = SDLFlowTracer() - private let logger: SDLLogger + nonisolated private let logger: SDLLogger public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { self.logger = logger @@ -105,16 +102,14 @@ public class SDLContext { // 处理event事件流 group.addTask { - if let eventStream = self.udpHole?.eventStream { + if let eventStream = await self.udpHole?.eventStream { for try await event in eventStream { try Task.checkCancellation() - Task { - switch event { - case .ready: - try await self.handleUDPHoleReady() - case .closed: - () - } + switch event { + case .ready: + await self.handleUDPHoleReady() + case .closed: + () } } } @@ -122,7 +117,7 @@ public class SDLContext { // 处理数据流 group.addTask { - if let dataStream = self.udpHole?.dataStream { + if let dataStream = await self.udpHole?.dataStream { for try await data in dataStream { try Task.checkCancellation() Task { @@ -134,7 +129,7 @@ public class SDLContext { // 处理signal信号流 group.addTask { - if let signalStream = self.udpHole?.signalStream { + if let signalStream = await self.udpHole?.signalStream { for try await(remoteAddress, signal) in signalStream { try Task.checkCancellation() Task { @@ -161,7 +156,7 @@ public class SDLContext { // 处理DNS的事件流 group.addTask { - if let packetFlow = self.dnsClient?.packetFlow { + if let packetFlow = await self.dnsClient?.packetFlow { for await packet in packetFlow { let nePacket = NEPacket(data: packet, protocolFamily: 2) self.providerAdapter.writePackets(packets: [nePacket]) @@ -171,12 +166,12 @@ public class SDLContext { // 处理Monitor的事件流 group.addTask { - for await event in self.monitor!.eventStream { + for await event in await self.monitor!.eventStream { switch event { case .changed: // 需要重新探测网络的nat类型 //self.natType = await self.getNatType() - self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) + self.logger.log("didNetworkPathChanged, nat type is: \(await self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) } @@ -197,18 +192,22 @@ public class SDLContext { self.readTask?.cancel() } - private func handleUDPHoleReady() async throws { + private func setNatType(natType: SDLNATProberActor.NatType) { + self.natType = natType + } + + private func handleUDPHoleReady() async { if let udpHole = self.udpHole { self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger) + self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) } - await withTaskGroup(of: Void.self) { group in + await withDiscardingTaskGroup { group in group.addTask { // 开始探测nat的类型 - if let udpHoleActor = self.udpHole { - self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) - self.natType = await self.proberActor!.probeNatType() - self.logger.log("[SDLContext] nat_type is: \(self.natType)") + if let natType = await self.proberActor?.probeNatType() { + await self.setNatType(natType: natType) + self.logger.log("[SDLContext] nat_type is: \(natType)") } } @@ -226,7 +225,7 @@ public class SDLContext { if let registerSuperData = try? registerSuper.serializedData() { self.logger.log("[SDLContext] will send register super") - self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress) + await self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress) } } } @@ -265,7 +264,7 @@ public class SDLContext { self.aesKey = aesKey } - private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) async { + private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) { let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return @@ -284,7 +283,7 @@ public class SDLContext { } - private func handleEvent(event: SDLEvent) async throws { + private func handleEvent(event: SDLEvent) throws { switch event { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac @@ -309,7 +308,7 @@ public class SDLContext { } } - private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { + private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws { let networkAddr = config.networkAddress self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) @@ -330,7 +329,7 @@ public class SDLContext { } } - private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { + private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 let networkAddr = config.networkAddress if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { @@ -341,7 +340,7 @@ public class SDLContext { } } - private func handleData(data: SDLData) async throws { + private func handleData(data: SDLData) throws { let mac = LayerPacket.MacAddress(data: data.dstMac) let networkAddr = config.networkAddress @@ -426,7 +425,7 @@ public class SDLContext { await withDiscardingTaskGroup() { group in for packet in chunkPackets { group.addTask { - self.dealPacket(packet: packet) + await self.dealPacket(packet: packet) } } }