解决定时器的问题

This commit is contained in:
anlicheng 2026-02-03 16:07:30 +08:00
parent d964eb6e27
commit 55ea1cd09d
6 changed files with 205 additions and 160 deletions

View File

@ -13,7 +13,7 @@ enum TunnelError: Error {
} }
class PacketTunnelProvider: NEPacketTunnelProvider { class PacketTunnelProvider: NEPacketTunnelProvider {
var contextSupervisor: SDLContextSupervisor? var contextActor: SDLContextActor?
private var rootTask: Task<Void, Error>? private var rootTask: Task<Void, Error>?
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) { override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
@ -24,7 +24,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
} }
// //
guard self.contextSupervisor == nil else { guard self.contextActor == nil else {
completionHandler(TunnelError.invalidContext) completionHandler(TunnelError.invalidContext)
return return
} }
@ -35,23 +35,24 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
let logger = SDLLogger(level: .debug) let logger = SDLLogger(level: .debug)
self.rootTask = Task { self.rootTask = Task {
self.contextSupervisor = SDLContextSupervisor() self.contextActor = SDLContextActor(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger)
await self.contextSupervisor?.start(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger) await self.contextActor?.start()
completionHandler(nil) completionHandler(nil)
} }
} }
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) { override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
// Add code here to start the process of stopping the tunnel. // Add code here to start the process of stopping the tunnel.
self.rootTask?.cancel()
Task { Task {
await self.contextSupervisor?.stop() await self.contextActor?.stop()
} self.contextActor = nil
self.contextSupervisor = nil
self.rootTask?.cancel()
self.rootTask = nil self.rootTask = nil
completionHandler() completionHandler()
} }
}
override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) { override func handleAppMessage(_ messageData: Data, completionHandler: ((Data?) -> Void)?) {
// Add code here to handle the message. // Add code here to handle the message.

View File

@ -1,44 +0,0 @@
//
// SDLContextSupervisor.swift
// Tun
//
// Created by on 2026/2/2.
//
import Foundation
import NetworkExtension
actor SDLContextSupervisor {
private var context: SDLContextActor?
private var tasks: [Task<Void, Never>] = []
public func start(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) async {
let context = SDLContextActor(provider: provider, config: config, rsaCipher: rsaCipher, aesCipher: aesCipher, logger: logger)
self.context = context
tasks.append(spawnLoop { try await context.startNoticeClient()})
tasks.append(spawnLoop { try await context.startUDPHole()})
tasks.append(spawnLoop { try await context.startDnsClient()})
tasks.append(spawnLoop { try await context.startMonitor()})
}
func stop() {
tasks.forEach {$0.cancel()}
tasks.removeAll()
}
private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task<Void, Never> {
return Task.detached {
while !Task.isCancelled {
do {
try await body()
} catch is CancellationError {
break
} catch {
try? await Task.sleep(nanoseconds: 2_000_000_000)
}
}
}
}
}

View File

@ -0,0 +1,29 @@
//
// SDLAsyncTimerStream.swift
// Tun
//
// Created by on 2026/2/3.
//
import Foundation
class SDLAsyncTimerStream {
let timer: DispatchSourceTimer
init() {
self.timer = DispatchSource.makeTimerSource(queue: .global())
}
func start(_ cont: AsyncStream<Void>.Continuation) {
timer.schedule(deadline: .now(), repeating: .seconds(5))
timer.setEventHandler {
cont.yield()
}
timer.resume()
}
deinit {
self.timer.cancel()
}
}

View File

@ -29,11 +29,14 @@ actor SDLContextActor {
nonisolated let rsaCipher: RSACipher nonisolated let rsaCipher: RSACipher
// //
var udpHole: SDLUDPHole? private var udpHole: SDLUDPHole?
private var udpHoleWorkers: [Task<Void, Never>]?
nonisolated let providerAdapter: SDLTunnelProviderAdapter nonisolated let providerAdapter: SDLTunnelProviderAdapter
var puncherActor: SDLPuncherActor? var puncherActor: SDLPuncherActor?
// dnsclient // dnsclient
var dnsClient: SDLDNSClient? private var dnsClient: SDLDNSClient?
private var dnsWorker: Task<Void, Never>?
// //
var proberActor: SDLNATProberActor? var proberActor: SDLNATProberActor?
@ -46,6 +49,7 @@ actor SDLContextActor {
// //
private var monitor: SDLNetworkMonitor? private var monitor: SDLNetworkMonitor?
private var monitorWorker: Task<Void, Never>?
// socket // socket
private var noticeClient: SDLNoticeClient? private var noticeClient: SDLNoticeClient?
@ -55,6 +59,9 @@ actor SDLContextActor {
nonisolated private let logger: SDLLogger nonisolated private let logger: SDLLogger
//
private var loopChildWorkers: [Task<Void, Never>] = []
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) { public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger self.logger = logger
self.config = config self.config = config
@ -66,20 +73,48 @@ actor SDLContextActor {
self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger) self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger)
} }
public func startNoticeClient() async throws { public func start() {
// noticeClient self.startMonitor()
self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
self.logger.log("[SDLContext] noticeClient started") self.loopChildWorkers.append(spawnLoop {
try await self.noticeClient?.waitClose() let noticeClient = try self.startNoticeClient()
try await noticeClient.waitClose()
self.logger.log("[SDLContext] noticeClient closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let dnsClient = try await self.startDnsClient()
try await dnsClient.waitClose()
self.logger.log("[SDLContext] dns closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let udpHole = try await self.startUDPHole()
try await udpHole.waitClose()
self.logger.log("[SDLContext] udp closed!!!!")
})
} }
public func startMonitor() async throws { private func startNoticeClient() throws -> SDLNoticeClient {
// noticeClient
let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
self.logger.log("[SDLContext] noticeClient started")
self.noticeClient = noticeClient
return noticeClient
}
private func startMonitor() {
self.monitorWorker?.cancel()
self.monitorWorker = nil
// monitor // monitor
let monitor = SDLNetworkMonitor() let monitor = SDLNetworkMonitor()
monitor.start() monitor.start()
self.logger.log("[SDLContext] monitor started") self.logger.log("[SDLContext] monitor started")
self.monitor = monitor self.monitor = monitor
self.monitorWorker = Task {
for await event in monitor.eventStream { for await event in monitor.eventStream {
switch event { switch event {
case .changed: case .changed:
@ -91,101 +126,120 @@ actor SDLContextActor {
} }
} }
} }
}
private func startDnsClient() async throws -> SDLDNSClient {
self.dnsWorker?.cancel()
self.dnsWorker = nil
public func startDnsClient() async throws {
// dns // dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353) let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger) let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
let channel = try dnsClient.start() try dnsClient.start()
self.logger.log("[SDLContext] dnsClient started") self.logger.log("[SDLContext] dnsClient started")
self.dnsClient = dnsClient self.dnsClient = dnsClient
self.dnsWorker = Task {
try await withThrowingTaskGroup(of: Void.self) {group in
group.addTask {
// //
for await packet in dnsClient.packetFlow { for await packet in dnsClient.packetFlow {
try Task.checkCancellation() if Task.isCancelled {
break
}
let nePacket = NEPacket(data: packet, protocolFamily: 2) let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket]) self.providerAdapter.writePackets(packets: [nePacket])
} }
} }
group.addTask { return dnsClient
try await channel.closeFuture.get()
} }
try await group.next() private func startUDPHole() async throws -> SDLUDPHole {
self.logger.log("[SDLContext] taskGroup cancel") self.udpHoleWorkers?.forEach {$0.cancel()}
group.cancelAll() self.udpHoleWorkers = nil
}
}
public func startUDPHole() async throws {
// udp // udp
let udpHole = try SDLUDPHole(logger: self.logger) let udpHole = try SDLUDPHole(logger: self.logger)
let channel = try udpHole.start() try udpHole.start()
self.logger.log("[SDLContext] udpHole started") self.logger.log("[SDLContext] udpHole started")
self.udpHole = udpHole self.udpHole = udpHole
await udpHole.channelIsActived() await udpHole.channelIsActived()
await self.handleUDPHoleReady() await self.handleUDPHoleReady()
try await withThrowingTaskGroup(of: Void.self) { group in //
group.addTask { let pingTask = Task.detached {
try await channel.closeFuture.get() let (stream, cont) = AsyncStream.makeStream(of: Void.self)
let timerStream = SDLAsyncTimerStream()
timerStream.start(cont)
for await _ in stream {
if Task.isCancelled {
break
}
self.logger.log("[SDLContext] will do stunRequest22")
await self.sendStunRequest()
self.logger.log("[SDLContext] will do stunRequest44")
} }
// UDP self.logger.log("[SDLContext] will do stunRequest55")
group.addTask {
while true {
try Task.checkCancellation()
try await Task.sleep(for: .seconds(5))
await self.sendStunRequest()
}
} }
// //
group.addTask { let dataTask = Task {
for try await data in udpHole.dataStream { for await data in udpHole.dataStream {
try Task.checkCancellation() if Task.isCancelled {
try await self.handleData(data: data) break
}
try? self.handleData(data: data)
} }
} }
// signal //
group.addTask { let signalTask = Task {
for try await(remoteAddress, signal) in udpHole.signalStream { for await(remoteAddress, signal) in udpHole.signalStream {
try Task.checkCancellation() if Task.isCancelled {
break
}
switch signal { switch signal {
case .registerSuperAck(let registerSuperAck): case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak): case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak) self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo): case .peerInfo(let peerInfo):
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo) await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
case .event(let event): case .event(let event):
try await self.handleEvent(event: event) try? self.handleEvent(event: event)
case .stunProbeReply(let probeReply): case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply) await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register): case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register) try? self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck): case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
} }
} }
} }
try await group.next() self.udpHoleWorkers = [pingTask, dataTask, signalTask]
group.cancelAll()
self.logger.log("[SDLContext] taskGroup cancel") return udpHole
}
} }
// context
public func stop() async { public func stop() async {
self.udpHole = nil self.loopChildWorkers.forEach { $0.cancel() }
self.noticeClient = nil self.loopChildWorkers.removeAll()
self.udpHoleWorkers?.forEach { $0.cancel() }
self.udpHoleWorkers = nil
self.dnsWorker?.cancel()
self.dnsWorker = nil
self.monitorWorker?.cancel()
self.monitorWorker = nil
self.readTask?.cancel() self.readTask?.cancel()
self.readTask = nil
} }
private func setNatType(natType: SDLNATProberActor.NatType) { private func setNatType(natType: SDLNATProberActor.NatType) {
@ -227,7 +281,7 @@ actor SDLContextActor {
} }
} }
private func sendStunRequest() async { private func sendStunRequest() {
var stunRequest = SDLStunRequest() var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId stunRequest.networkID = self.config.networkAddress.networkId
@ -418,21 +472,8 @@ actor SDLContextActor {
let packets = await self.providerAdapter.readPackets() let packets = await self.providerAdapter.readPackets()
let ipPackets = packets.compactMap { IPPacket($0) } let ipPackets = packets.compactMap { IPPacket($0) }
await self.batchProcessPackets(batchSize: 20, packets: ipPackets) for ipPacket in ipPackets {
} self.dealPacket(packet: ipPacket)
}
}
// ip
private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async {
for startIndex in stride(from: 0, to: packets.count, by: batchSize) {
let endIndex = Swift.min(startIndex + batchSize, packets.count)
let chunkPackets = packets[startIndex..<endIndex]
await withDiscardingTaskGroup() { group in
for packet in chunkPackets {
group.addTask {
await self.dealPacket(packet: packet)
}
} }
} }
} }
@ -508,6 +549,20 @@ actor SDLContextActor {
} }
} }
private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task<Void, Never> {
return Task.detached {
while !Task.isCancelled {
do {
try await body()
} catch is CancellationError {
break
} catch {
try? await Task.sleep(nanoseconds: 2_000_000_000)
}
}
}
}
deinit { deinit {
self.udpHole = nil self.udpHole = nil
self.dnsClient = nil self.dnsClient = nil

View File

@ -29,7 +29,7 @@ final class SDLDNSClient: ChannelInboundHandler {
(self.packetFlow, self.packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) (self.packetFlow, self.packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded)
} }
func start() throws -> Channel { func start() throws {
let bootstrap = DatagramBootstrap(group: group) let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in .channelInitializer { channel in
@ -39,8 +39,10 @@ final class SDLDNSClient: ChannelInboundHandler {
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[DNSClient] started", level: .debug) self.logger.log("[DNSClient] started", level: .debug)
self.channel = channel self.channel = channel
}
return channel func waitClose() async throws {
try await self.channel?.closeFuture.get()
} }
// --MARK: ChannelInboundHandler delegate // --MARK: ChannelInboundHandler delegate

View File

@ -49,7 +49,7 @@ final class SDLUDPHole: ChannelInboundHandler {
(self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded) (self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded)
} }
func start() throws -> Channel { func start() throws {
let bootstrap = DatagramBootstrap(group: group) let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in .channelInitializer { channel in
@ -59,8 +59,6 @@ final class SDLUDPHole: ChannelInboundHandler {
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[UDPHole] started", level: .debug) self.logger.log("[UDPHole] started", level: .debug)
self.channel = channel self.channel = channel
return channel
} }
func channelIsActived() async { func channelIsActived() async {
@ -73,6 +71,10 @@ final class SDLUDPHole: ChannelInboundHandler {
} }
} }
func waitClose() async throws {
try await self.channel?.closeFuture.get()
}
// --MARK: ChannelInboundHandler delegate // --MARK: ChannelInboundHandler delegate
func channelActive(context: ChannelHandlerContext) { func channelActive(context: ChannelHandlerContext) {