// // SDLWebsocketClient.swift // Tun // // Created by 安礼成 on 2024/3/28. // import Foundation import NIOCore import NIOPosix import Combine // --MARK: 和SuperNode的客户端 class SDLSuperClient: ChannelInboundHandler { public typealias InboundIn = ByteBuffer public typealias OutboundOut = ByteBuffer public typealias CallbackFun = (SDLSuperInboundMessage?) -> Void private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? // id生成器 var idGenerator = SDLIdGenerator(seed: 1) private let callbackManager = SuperCallbackManager() let host: String let port: Int private var pingCancel: AnyCancellable? public var eventFlow = PassthroughSubject() // 定义事件类型 enum SuperEvent { case ready case closed case event(SDLEvent) case command(UInt32, SDLCommand) } init(host: String, port: Int) { self.host = host self.port = port } func start() async throws { let bootstrap = ClientBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in return channel.pipeline.addHandlers([ ByteToMessageHandler(FixedHeaderDelimiterCoder()), MessageToByteHandler(FixedHeaderDelimiterCoder()), self ]) } do { NSLog("super client connect: \(self.host):\(self.port)") self.channel = try await bootstrap.connect(host: self.host, port: self.port).get() } catch let err { NSLog("super client get error: \(err)") self.eventFlow.send(.closed) } } // -- MARK: apis func commandAck(packetId: UInt32, ack: SDLCommandAck) { guard let data = try? ack.serializedData() else { return } self.send(type: .commandAck, packetId: packetId, data: data) } func registerSuper(context ctx: SDLContext) async -> SDLSuperInboundMessage? { return await withCheckedContinuation { c in self.registerSuper(context: ctx) { message in c.resume(returning: message) } } } func registerSuper(context ctx: SDLContext, callback: @escaping CallbackFun) { var registerSuper = SDLRegisterSuper() registerSuper.version = UInt32(ctx.config.version) registerSuper.clientID = ctx.config.clientId registerSuper.devAddr = ctx.devAddr registerSuper.pubKey = ctx.rsaCipher.pubKey registerSuper.token = ctx.config.token let data = try! registerSuper.serializedData() self.write(type: .registerSuper, data: data, callback: callback) } func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage? { return await withCheckedContinuation { c in self.queryInfo(context: ctx, dst_mac: dst_mac) { message in c.resume(returning: message) } } } // 查询目标服务器的相关信息 func queryInfo(context ctx: SDLContext, dst_mac: Data, callback: @escaping CallbackFun) { var queryInfo = SDLQueryInfo() queryInfo.dstMac = dst_mac self.write(type: .queryInfo, data: try! queryInfo.serializedData(), callback: callback) } func unregister(context ctx: SDLContext) throws { self.send(type: .unregisterSuper, packetId: 0, data: Data()) } func ping() { self.send(type: .ping, packetId: 0, data: Data()) } func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) { var flow = SDLFlows() flow.forwardNum = forwardNum flow.p2PNum = p2pNum flow.inboundNum = inboundNum self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData()) } // --MARK: ChannelInboundHandler public func channelActive(context: ChannelHandlerContext) { self.startPingTicker() self.eventFlow.send(.ready) } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { var buffer = self.unwrapInboundIn(data) if let message = decode(buffer: &buffer) { SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning) switch message.packet { case .event(let event): self.eventFlow.send(.event(event)) case .command(let command): self.eventFlow.send(.command(message.msgId, command)) default: self.callbackManager.fireCallback(message: message) } } } public func errorCaught(context: ChannelHandlerContext, error: Error) { SDLLogger.log("[SDLSuperTransport] error: \(error)", level: .warning) self.channel = nil self.eventFlow.send(.closed) context.close(promise: nil) } public func channelInactive(context: ChannelHandlerContext) { SDLLogger.log("[SDLSuperTransport] channelInactive", level: .warning) self.channel = nil context.close(promise: nil) } func write(type: SDLPacketType, data: Data, callback: @escaping CallbackFun) { guard let channel = self.channel else { return } SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug) let packetId = idGenerator.nextId() self.callbackManager.addCallback(id: packetId, callback: callback) channel.eventLoop.execute { var buffer = channel.allocator.buffer(capacity: data.count + 5) buffer.writeInteger(packetId, as: UInt32.self) buffer.writeBytes([type.rawValue]) buffer.writeBytes(data) channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } } func send(type: SDLPacketType, packetId: UInt32, data: Data) { guard let channel = self.channel else { return } channel.eventLoop.execute { var buffer = channel.allocator.buffer(capacity: data.count + 5) buffer.writeInteger(packetId, as: UInt32.self) buffer.writeBytes([type.rawValue]) buffer.writeBytes(data) channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil) } } // --MARK: 心跳机制 private func startPingTicker() { self.pingCancel = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect() .sink { _ in // 保持和super-node的心跳机制 self.ping() } } deinit { self.pingCancel?.cancel() try! group.syncShutdownGracefully() } } /// 基于2字节固定长度的分包协议 extension SDLSuperClient { private final class FixedHeaderDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder { typealias InboundIn = ByteBuffer typealias InboundOut = ByteBuffer func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { return .needMoreData } if buffer.readableBytes >= len + 2 { buffer.moveReaderIndex(forwardBy: 2) if let bytes = buffer.readBytes(length: Int(len)) { context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) } return .continue } else { return .needMoreData } } func encode(data: ByteBuffer, out: inout ByteBuffer) throws { let len = data.readableBytes out.writeInteger(UInt16(len)) out.writeBytes(data.readableBytesView) } } } // 回调函数管理器 extension SDLSuperClient { private final class SuperCallbackManager { // 对应请求体和相应的关系 private var callbacks: [UInt32:CallbackFun] = [:] private let locker = NSLock() func addCallback(id: UInt32, callback: @escaping CallbackFun) { locker.lock() defer { locker.unlock() } self.callbacks[id] = callback } func fireCallback(message: SDLSuperInboundMessage) { locker.lock() defer { locker.unlock() } if let callback = self.callbacks[message.msgId] { callback(message) self.callbacks.removeValue(forKey: message.msgId) } } func fireAllCallbacks(message: SDLSuperInboundMessage) { locker.lock() defer { locker.unlock() } for (_, callback) in self.callbacks { callback(nil) } self.callbacks.removeAll() } } } // --MARK: 编解码器 extension SDLSuperClient { // 消息格式为: <> func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { guard let msgId = buffer.readInteger(as: UInt32.self), let type = buffer.readInteger(as: UInt8.self), let messageType = SDLPacketType(rawValue: type) else { return nil } switch messageType { case .empty: return .init(msgId: msgId, packet: .empty) case .registerSuperAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperAck = try? SDLRegisterSuperAck(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) case .registerSuperNak: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerSuperNak = try? SDLRegisterSuperNak(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) case .peerInfo: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let peerInfo = try? SDLPeerInfo(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .peerInfo(peerInfo)) case .pong: return .init(msgId: msgId, packet: .pong) case .command: guard let commandVal = buffer.readInteger(as: UInt8.self), let command = SDLCommandType(rawValue: commandVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch command { case .changeNetwork: guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) } case .event: guard let eventVal = buffer.readInteger(as: UInt8.self), let event = SDLEventType(rawValue: eventVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { return nil } switch event { case .natChanged: guard let natChangedEvent = try? SDLNatChangedEvent(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) case .sendRegister: guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) case .networkShutdown: guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedData: Data(bytes)) else { return nil } return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) } default: return nil } } }