diff --git a/Sources/Punchnet/SDLContext.swift b/Sources/Punchnet/SDLContext.swift index 1aff3ac..cc9fe7e 100644 --- a/Sources/Punchnet/SDLContext.swift +++ b/Sources/Punchnet/SDLContext.swift @@ -62,12 +62,8 @@ public class SDLContext: @unchecked Sendable { // 记录最后发送的stunRequest的cookie private var lastCookie: UInt32? = 0 - // 定时器 - private var stunCancel: AnyCancellable? - // 网络状态变化的健康 - private var monitor = SDLNetworkMonitor() - private var monitorCancel: AnyCancellable? + private var monitor: SDLNetworkMonitor? // 内部socket通讯 private var noticeClient: SDLNoticeClient? @@ -83,6 +79,8 @@ public class SDLContext: @unchecked Sendable { private let logger: SDLLogger + private var rootTask: Task? + struct RegisterRequest { let srcMac: Data let dstMac: Data @@ -107,59 +105,68 @@ public class SDLContext: @unchecked Sendable { } public func start() async throws { - self.noticeClient = try await SDLNoticeClient(logger: self.logger) - - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - while !Task.isCancelled { - do { - try await self.startUDPHole() - } catch let err { - self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) + self.rootTask = Task { + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while !Task.isCancelled { + do { + try await self.startUDPHole() + } catch let err { + self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning) + } } } - } - - group.addTask { - while !Task.isCancelled { - do { - try await self.startSuperClient() - } catch let err { - self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning) - await self.arpServer.clear() - try? await Task.sleep(for: .seconds(2)) + + group.addTask { + while !Task.isCancelled { + do { + try await self.startSuperClient() + } catch let err { + self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning) + await self.arpServer.clear() + try? await Task.sleep(for: .seconds(2)) + } } } - } - - group.addTask { - try await self.startMonitor() - } - - group.addTask { - while !Task.isCancelled { - do { - try await self.noticeClient?.start() - } catch let err { - self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) + + group.addTask { + await self.startMonitor() + } + + group.addTask { + while !Task.isCancelled { + do { + try await self.startNoticeClient() + } catch let err { + self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning) + } } } + + try await group.waitForAll() } - - try await group.waitForAll() } + + try await self.rootTask?.value } public func stop() async { + self.rootTask?.cancel() self.superClient = nil self.udpHole = nil - + self.noticeClient = nil + self.readTask?.cancel() } + private func startNoticeClient() async throws { + self.noticeClient = try await SDLNoticeClient(logger: self.logger) + try await self.noticeClient?.start() + self.logger.log("[SDLContext] notice_client task cancel", level: .warning) + } + private func startUDPHole() async throws { self.udpHole = try await SDLUDPHole(logger: self.logger) - try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.udpHole?.start() @@ -180,12 +187,13 @@ public class SDLContext: @unchecked Sendable { } } try await group.waitForAll() + self.logger.log("[SDLContext] udp_hole task cancel", level: .warning) } + } private func startSuperClient() async throws { self.superClient = try await SDLSuperClient(host: self.config.superHost, port: self.config.superPort, logger: self.logger) - try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.superClient?.start() @@ -199,33 +207,21 @@ public class SDLContext: @unchecked Sendable { } } try await group.waitForAll() + self.logger.log("[SDLContext] super client task cancel", level: .warning) } } - private func startMonitor() async throws { - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await self.noticeClient?.start() + private func startMonitor() async { + self.monitor = SDLNetworkMonitor() + for await event in self.monitor!.eventStream { + switch event { + case .changed: + // 需要重新探测网络的nat类型 + self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) + self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) + case .unreachable: + self.logger.log("didNetworkPathUnreachable", level: .warning) } - - group.addTask { - // 启动网络监控 - self.monitorCancel = self.monitor.eventFlow.sink { event in - switch event { - case .changed: - // 需要重新探测网络的nat类型 - Task { - self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger) - self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) - } - case .unreachable: - self.logger.log("didNetworkPathUnreachable", level: .warning) - } - } - self.monitor.start() - } - - try await group.waitForAll() } } @@ -579,7 +575,7 @@ public class SDLContext: @unchecked Sendable { } deinit { - self.stunCancel?.cancel() + self.rootTask?.cancel() self.udpHole = nil self.superClient = nil } diff --git a/Sources/Punchnet/SDLNetworkMonitor.swift b/Sources/Punchnet/SDLNetworkMonitor.swift index 840429e..5a76bee 100644 --- a/Sources/Punchnet/SDLNetworkMonitor.swift +++ b/Sources/Punchnet/SDLNetworkMonitor.swift @@ -15,9 +15,9 @@ class SDLNetworkMonitor: @unchecked Sendable { private var interfaceType: NWInterface.InterfaceType? private let publisher = PassthroughSubject() private var cancel: AnyCancellable? - private let queue = DispatchQueue(label: "networkMonitorQueue") - public let eventFlow = PassthroughSubject() + public let eventStream: AsyncStream + private let eventContinuation: AsyncStream.Continuation enum MonitorEvent { case changed @@ -26,6 +26,7 @@ class SDLNetworkMonitor: @unchecked Sendable { init() { self.monitor = NWPathMonitor() + (self.eventStream , self.eventContinuation) = AsyncStream.makeStream(of: MonitorEvent.self, bufferingPolicy: .unbounded) } func start() { @@ -39,16 +40,16 @@ class SDLNetworkMonitor: @unchecked Sendable { self.publisher.send(.wiredEthernet) } } else { - self.eventFlow.send(.unreachable) + self.eventContinuation.yield(.unreachable) self.interfaceType = nil } } - self.monitor.start(queue: self.queue) + self.monitor.start(queue: DispatchQueue.global()) - self.cancel = publisher.throttle(for: 5.0, scheduler: self.queue, latest: true) + self.cancel = publisher.throttle(for: 5.0, scheduler: DispatchQueue.global(), latest: true) .sink { type in if self.interfaceType != nil && self.interfaceType != type { - self.eventFlow.send(.changed) + self.eventContinuation.yield(.changed) } self.interfaceType = type } @@ -57,6 +58,7 @@ class SDLNetworkMonitor: @unchecked Sendable { deinit { self.monitor.cancel() self.cancel?.cancel() + self.eventContinuation.finish() } } diff --git a/Sources/Punchnet/SDLSuperClient.swift b/Sources/Punchnet/SDLSuperClient.swift index fa48c27..63ad651 100644 --- a/Sources/Punchnet/SDLSuperClient.swift +++ b/Sources/Punchnet/SDLSuperClient.swift @@ -75,6 +75,10 @@ actor SDLSuperClient { } for try await var packet in inbound { + if Task.isCancelled { + break + } + if let message = SDLSuperClientDecoder.decode(buffer: &packet) { self.logger.log("[SDLSuperTransport] read message: \(message)", level: .debug) switch message.packet { @@ -96,6 +100,10 @@ actor SDLSuperClient { } for try await message in self.writeStream { + if Task.isCancelled { + break + } + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5) buffer.writeInteger(message.packetId, as: UInt32.self) buffer.writeBytes([message.type.rawValue]) @@ -107,7 +115,7 @@ actor SDLSuperClient { // --MARK: 心跳机制 group.addTask { - while true { + while !Task.isCancelled { do { await self.ping() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) @@ -126,6 +134,7 @@ actor SDLSuperClient { self.logger.log("[SDLSuperClient] group closed", level: .warning) } } + } // -- MARK: apis diff --git a/Sources/Punchnet/SDLUDPHole.swift b/Sources/Punchnet/SDLUDPHole.swift index 10426fc..143a660 100644 --- a/Sources/Punchnet/SDLUDPHole.swift +++ b/Sources/Punchnet/SDLUDPHole.swift @@ -58,9 +58,8 @@ actor SDLUDPHole { } func start() async throws { - try await self.asyncChannel.executeThenClose { inbound, outbound in + try await self.asyncChannel.executeThenClose {inbound, outbound in self.eventContinuation.yield(.ready) - try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.asyncChannel.channel.closeFuture.get() @@ -96,6 +95,10 @@ actor SDLUDPHole { group.addTask { for try await message in self.writeStream { + if Task.isCancelled { + break + } + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) buffer.writeBytes([message.type.rawValue]) buffer.writeBytes(message.data) @@ -266,4 +269,5 @@ actor SDLUDPHole { self.writeContinuation.finish() self.eventContinuation.finish() } + }