diff --git a/Sources/Punchnet/SDLContext.swift b/Sources/Punchnet/SDLContext.swift index 72d0c83..e840170 100644 --- a/Sources/Punchnet/SDLContext.swift +++ b/Sources/Punchnet/SDLContext.swift @@ -219,7 +219,7 @@ public class SDLContext: @unchecked Sendable { switch event { case .ready: NSLog("[SDLContext] get registerSuper, mac address: \(Self.formatMacAddress(mac: self.devAddr.mac))") - guard let message = try await self.superClient?.registerSuper(context: self) else { + guard let message = try await self.superClient?.registerSuper(context: self).get() else { return } @@ -529,7 +529,7 @@ public class SDLContext: @unchecked Sendable { func holerTask(dstMac: Data) -> Task<(), Never> { return Task { - guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac) else { + guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac).get() else { return } diff --git a/Sources/Punchnet/SDLSuperClient.swift b/Sources/Punchnet/SDLSuperClient.swift index d779fc8..b06709e 100644 --- a/Sources/Punchnet/SDLSuperClient.swift +++ b/Sources/Punchnet/SDLSuperClient.swift @@ -117,7 +117,7 @@ actor SDLSuperClient { self.send(type: .commandAck, packetId: packetId, data: data) } - func registerSuper(context ctx: SDLContext) async throws -> SDLSuperInboundMessage { + func registerSuper(context ctx: SDLContext) throws -> EventLoopFuture { var registerSuper = SDLRegisterSuper() registerSuper.version = UInt32(ctx.config.version) registerSuper.clientID = ctx.config.clientId @@ -127,15 +127,15 @@ actor SDLSuperClient { let data = try! registerSuper.serializedData() - return try await self.write(type: .registerSuper, data: data).get() + return self.write(type: .registerSuper, data: data) } // 查询目标服务器的相关信息 - func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage { + func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> EventLoopFuture { var queryInfo = SDLQueryInfo() queryInfo.dstMac = dst_mac - return try await self.write(type: .queryInfo, data: try! queryInfo.serializedData()).get() + return self.write(type: .queryInfo, data: try! queryInfo.serializedData()) } func unregister(context ctx: SDLContext) throws { @@ -155,17 +155,18 @@ actor SDLSuperClient { self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData()) } - func write(type: SDLPacketType, data: Data) -> EventLoopFuture { + private func write(type: SDLPacketType, data: Data) -> EventLoopFuture { SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug) let packetId = idGenerator.nextId() let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) self.callbackPromises[packetId] = promise + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) return promise.futureResult } - func send(type: SDLPacketType, packetId: UInt32, data: Data) { + private func send(type: SDLPacketType, packetId: UInt32, data: Data) { self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) } diff --git a/Sources/Punchnet/SDLUDPHole.swift b/Sources/Punchnet/SDLUDPHole.swift index 37e5fa0..4c8c143 100644 --- a/Sources/Punchnet/SDLUDPHole.swift +++ b/Sources/Punchnet/SDLUDPHole.swift @@ -9,12 +9,6 @@ import Foundation import NIOCore import NIOPosix -struct UDPMessage { - let remoteAddress: SocketAddress - let type: SDLPacketType - let data: Data -} - // 处理和sn-server服务器之间的通讯 actor SDLUDPHole { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) @@ -26,6 +20,15 @@ actor SDLUDPHole { public var localAddress: SocketAddress? public let (eventFlow, eventContinuation) = AsyncStream.makeStream(of: UDPEvent.self, bufferingPolicy: .unbounded) + + // 标记当前通道是否关闭 + private var isClosed: Bool = true + + struct UDPMessage { + let remoteAddress: SocketAddress + let type: SDLPacketType + let data: Data + } // 定义事件类型 enum UDPEvent { @@ -55,10 +58,12 @@ actor SDLUDPHole { func start() async throws { try await self.asyncChannel.executeThenClose { inbound, outbound in self.eventContinuation.yield(.ready) + self.closeChannel(closed: false) try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.asyncChannel.channel.closeFuture.get() + await self.closeChannel(closed: true) self.eventContinuation.finish() } @@ -101,6 +106,7 @@ actor SDLUDPHole { try await outbound.write(envelope) } } + try await group.waitForAll() } } @@ -157,6 +163,10 @@ actor SDLUDPHole { } } + private func closeChannel(closed: Bool) { + self.isClosed = closed + } + // MARK: client-client apis // 发送数据包到其他session @@ -220,6 +230,10 @@ actor SDLUDPHole { // 处理写入逻辑 private func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) { + guard !self.isClosed else { + return + } + let message = UDPMessage(remoteAddress: remoteAddress, type: type, data: data) self.writeContinuation.yield(message) }