// // SDLWebsocketClient.swift // Tun // // Created by 安礼成 on 2024/3/28. // import Foundation import NIOCore import NIOPosix // --MARK: 和SuperNode的客户端 actor SDLSuperClientActor { // 发送的消息格式 private typealias TcpMessage = (packetId: UInt32, type: SDLPacketType, data: Data) private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private let asyncChannel: NIOAsyncChannel private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) private var continuations: [UInt32:CheckedContinuation] = [:] public let eventFlow: AsyncStream private let inboundContinuation: AsyncStream.Continuation // id生成器 var idGenerator = SDLIdGenerator(seed: 1) private let logger: SDLLogger // 定义事件类型 enum SuperEvent { case ready case event(SDLEvent) case command(UInt32, SDLCommand) } enum SuperClientError: Error { case timeout case connectionClosed case cancelled } init(host: String, port: Int, logger: SDLLogger) async throws { self.logger = logger (self.eventFlow, self.inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded) let bootstrap = ClientBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in return channel.pipeline.addHandlers([ ByteToMessageHandler(FixedHeaderDecoder()), MessageToByteHandler(FixedHeaderEncoder()) ]) } self.asyncChannel = try await bootstrap.connect(host: host, port: port) .flatMapThrowing { channel in return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( inboundType: ByteBuffer.self, outboundType: ByteBuffer.self )) } .get() } func start() async throws { try await withTaskCancellationHandler { try await self.asyncChannel.executeThenClose { inbound, outbound in self.inboundContinuation.yield(.ready) try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { defer { self.logger.log("[SDLSuperClient] inbound closed", level: .warning) } for try await var packet in inbound { try Task.checkCancellation() if let message = SDLSuperClientDecoder.decode(buffer: &packet) { if !message.isPong() { self.logger.log("[SDLSuperClient] read message: \(message)", level: .debug) } switch message.packet { case .event(let event): self.inboundContinuation.yield(.event(event)) case .command(let command): self.inboundContinuation.yield(.command(message.msgId, command)) default: await self.fireCallback(message: message) } } } } group.addTask { defer { self.logger.log("[SDLSuperClient] outbound closed", level: .warning) } for await (packetId, type, data) in self.writeStream { try Task.checkCancellation() var buffer = self.asyncChannel.channel.allocator.buffer(capacity: data.count + 5) buffer.writeInteger(packetId, as: UInt32.self) buffer.writeBytes([type.rawValue]) buffer.writeBytes(data) try await outbound.write(buffer) } } // --MARK: 心跳机制 group.addTask { defer { self.logger.log("[SDLSuperClient] ping task closed", level: .warning) } while true { try Task.checkCancellation() await self.ping() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) } } // 迭代等待所有任务的退出, 第一个异常会被抛出 if let _ = try await group.next() { group.cancelAll() } } } } onCancel: { self.inboundContinuation.finish() self.writeContinuation.finish() self.logger.log("[SDLSuperClient] withTaskCancellationHandler cancel") Task { await self.failAllContinuations(SuperClientError.cancelled) } } } // -- MARK: apis func unregister() throws { self.send(type: .unregisterSuper, packetId: 0, data: Data()) } private func ping() { self.send(type: .ping, packetId: 0, data: Data()) } func request(type: SDLPacketType, data: Data, timeout: Duration = .seconds(5)) async throws -> SDLSuperInboundMessage { let packetId = idGenerator.nextId() return try await withCheckedThrowingContinuation { cont in self.continuations[packetId] = cont self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) Task { try? await Task.sleep(for: timeout) self.timeout(packetId: packetId) } } } func send(type: SDLPacketType, packetId: UInt32, data: Data) { self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) } // 处理回调函数 private func fireCallback(message: SDLSuperInboundMessage) { guard let cont = self.continuations.removeValue(forKey: message.msgId) else { return } cont.resume(returning: message) } private func failAllContinuations(_ error: Error) { let all = continuations continuations.removeAll() for (_, cont) in all { cont.resume(throwing: error) } } private func timeout(packetId: UInt32) { guard let cont = self.continuations.removeValue(forKey: packetId) else { return } cont.resume(throwing: SuperClientError.timeout) } deinit { try! group.syncShutdownGracefully() } } // --MARK: 编解码器 private struct SDLSuperClientDecoder { // 消息格式为: <> static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { guard let msgId = buffer.readInteger(as: UInt32.self), let type = buffer.readInteger(as: UInt8.self), let messageType = SDLPacketType(rawValue: type) else { return nil } switch messageType { case .empty: return .init(msgId: msgId, packet: .empty) case .registerSuperAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) case .registerSuperNak: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) case .peerInfo: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .peerInfo(peerInfo)) case .pong: return .init(msgId: msgId, packet: .pong) case .command: guard let commandVal = buffer.readInteger(as: UInt8.self), let command = SDLCommandType(rawValue: commandVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch command { case .changeNetwork: guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) } case .event: guard let eventVal = buffer.readInteger(as: UInt8.self), let event = SDLEventType(rawValue: eventVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch event { case .natChanged: guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) case .sendRegister: guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) case .networkShutdown: guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) } default: return nil } } } private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable { typealias InboundIn = ByteBuffer typealias InboundOut = ByteBuffer func encode(data: ByteBuffer, out: inout ByteBuffer) throws { let len = data.readableBytes out.writeInteger(UInt16(len)) out.writeBytes(data.readableBytesView) } } private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable { typealias InboundIn = ByteBuffer typealias InboundOut = ByteBuffer func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { return .needMoreData } if buffer.readableBytes >= len + 2 { buffer.moveReaderIndex(forwardBy: 2) if let bytes = buffer.readBytes(length: Int(len)) { context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) } return .continue } else { return .needMoreData } } }