swiftlib_sdlan/Sources/Punchnet/SDLSuperClient.swift
2025-08-03 12:26:54 +08:00

322 lines
12 KiB
Swift

//
// SDLWebsocketClient.swift
// Tun
//
// Created by on 2024/3/28.
//
import Foundation
import NIOCore
import NIOPosix
// --MARK: SuperNode
actor SDLSuperClient {
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private let asyncChannel: NIOAsyncChannel<ByteBuffer,ByteBuffer>
private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded)
private var callbackPromises: [UInt32:EventLoopPromise<SDLSuperInboundMessage>] = [:]
public let (eventFlow, inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded)
// id
var idGenerator = SDLIdGenerator(seed: 1)
//
struct TcpMessage {
let packetId: UInt32
let type: SDLPacketType
let data: Data
}
//
enum SuperEvent {
case ready
case event(SDLEvent)
case command(UInt32, SDLCommand)
}
init(host: String, port: Int) async throws {
let bootstrap = ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
return channel.pipeline.addHandlers([
ByteToMessageHandler(FixedHeaderDecoder()),
MessageToByteHandler(FixedHeaderEncoder())
])
}
self.asyncChannel = try await bootstrap.connect(host: host, port: port)
.flatMapThrowing { channel in
return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init(
inboundType: ByteBuffer.self,
outboundType: ByteBuffer.self
))
}
.get()
}
func start() async throws {
try await self.asyncChannel.executeThenClose { inbound, outbound in
self.inboundContinuation.yield(.ready)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await self.asyncChannel.channel.closeFuture.get()
NSLog("[SDLSuperClient] socket closed")
throw SDLError.socketClosed
}
group.addTask {
defer {
self.inboundContinuation.finish()
}
for try await var packet in inbound {
if let message = SDLSuperClientDecoder.decode(buffer: &packet) {
SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning)
switch message.packet {
case .event(let event):
self.inboundContinuation.yield(.event(event))
case .command(let command):
self.inboundContinuation.yield(.command(message.msgId, command))
default:
await self.fireCallback(message: message)
}
}
}
NSLog("[SDLSuperClient] inbound closed")
}
group.addTask {
defer {
self.writeContinuation.finish()
}
for try await message in self.writeStream {
var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5)
buffer.writeInteger(message.packetId, as: UInt32.self)
buffer.writeBytes([message.type.rawValue])
buffer.writeBytes(message.data)
try await outbound.write(buffer)
}
NSLog("[SDLSuperClient] outbound closed")
}
// --MARK:
group.addTask {
while true {
do {
await self.ping()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
} catch let err {
NSLog("[SDLSuperClient] heartbeat cancelled with error: \(err)")
break
}
}
}
// 退,
for try await _ in group {
}
NSLog("[SDLSuperClient] group closed")
}
}
}
// -- MARK: apis
func commandAck(packetId: UInt32, ack: SDLCommandAck) {
guard let data = try? ack.serializedData() else {
return
}
self.send(type: .commandAck, packetId: packetId, data: data)
}
func registerSuper(context ctx: SDLContext) throws -> EventLoopFuture<SDLSuperInboundMessage> {
var registerSuper = SDLRegisterSuper()
registerSuper.version = UInt32(ctx.config.version)
registerSuper.clientID = ctx.config.clientId
registerSuper.devAddr = ctx.devAddr
registerSuper.pubKey = ctx.rsaCipher.pubKey
registerSuper.token = ctx.config.token
let data = try! registerSuper.serializedData()
return self.write(type: .registerSuper, data: data)
}
//
func queryInfo(dst_mac: Data) async throws -> EventLoopFuture<SDLSuperInboundMessage> {
var queryInfo = SDLQueryInfo()
queryInfo.dstMac = dst_mac
return self.write(type: .queryInfo, data: try! queryInfo.serializedData())
}
func unregister(context ctx: SDLContext) throws {
self.send(type: .unregisterSuper, packetId: 0, data: Data())
}
func ping() {
self.send(type: .ping, packetId: 0, data: Data())
}
func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) {
var flow = SDLFlows()
flow.forwardNum = forwardNum
flow.p2PNum = p2pNum
flow.inboundNum = inboundNum
self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData())
}
private func write(type: SDLPacketType, data: Data) -> EventLoopFuture<SDLSuperInboundMessage> {
SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug)
let packetId = idGenerator.nextId()
let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self)
self.callbackPromises[packetId] = promise
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
return promise.futureResult
}
private func send(type: SDLPacketType, packetId: UInt32, data: Data) {
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
}
//
private func fireCallback(message: SDLSuperInboundMessage) {
if let promise = self.callbackPromises[message.msgId] {
self.asyncChannel.channel.eventLoop.execute {
promise.succeed(message)
}
self.callbackPromises.removeValue(forKey: message.msgId)
}
}
deinit {
try! group.syncShutdownGracefully()
}
}
// --MARK:
private struct SDLSuperClientDecoder {
// : <<MsgId:32, Type:8, Body/binary>>
static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? {
guard let msgId = buffer.readInteger(as: UInt32.self),
let type = buffer.readInteger(as: UInt8.self),
let messageType = SDLPacketType(rawValue: type) else {
return nil
}
switch messageType {
case .empty:
return .init(msgId: msgId, packet: .empty)
case .registerSuperAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck))
case .registerSuperNak:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak))
case .peerInfo:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .peerInfo(peerInfo))
case .pong:
return .init(msgId: msgId, packet: .pong)
case .command:
guard let commandVal = buffer.readInteger(as: UInt8.self),
let command = SDLCommandType(rawValue: commandVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch command {
case .changeNetwork:
guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand)))
}
case .event:
guard let eventVal = buffer.readInteger(as: UInt8.self),
let event = SDLEventType(rawValue: eventVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch event {
case .natChanged:
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent)))
case .sendRegister:
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent)))
case .networkShutdown:
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent)))
}
default:
return nil
}
}
}
private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
let len = data.readableBytes
out.writeInteger(UInt16(len))
out.writeBytes(data.readableBytesView)
}
}
private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
return .needMoreData
}
if buffer.readableBytes >= len + 2 {
buffer.moveReaderIndex(forwardBy: 2)
if let bytes = buffer.readBytes(length: Int(len)) {
context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes)))
}
return .continue
} else {
return .needMoreData
}
}
}