diff --git a/Tun/Punchnet/SDLIPV6AssistClient.swift b/Tun/Punchnet/SDLIPV6AssistClient.swift index ddc8f3f..1d50cf5 100644 --- a/Tun/Punchnet/SDLIPV6AssistClient.swift +++ b/Tun/Punchnet/SDLIPV6AssistClient.swift @@ -9,9 +9,14 @@ import Network enum SDLIPV6AssistError: Error { case lostConnection + case requestTimeout } actor SDLIPV6AssistClient { + private struct PendingRequest { + let continuation: CheckedContinuation + let timeoutTask: Task + } private enum State { case idle @@ -25,8 +30,8 @@ actor SDLIPV6AssistClient { private let assistServerAddress: NWEndpoint private var packetId: UInt32 = 1 - private var pendingRequests: [UInt32: CheckedContinuation] = [:] - + private var pendingRequests: [UInt32: PendingRequest] = [:] + // 用来处理关闭事件 private let closeStream: AsyncStream private let closeContinuation: AsyncStream.Continuation @@ -104,40 +109,52 @@ actor SDLIPV6AssistClient { } } - func probe() async throws -> SDLV6AssistProbeReply { + func probe(requestTimeout: Duration = .seconds(5)) async throws -> SDLV6AssistProbeReply { guard case .running = self.state, let connection = self.connection, connection.state == .ready else { throw SDLIPV6AssistError.lostConnection } + let pktId = self.nextPacketId() + let requestTimeout = self.requestTimeout + var assistProbe = SDLV6AssistProbe() + assistProbe.pktID = pktId + let data = try assistProbe.serializedData() + return try await withCheckedThrowingContinuation { cont in - let pktId = self.nextPacketId() - var assistProbe = SDLV6AssistProbe() - assistProbe.pktID = pktId - - do { - let data = try assistProbe.serializedData() - connection.send(content: data, completion: .contentProcessed { error in - if let error { - Task { - await self.handleProcesseError(packetId: pktId, error: error) - } - } - }) - self.pendingRequests[pktId] = cont - } catch let err { - cont.resume(throwing: err) + let timeoutTask = Task { [weak self] in + try? await Task.sleep(for: requestTimeout) + await self?.handleRequestTimeout(packetId: pktId) } + + self.pendingRequests[pktId] = .init(continuation: cont, timeoutTask: timeoutTask) + connection.send(content: data, completion: .contentProcessed { error in + if let error { + Task { + await self.handleProcessError(packetId: pktId, error: error) + } + } + }) } } - private func handleProcesseError(packetId: UInt32, error: NWError) { - if let cont = self.pendingRequests.removeValue(forKey: packetId) { - cont.resume(throwing: error) + private func handleProcessError(packetId: UInt32, error: NWError) { + if let request = self.takePendingRequest(packetId: packetId) { + request.continuation.resume(throwing: error) + } + } + + private func handleRequestTimeout(packetId: UInt32) { + if let request = self.takePendingRequest(packetId: packetId) { + request.continuation.resume(throwing: SDLIPV6AssistError.requestTimeout) } } func stop() { + self.stop(pendingError: SDLIPV6AssistError.lostConnection) + } + + private func stop(pendingError: any Error) { guard self.state != .stopped else { return } @@ -147,6 +164,7 @@ actor SDLIPV6AssistClient { self.receiveTask = nil self.connection?.cancel() self.connection = nil + self.failAllPendingRequests(error: pendingError) self.finishCloseStreamIfNeeded() } @@ -161,7 +179,7 @@ actor SDLIPV6AssistClient { self.startReceiveTask(for: connection) case .failed(let error): SDLLogger.log("[SDLIPV6AssistClient] Connection failed: \(error)", for: .debug) - self.stop() + self.stop(pendingError: error) case .cancelled: self.stop() default: @@ -191,8 +209,8 @@ actor SDLIPV6AssistClient { do { let packet = try SDLV6AssistProbeReply(serializedBytes: data) let pktId = packet.pktID - if let cont = self.pendingRequests.removeValue(forKey: pktId) { - cont.resume(returning: packet) + if let request = self.takePendingRequest(packetId: pktId) { + request.continuation.resume(returning: packet) } } catch { SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug) @@ -227,6 +245,25 @@ actor SDLIPV6AssistClient { return packetId } + private func takePendingRequest(packetId: UInt32) -> PendingRequest? { + guard let request = self.pendingRequests.removeValue(forKey: packetId) else { + return nil + } + + request.timeoutTask.cancel() + return request + } + + private func failAllPendingRequests(error: any Error) { + let pendingRequests = self.pendingRequests + self.pendingRequests.removeAll() + + pendingRequests.values.forEach { request in + request.timeoutTask.cancel() + request.continuation.resume(throwing: error) + } + } + deinit { self.connection?.cancel() }