diff --git a/Tun/Punchnet/Actors/SDLPuncherActor.swift b/Tun/Punchnet/Actors/SDLPuncherActor.swift index 7118f2a..7141f5d 100644 --- a/Tun/Punchnet/Actors/SDLPuncherActor.swift +++ b/Tun/Punchnet/Actors/SDLPuncherActor.swift @@ -61,7 +61,7 @@ actor SDLPuncherActor { private func tryHole(request: RegisterRequest) async { var queryInfo = SDLQueryInfo() queryInfo.dstMac = request.dstMac - guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else { + guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()) else { return } diff --git a/Tun/Punchnet/Actors/SDLSuperClientActor.swift b/Tun/Punchnet/Actors/SDLSuperClientActor.swift index 692c66e..1f9ec44 100644 --- a/Tun/Punchnet/Actors/SDLSuperClientActor.swift +++ b/Tun/Punchnet/Actors/SDLSuperClientActor.swift @@ -17,8 +17,8 @@ actor SDLSuperClientActor { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private let asyncChannel: NIOAsyncChannel private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) - private var callbackPromises: [UInt32:EventLoopPromise] = [:] - + private var continuations: [UInt32:CheckedContinuation] = [:] + public let eventFlow: AsyncStream private let inboundContinuation: AsyncStream.Continuation @@ -34,6 +34,12 @@ actor SDLSuperClientActor { case command(UInt32, SDLCommand) } + enum SDLSuperClientError: Error { + case timeout + case connectionClosed + case cancelled + } + init(host: String, port: Int, logger: SDLLogger) async throws { self.logger = logger @@ -139,14 +145,17 @@ actor SDLSuperClientActor { self.send(type: .ping, packetId: 0, data: Data()) } - func request(type: SDLPacketType, data: Data) -> EventLoopFuture { + func request(type: SDLPacketType, data: Data, timeout: Duration = .seconds(5)) async throws -> SDLSuperInboundMessage { let packetId = idGenerator.nextId() - let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) - self.callbackPromises[packetId] = promise - self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) - - return promise.futureResult + return try await withCheckedThrowingContinuation { cont in + self.continuations[packetId] = cont + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) + Task { + try? await Task.sleep(for: timeout) + self.timeout(packetId: packetId) + } + } } func send(type: SDLPacketType, packetId: UInt32, data: Data) { @@ -155,12 +164,17 @@ actor SDLSuperClientActor { // 处理回调函数 private func fireCallback(message: SDLSuperInboundMessage) { - if let promise = self.callbackPromises[message.msgId] { - self.asyncChannel.channel.eventLoop.execute { - promise.succeed(message) - } - self.callbackPromises.removeValue(forKey: message.msgId) + guard let cont = self.continuations.removeValue(forKey: message.msgId) else { + return } + cont.resume(returning: message) + } + + private func timeout(packetId: UInt32) { + guard let cont = self.continuations.removeValue(forKey: packetId) else { + return + } + cont.resume(throwing: SDLSuperClientError.timeout) } deinit { diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 1ffbeae..398378c 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -298,7 +298,7 @@ public class SDLContext { registerSuper.token = self.config.token registerSuper.networkCode = self.config.networkCode registerSuper.hostname = self.config.hostname - guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()).get() else { + guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()) else { return }