fix context

This commit is contained in:
anlicheng 2026-01-30 16:41:10 +08:00
parent b5d574ea31
commit 352dff8e19

View File

@ -13,24 +13,24 @@ import NIOCore
/* /*
1. rsa的加解密逻辑 1. rsa的加解密逻辑
*/ */
public class SDLContext { actor SDLContext {
let config: SDLConfiguration nonisolated let config: SDLConfiguration
// nat // nat
var natType: SDLNATProberActor.NatType = .blocked var natType: SDLNATProberActor.NatType = .blocked
// AES // AES
let aesCipher: AESCipher nonisolated let aesCipher: AESCipher
// aes // aes
var aesKey: Data = Data() var aesKey: Data = Data()
// rsa, public_key // rsa, public_key
let rsaCipher: RSACipher nonisolated let rsaCipher: RSACipher
// //
var udpHole: SDLUDPHole? var udpHole: SDLUDPHole?
var providerAdapter: SDLTunnelProviderAdapter nonisolated let providerAdapter: SDLTunnelProviderAdapter
var puncherActor: SDLPuncherActor? var puncherActor: SDLPuncherActor?
// dnsclient // dnsclient
var dnsClient: SDLDNSClient? var dnsClient: SDLDNSClient?
@ -44,9 +44,6 @@ public class SDLContext {
private var sessionManager: SessionManager private var sessionManager: SessionManager
private var arpServer: ArpServer private var arpServer: ArpServer
// stunRequestcookie
private var lastCookie: UInt32? = 0
// //
private var monitor: SDLNetworkMonitor? private var monitor: SDLNetworkMonitor?
@ -54,9 +51,9 @@ public class SDLContext {
private var noticeClient: SDLNoticeClient? private var noticeClient: SDLNoticeClient?
// //
private var flowTracer = SDLFlowTracer() nonisolated private let flowTracer = SDLFlowTracer()
private let logger: SDLLogger nonisolated private let logger: SDLLogger
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger self.logger = logger
@ -105,24 +102,22 @@ public class SDLContext {
// event // event
group.addTask { group.addTask {
if let eventStream = self.udpHole?.eventStream { if let eventStream = await self.udpHole?.eventStream {
for try await event in eventStream { for try await event in eventStream {
try Task.checkCancellation() try Task.checkCancellation()
Task {
switch event { switch event {
case .ready: case .ready:
try await self.handleUDPHoleReady() await self.handleUDPHoleReady()
case .closed: case .closed:
() ()
} }
} }
} }
} }
}
// //
group.addTask { group.addTask {
if let dataStream = self.udpHole?.dataStream { if let dataStream = await self.udpHole?.dataStream {
for try await data in dataStream { for try await data in dataStream {
try Task.checkCancellation() try Task.checkCancellation()
Task { Task {
@ -134,7 +129,7 @@ public class SDLContext {
// signal // signal
group.addTask { group.addTask {
if let signalStream = self.udpHole?.signalStream { if let signalStream = await self.udpHole?.signalStream {
for try await(remoteAddress, signal) in signalStream { for try await(remoteAddress, signal) in signalStream {
try Task.checkCancellation() try Task.checkCancellation()
Task { Task {
@ -161,7 +156,7 @@ public class SDLContext {
// DNS // DNS
group.addTask { group.addTask {
if let packetFlow = self.dnsClient?.packetFlow { if let packetFlow = await self.dnsClient?.packetFlow {
for await packet in packetFlow { for await packet in packetFlow {
let nePacket = NEPacket(data: packet, protocolFamily: 2) let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket]) self.providerAdapter.writePackets(packets: [nePacket])
@ -171,12 +166,12 @@ public class SDLContext {
// Monitor // Monitor
group.addTask { group.addTask {
for await event in self.monitor!.eventStream { for await event in await self.monitor!.eventStream {
switch event { switch event {
case .changed: case .changed:
// nat // nat
//self.natType = await self.getNatType() //self.natType = await self.getNatType()
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) self.logger.log("didNetworkPathChanged, nat type is: \(await self.natType)", level: .info)
case .unreachable: case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning) self.logger.log("didNetworkPathUnreachable", level: .warning)
} }
@ -197,18 +192,22 @@ public class SDLContext {
self.readTask?.cancel() self.readTask?.cancel()
} }
private func handleUDPHoleReady() async throws { private func setNatType(natType: SDLNATProberActor.NatType) {
if let udpHole = self.udpHole { self.natType = natType
self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger)
} }
await withTaskGroup(of: Void.self) { group in private func handleUDPHoleReady() async {
if let udpHole = self.udpHole {
self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger)
self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger)
}
await withDiscardingTaskGroup { group in
group.addTask { group.addTask {
// nat // nat
if let udpHoleActor = self.udpHole { if let natType = await self.proberActor?.probeNatType() {
self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) await self.setNatType(natType: natType)
self.natType = await self.proberActor!.probeNatType() self.logger.log("[SDLContext] nat_type is: \(natType)")
self.logger.log("[SDLContext] nat_type is: \(self.natType)")
} }
} }
@ -226,7 +225,7 @@ public class SDLContext {
if let registerSuperData = try? registerSuper.serializedData() { if let registerSuperData = try? registerSuper.serializedData() {
self.logger.log("[SDLContext] will send register super") self.logger.log("[SDLContext] will send register super")
self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress) await self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress)
} }
} }
} }
@ -265,7 +264,7 @@ public class SDLContext {
self.aesKey = aesKey self.aesKey = aesKey
} }
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) async { private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) {
let errorMessage = nakPacket.errorMessage let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return return
@ -284,7 +283,7 @@ public class SDLContext {
} }
private func handleEvent(event: SDLEvent) async throws { private func handleEvent(event: SDLEvent) throws {
switch event { switch event {
case .natChanged(let natChangedEvent): case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac let dstMac = natChangedEvent.mac
@ -309,7 +308,7 @@ public class SDLContext {
} }
} }
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws { private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
let networkAddr = config.networkAddress let networkAddr = config.networkAddress
self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug) self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
@ -330,7 +329,7 @@ public class SDLContext {
} }
} }
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async { private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
// tun, // tun,
let networkAddr = config.networkAddress let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId { if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
@ -341,7 +340,7 @@ public class SDLContext {
} }
} }
private func handleData(data: SDLData) async throws { private func handleData(data: SDLData) throws {
let mac = LayerPacket.MacAddress(data: data.dstMac) let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress let networkAddr = config.networkAddress
@ -426,7 +425,7 @@ public class SDLContext {
await withDiscardingTaskGroup() { group in await withDiscardingTaskGroup() { group in
for packet in chunkPackets { for packet in chunkPackets {
group.addTask { group.addTask {
self.dealPacket(packet: packet) await self.dealPacket(packet: packet)
} }
} }
} }