swiftlib_sdlan/Sources/sdlan/SDLSuperClient.swift
2025-07-14 15:33:40 +08:00

374 lines
13 KiB
Swift

//
// SDLWebsocketClient.swift
// Tun
//
// Created by on 2024/3/28.
//
import Foundation
import NIOCore
import NIOPosix
import Combine
// --MARK: SuperNode
class SDLSuperClient: ChannelInboundHandler {
public typealias InboundIn = ByteBuffer
public typealias OutboundOut = ByteBuffer
public typealias CallbackFun = (SDLSuperInboundMessage?) -> Void
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var channel: Channel?
// id
var idGenerator = SDLIdGenerator(seed: 1)
private let callbackManager = SuperCallbackManager()
let host: String
let port: Int
private var pingCancel: AnyCancellable?
public var eventFlow = PassthroughSubject<SuperEvent, Never>()
//
enum SuperEvent {
case ready
case closed
case event(SDLEvent)
case command(UInt32, SDLCommand)
}
init(host: String, port: Int) {
self.host = host
self.port = port
}
func start() async throws {
let bootstrap = ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
return channel.pipeline.addHandlers([
ByteToMessageHandler(FixedHeaderDelimiterCoder()),
MessageToByteHandler(FixedHeaderDelimiterCoder()),
self
])
}
do {
NSLog("super client connect: \(self.host):\(self.port)")
self.channel = try await bootstrap.connect(host: self.host, port: self.port).get()
} catch let err {
NSLog("super client get error: \(err)")
self.eventFlow.send(.closed)
}
}
// -- MARK: apis
func commandAck(packetId: UInt32, ack: SDLCommandAck) {
guard let data = try? ack.serializedData() else {
return
}
self.send(type: .commandAck, packetId: packetId, data: data)
}
func registerSuper(context ctx: SDLContext) async -> SDLSuperInboundMessage? {
return await withCheckedContinuation { c in
self.registerSuper(context: ctx) { message in
c.resume(returning: message)
}
}
}
func registerSuper(context ctx: SDLContext, callback: @escaping CallbackFun) {
var registerSuper = SDLRegisterSuper()
registerSuper.version = UInt32(ctx.config.version)
registerSuper.clientID = ctx.config.clientId
registerSuper.devAddr = ctx.devAddr
registerSuper.pubKey = ctx.rsaCipher.pubKey
registerSuper.token = ctx.config.token
let data = try! registerSuper.serializedData()
self.write(type: .registerSuper, data: data, callback: callback)
}
func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage? {
return await withCheckedContinuation { c in
self.queryInfo(context: ctx, dst_mac: dst_mac) { message in
c.resume(returning: message)
}
}
}
//
func queryInfo(context ctx: SDLContext, dst_mac: Data, callback: @escaping CallbackFun) {
var queryInfo = SDLQueryInfo()
queryInfo.dstMac = dst_mac
self.write(type: .queryInfo, data: try! queryInfo.serializedData(), callback: callback)
}
func unregister(context ctx: SDLContext) throws {
self.send(type: .unregisterSuper, packetId: 0, data: Data())
}
func ping() {
self.send(type: .ping, packetId: 0, data: Data())
}
func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) {
var flow = SDLFlows()
flow.forwardNum = forwardNum
flow.p2PNum = p2pNum
flow.inboundNum = inboundNum
self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData())
}
// --MARK: ChannelInboundHandler
public func channelActive(context: ChannelHandlerContext) {
self.startPingTicker()
self.eventFlow.send(.ready)
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var buffer = self.unwrapInboundIn(data)
if let message = decode(buffer: &buffer) {
SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning)
switch message.packet {
case .event(let event):
self.eventFlow.send(.event(event))
case .command(let command):
self.eventFlow.send(.command(message.msgId, command))
default:
self.callbackManager.fireCallback(message: message)
}
}
}
public func errorCaught(context: ChannelHandlerContext, error: Error) {
SDLLogger.log("[SDLSuperTransport] error: \(error)", level: .warning)
self.channel = nil
self.eventFlow.send(.closed)
context.close(promise: nil)
}
public func channelInactive(context: ChannelHandlerContext) {
SDLLogger.log("[SDLSuperTransport] channelInactive", level: .warning)
self.channel = nil
context.close(promise: nil)
}
func write(type: SDLPacketType, data: Data, callback: @escaping CallbackFun) {
guard let channel = self.channel else {
return
}
SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug)
let packetId = idGenerator.nextId()
self.callbackManager.addCallback(id: packetId, callback: callback)
channel.eventLoop.execute {
var buffer = channel.allocator.buffer(capacity: data.count + 5)
buffer.writeInteger(packetId, as: UInt32.self)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
}
func send(type: SDLPacketType, packetId: UInt32, data: Data) {
guard let channel = self.channel else {
return
}
channel.eventLoop.execute {
var buffer = channel.allocator.buffer(capacity: data.count + 5)
buffer.writeInteger(packetId, as: UInt32.self)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
}
}
// --MARK:
private func startPingTicker() {
self.pingCancel = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
.sink { _ in
// super-node
self.ping()
}
}
deinit {
self.pingCancel?.cancel()
try! group.syncShutdownGracefully()
}
}
/// 2
extension SDLSuperClient {
private final class FixedHeaderDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
return .needMoreData
}
if buffer.readableBytes >= len + 2 {
buffer.moveReaderIndex(forwardBy: 2)
if let bytes = buffer.readBytes(length: Int(len)) {
context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes)))
}
return .continue
} else {
return .needMoreData
}
}
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
let len = data.readableBytes
out.writeInteger(UInt16(len))
out.writeBytes(data.readableBytesView)
}
}
}
//
extension SDLSuperClient {
private final class SuperCallbackManager {
//
private var callbacks: [UInt32:CallbackFun] = [:]
private let locker = NSLock()
func addCallback(id: UInt32, callback: @escaping CallbackFun) {
locker.lock()
defer {
locker.unlock()
}
self.callbacks[id] = callback
}
func fireCallback(message: SDLSuperInboundMessage) {
locker.lock()
defer {
locker.unlock()
}
if let callback = self.callbacks[message.msgId] {
callback(message)
self.callbacks.removeValue(forKey: message.msgId)
}
}
func fireAllCallbacks(message: SDLSuperInboundMessage) {
locker.lock()
defer {
locker.unlock()
}
for (_, callback) in self.callbacks {
callback(nil)
}
self.callbacks.removeAll()
}
}
}
// --MARK:
extension SDLSuperClient {
// : <<MsgId:32, Type:8, Body/binary>>
func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? {
guard let msgId = buffer.readInteger(as: UInt32.self),
let type = buffer.readInteger(as: UInt8.self),
let messageType = SDLPacketType(rawValue: type) else {
return nil
}
switch messageType {
case .empty:
return .init(msgId: msgId, packet: .empty)
case .registerSuperAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperAck = try? SDLRegisterSuperAck(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck))
case .registerSuperNak:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperNak = try? SDLRegisterSuperNak(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak))
case .peerInfo:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let peerInfo = try? SDLPeerInfo(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .peerInfo(peerInfo))
case .pong:
return .init(msgId: msgId, packet: .pong)
case .command:
guard let commandVal = buffer.readInteger(as: UInt8.self),
let command = SDLCommandType(rawValue: commandVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch command {
case .changeNetwork:
guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand)))
}
case .event:
guard let eventVal = buffer.readInteger(as: UInt8.self),
let event = SDLEventType(rawValue: eventVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch event {
case .natChanged:
guard let natChangedEvent = try? SDLNatChangedEvent(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent)))
case .sendRegister:
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent)))
case .networkShutdown:
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedData: Data(bytes)) else {
return nil
}
return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent)))
}
default:
return nil
}
}
}