From 3a0c21280cff641e8b2d40afdd0ab48f4bb8e54b Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Fri, 1 Aug 2025 10:31:23 +0800 Subject: [PATCH] tcp actor --- Sources/Punchnet/SDLSuperClientActor.swift | 316 +++++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 Sources/Punchnet/SDLSuperClientActor.swift diff --git a/Sources/Punchnet/SDLSuperClientActor.swift b/Sources/Punchnet/SDLSuperClientActor.swift new file mode 100644 index 0000000..f57441f --- /dev/null +++ b/Sources/Punchnet/SDLSuperClientActor.swift @@ -0,0 +1,316 @@ +// +// SDLWebsocketClient.swift +// Tun +// +// Created by 安礼成 on 2024/3/28. +// + +import Foundation +import NIOCore +import NIOPosix + +struct TcpMessage { + let packetId: UInt32 + let type: SDLPacketType + let data: Data +} + +// --MARK: 和SuperNode的客户端 +actor SDLSuperClientActor { + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + private let asyncChannel: NIOAsyncChannel + private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded) + private var callbackPromises: [UInt32:EventLoopPromise] = [:] + + public let (eventFlow, inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded) + + // id生成器 + var idGenerator = SDLIdGenerator(seed: 1) + + let host: String + let port: Int + + private var pingCancel: AnyCancellable? + + // 定义事件类型 + enum SuperEvent { + case ready + case closed + case event(SDLEvent) + case command(UInt32, SDLCommand) + } + + init(host: String, port: Int) { + self.host = host + self.port = port + } + + init() async throws { + let bootstrap = ClientBootstrap(group: self.group) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .channelInitializer { channel in + return channel.pipeline.addHandlers([ + ByteToMessageHandler(FixedHeaderDecoder()), + MessageToByteHandler(FixedHeaderEncoder()) + ]) + } + + self.asyncChannel = try await bootstrap.connect(host: host, port: port) + .flatMapThrowing { channel in + return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( + inboundType: ByteBuffer.self, + outboundType: ByteBuffer.self + )) + } + .get() + + try await self.asyncChannel.executeThenClose { inbound, outbound in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + self.inboundContinuation.finish() + } + + for try await var packet in inbound { + if let message = SDLSuperClientDecoder.decode(buffer: &packet) { + SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning) + switch message.packet { + case .event(let event): + self.inboundContinuation.yield(.event(event)) + case .command(let command): + self.inboundContinuation.yield(.command(message.msgId, command)) + default: + await self.fireCallback(message: message) + } + } + } + } + + group.addTask { + defer { + self.writeContinuation.finish() + } + + for try await message in self.writeStream { + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5) + buffer.writeInteger(message.packetId, as: UInt32.self) + buffer.writeBytes([message.type.rawValue]) + buffer.writeBytes(message.data) + try await outbound.write(buffer) + } + } + + try await group.waitForAll() + } + } + } + + private func fireCallback(message: SDLSuperInboundMessage) { + if let promise = self.callbackPromises[message.msgId] { + self.asyncChannel.channel.eventLoop.execute { + promise.succeed(message) + } + self.callbackPromises.removeValue(forKey: message.msgId) + } + } + + // -- MARK: apis + + func commandAck(packetId: UInt32, ack: SDLCommandAck) { + guard let data = try? ack.serializedData() else { + return + } + + self.send(type: .commandAck, packetId: packetId, data: data) + } + + func registerSuper(context ctx: SDLContext) async throws -> SDLSuperInboundMessage { + var registerSuper = SDLRegisterSuper() + registerSuper.version = UInt32(ctx.config.version) + registerSuper.clientID = ctx.config.clientId + registerSuper.devAddr = ctx.devAddr + registerSuper.pubKey = ctx.rsaCipher.pubKey + registerSuper.token = ctx.config.token + + let data = try! registerSuper.serializedData() + + return try await self.write(type: .registerSuper, data: data).get() + } + + // 查询目标服务器的相关信息 + func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage { + var queryInfo = SDLQueryInfo() + queryInfo.dstMac = dst_mac + + return try await self.write(type: .queryInfo, data: try! queryInfo.serializedData()).get() + } + + func unregister(context ctx: SDLContext) throws { + self.send(type: .unregisterSuper, packetId: 0, data: Data()) + } + + func ping() { + self.send(type: .ping, packetId: 0, data: Data()) + } + + func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) { + var flow = SDLFlows() + flow.forwardNum = forwardNum + flow.p2PNum = p2pNum + flow.inboundNum = inboundNum + + self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData()) + } + + // --MARK: ChannelInboundHandler + + public func channelActive(context: ChannelHandlerContext) { + self.startPingTicker() + } + + func write(type: SDLPacketType, data: Data) -> EventLoopFuture { + SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug) + let packetId = idGenerator.nextId() + let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self) + self.callbackPromises[packetId] = promise + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) + + return promise.futureResult + } + + func send(type: SDLPacketType, packetId: UInt32, data: Data) { + self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data)) + } + + // --MARK: 心跳机制 + + private func startPingTicker() { + self.pingCancel = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect() + .sink { _ in + // 保持和super-node的心跳机制 + self.ping() + } + } + + deinit { + self.pingCancel?.cancel() + try! group.syncShutdownGracefully() + } + +} + +// --MARK: 编解码器 +struct SDLSuperClientDecoder { + // 消息格式为: <> + static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? { + guard let msgId = buffer.readInteger(as: UInt32.self), + let type = buffer.readInteger(as: UInt8.self), + let messageType = SDLPacketType(rawValue: type) else { + return nil + } + + switch messageType { + case .empty: + return .init(msgId: msgId, packet: .empty) + case .registerSuperAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck)) + + case .registerSuperNak: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak)) + + case .peerInfo: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + return nil + } + + return .init(msgId: msgId, packet: .peerInfo(peerInfo)) + case .pong: + return .init(msgId: msgId, packet: .pong) + + case .command: + guard let commandVal = buffer.readInteger(as: UInt8.self), + let command = SDLCommandType(rawValue: commandVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch command { + case .changeNetwork: + guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else { + return nil + } + + return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand))) + } + + case .event: + guard let eventVal = buffer.readInteger(as: UInt8.self), + let event = SDLEventType(rawValue: eventVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch event { + case .natChanged: + guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent))) + case .sendRegister: + guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent))) + case .networkShutdown: + guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + return nil + } + return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent))) + } + + default: + return nil + } + } +} + +private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func encode(data: ByteBuffer, out: inout ByteBuffer) throws { + let len = data.readableBytes + out.writeInteger(UInt16(len)) + out.writeBytes(data.readableBytesView) + } +} + +private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { + guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { + return .needMoreData + } + + if buffer.readableBytes >= len + 2 { + buffer.moveReaderIndex(forwardBy: 2) + if let bytes = buffer.readBytes(length: Int(len)) { + context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes))) + } + return .continue + } else { + return .needMoreData + } + } +}