// // SDLWebsocketClient.swift // Tun // // Created by 安礼成 on 2024/3/28. // import Foundation import NIOCore import NIOPosix // --MARK: 和SuperNode的客户端 @available(macOS 14, *) actor SDLSuperClient { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private let asyncChannel: NIOAsyncChannel private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) private var callbackPromises: [UInt32:EventLoopPromise] = [:] public let (eventFlow, inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded) // id生成器 var idGenerator = SDLIdGenerator(seed: 1) private let logger: SDLLogger // 发送的消息格式 struct TcpMessage { let packetId: UInt32 let type: SDLPacketType let data: Data } // 定义事件类型 enum SuperEvent { case ready case event(SDLEvent) case command(UInt32, SDLCommand) } init(host: String, port: Int, logger: SDLLogger) async throws { self.logger = logger let bootstrap = ClientBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in return channel.pipeline.addHandlers([ ByteToMessageHandler(FixedHeaderDecoder()), MessageToByteHandler(FixedHeaderEncoder()) ]) } self.asyncChannel = try await bootstrap.connect(host: host, port: port) .flatMapThrowing { channel in return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( inboundType: ByteBuffer.self, outboundType: ByteBuffer.self )) } .get() } func start() async throws { try await self.asyncChannel.executeThenClose { inbound, outbound in self.inboundContinuation.yield(.ready) try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.asyncChannel.channel.closeFuture.get() self.logger.log("[SDLSuperClient] socket closed", level: .warning) throw SDLError.socketClosed } group.addTask { defer { self.inboundContinuation.finish() } for try await var packet in inbound { if let message = SDLSuperClientDecoder.decode(buffer: &packet) { self.logger.log("[SDLSuperTransport] read message: \(message)", level: .debug) switch message.packet { case .event(let event): self.inboundContinuation.yield(.event(event)) case .command(let command): self.inboundContinuation.yield(.command(message.msgId, command)) default: await self.fireCallback(message: message) } } } self.logger.log("[SDLSuperClient] inbound closed", level: .warning) } group.addTask { defer { self.writeContinuation.finish() } for try await message in self.writeStream { var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5) buffer.writeInteger(message.packetId, as: UInt32.self) buffer.writeBytes([message.type.rawValue]) buffer.writeBytes(message.data) try await outbound.write(buffer) } self.logger.log("[SDLSuperClient] outbound closed", level: .warning) } // --MARK: 心跳机制 group.addTask { while true { do { await self.ping() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) } catch let err { self.logger.log("[SDLSuperClient] heartbeat cancelled with error: \(err)", level: .warning) break } } } // 迭代等待所有任务的退出, 第一个异常会被抛出 for try await _ in group { } self.logger.log("[SDLSuperClient] group closed", level: .warning) } } } // -- MARK: apis func commandAck(packetId: UInt32, ack: SDLCommandAck) { guard let data = try? ack.serializedData() else { return } self.send(type: .commandAck, packetId: packetId, data: data) } func registerSuper(context ctx: SDLContext) throws -> EventLoopFuture { var registerSuper = SDLRegisterSuper() registerSuper.version = UInt32(ctx.config.version) registerSuper.clientID = ctx.config.clientId registerSuper.devAddr = ctx.devAddr registerSuper.pubKey = ctx.rsaCipher.pubKey registerSuper.token = ctx.config.token let data = try! registerSuper.serializedData() return self.write(type: .registerSuper, data: data) } // 查询目标服务器的相关信息 func queryInfo(dst_mac: Data) async throws -> EventLoopFuture { var queryInfo = SDLQueryInfo() queryInfo.dstMac = dst_mac return self.write(type: .queryInfo, data: try! queryInfo.serializedData()) } func unregister(context ctx: SDLContext) throws { self.send(type: .unregisterSuper, packetId: 0, data: Data()) } func ping() { self.send(type: .ping, packetId: 0, data: Data()) } func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) { var flow = SDLFlows() flow.forwardNum = forwardNum flow.p2PNum = p2pNum flow.inboundNum = inboundNum self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData()) } private func write(type: SDLPacketType, data: Data) -> EventLoopFuture { let packetId = idGenerator.nextId() let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) self.callbackPromises[packetId] = promise self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) return promise.futureResult } private func send(type: SDLPacketType, packetId: UInt32, data: Data) { self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) } // 处理回调函数 private func fireCallback(message: SDLSuperInboundMessage) { if let promise = self.callbackPromises[message.msgId] { self.asyncChannel.channel.eventLoop.execute { promise.succeed(message) } self.callbackPromises.removeValue(forKey: message.msgId) } } deinit { try! group.syncShutdownGracefully() } } // --MARK: 编解码器 private struct SDLSuperClientDecoder { // 消息格式为: <> static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { guard let msgId = buffer.readInteger(as: UInt32.self), let type = buffer.readInteger(as: UInt8.self), let messageType = SDLPacketType(rawValue: type) else { return nil } switch messageType { case .empty: return .init(msgId: msgId, packet: .empty) case .registerSuperAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) case .registerSuperNak: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) case .peerInfo: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .peerInfo(peerInfo)) case .pong: return .init(msgId: msgId, packet: .pong) case .command: guard let commandVal = buffer.readInteger(as: UInt8.self), let command = SDLCommandType(rawValue: commandVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch command { case .changeNetwork: guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) } case .event: guard let eventVal = buffer.readInteger(as: UInt8.self), let event = SDLEventType(rawValue: eventVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch event { case .natChanged: guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) case .sendRegister: guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) case .networkShutdown: guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { return nil } return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) } default: return nil } } } private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable { typealias InboundIn = ByteBuffer typealias InboundOut = ByteBuffer func encode(data: ByteBuffer, out: inout ByteBuffer) throws { let len = data.readableBytes out.writeInteger(UInt16(len)) out.writeBytes(data.readableBytesView) } } private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable { typealias InboundIn = ByteBuffer typealias InboundOut = ByteBuffer func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { return .needMoreData } if buffer.readableBytes >= len + 2 { buffer.moveReaderIndex(forwardBy: 2) if let bytes = buffer.readBytes(length: Int(len)) { context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) } return .continue } else { return .needMoreData } } }