tcp actor
This commit is contained in:
parent
aae0b333de
commit
3a0c21280c
316
Sources/Punchnet/SDLSuperClientActor.swift
Normal file
316
Sources/Punchnet/SDLSuperClientActor.swift
Normal file
@ -0,0 +1,316 @@
|
||||
//
|
||||
// SDLWebsocketClient.swift
|
||||
// Tun
|
||||
//
|
||||
// Created by 安礼成 on 2024/3/28.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
|
||||
struct TcpMessage {
|
||||
let packetId: UInt32
|
||||
let type: SDLPacketType
|
||||
let data: Data
|
||||
}
|
||||
|
||||
// --MARK: 和SuperNode的客户端
|
||||
actor SDLSuperClientActor {
|
||||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
private let asyncChannel: NIOAsyncChannel<ByteBuffer,ByteBuffer>
|
||||
private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded)
|
||||
private var callbackPromises: [UInt32:EventLoopPromise<SDLSuperInboundMessage>] = [:]
|
||||
|
||||
public let (eventFlow, inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded)
|
||||
|
||||
// id生成器
|
||||
var idGenerator = SDLIdGenerator(seed: 1)
|
||||
|
||||
let host: String
|
||||
let port: Int
|
||||
|
||||
private var pingCancel: AnyCancellable?
|
||||
|
||||
// 定义事件类型
|
||||
enum SuperEvent {
|
||||
case ready
|
||||
case closed
|
||||
case event(SDLEvent)
|
||||
case command(UInt32, SDLCommand)
|
||||
}
|
||||
|
||||
init(host: String, port: Int) {
|
||||
self.host = host
|
||||
self.port = port
|
||||
}
|
||||
|
||||
init() async throws {
|
||||
let bootstrap = ClientBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
return channel.pipeline.addHandlers([
|
||||
ByteToMessageHandler(FixedHeaderDecoder()),
|
||||
MessageToByteHandler(FixedHeaderEncoder())
|
||||
])
|
||||
}
|
||||
|
||||
self.asyncChannel = try await bootstrap.connect(host: host, port: port)
|
||||
.flatMapThrowing { channel in
|
||||
return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init(
|
||||
inboundType: ByteBuffer.self,
|
||||
outboundType: ByteBuffer.self
|
||||
))
|
||||
}
|
||||
.get()
|
||||
|
||||
try await self.asyncChannel.executeThenClose { inbound, outbound in
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
group.addTask {
|
||||
defer {
|
||||
self.inboundContinuation.finish()
|
||||
}
|
||||
|
||||
for try await var packet in inbound {
|
||||
if let message = SDLSuperClientDecoder.decode(buffer: &packet) {
|
||||
SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning)
|
||||
switch message.packet {
|
||||
case .event(let event):
|
||||
self.inboundContinuation.yield(.event(event))
|
||||
case .command(let command):
|
||||
self.inboundContinuation.yield(.command(message.msgId, command))
|
||||
default:
|
||||
await self.fireCallback(message: message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group.addTask {
|
||||
defer {
|
||||
self.writeContinuation.finish()
|
||||
}
|
||||
|
||||
for try await message in self.writeStream {
|
||||
var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 5)
|
||||
buffer.writeInteger(message.packetId, as: UInt32.self)
|
||||
buffer.writeBytes([message.type.rawValue])
|
||||
buffer.writeBytes(message.data)
|
||||
try await outbound.write(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
try await group.waitForAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func fireCallback(message: SDLSuperInboundMessage) {
|
||||
if let promise = self.callbackPromises[message.msgId] {
|
||||
self.asyncChannel.channel.eventLoop.execute {
|
||||
promise.succeed(message)
|
||||
}
|
||||
self.callbackPromises.removeValue(forKey: message.msgId)
|
||||
}
|
||||
}
|
||||
|
||||
// -- MARK: apis
|
||||
|
||||
func commandAck(packetId: UInt32, ack: SDLCommandAck) {
|
||||
guard let data = try? ack.serializedData() else {
|
||||
return
|
||||
}
|
||||
|
||||
self.send(type: .commandAck, packetId: packetId, data: data)
|
||||
}
|
||||
|
||||
func registerSuper(context ctx: SDLContext) async throws -> SDLSuperInboundMessage {
|
||||
var registerSuper = SDLRegisterSuper()
|
||||
registerSuper.version = UInt32(ctx.config.version)
|
||||
registerSuper.clientID = ctx.config.clientId
|
||||
registerSuper.devAddr = ctx.devAddr
|
||||
registerSuper.pubKey = ctx.rsaCipher.pubKey
|
||||
registerSuper.token = ctx.config.token
|
||||
|
||||
let data = try! registerSuper.serializedData()
|
||||
|
||||
return try await self.write(type: .registerSuper, data: data).get()
|
||||
}
|
||||
|
||||
// 查询目标服务器的相关信息
|
||||
func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage {
|
||||
var queryInfo = SDLQueryInfo()
|
||||
queryInfo.dstMac = dst_mac
|
||||
|
||||
return try await self.write(type: .queryInfo, data: try! queryInfo.serializedData()).get()
|
||||
}
|
||||
|
||||
func unregister(context ctx: SDLContext) throws {
|
||||
self.send(type: .unregisterSuper, packetId: 0, data: Data())
|
||||
}
|
||||
|
||||
func ping() {
|
||||
self.send(type: .ping, packetId: 0, data: Data())
|
||||
}
|
||||
|
||||
func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) {
|
||||
var flow = SDLFlows()
|
||||
flow.forwardNum = forwardNum
|
||||
flow.p2PNum = p2pNum
|
||||
flow.inboundNum = inboundNum
|
||||
|
||||
self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData())
|
||||
}
|
||||
|
||||
// --MARK: ChannelInboundHandler
|
||||
|
||||
public func channelActive(context: ChannelHandlerContext) {
|
||||
self.startPingTicker()
|
||||
}
|
||||
|
||||
func write(type: SDLPacketType, data: Data) -> EventLoopFuture<SDLSuperInboundMessage> {
|
||||
SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug)
|
||||
let packetId = idGenerator.nextId()
|
||||
let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLSuperInboundMessage.self)
|
||||
self.callbackPromises[packetId] = promise
|
||||
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
|
||||
|
||||
return promise.futureResult
|
||||
}
|
||||
|
||||
func send(type: SDLPacketType, packetId: UInt32, data: Data) {
|
||||
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
|
||||
}
|
||||
|
||||
// --MARK: 心跳机制
|
||||
|
||||
private func startPingTicker() {
|
||||
self.pingCancel = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
|
||||
.sink { _ in
|
||||
// 保持和super-node的心跳机制
|
||||
self.ping()
|
||||
}
|
||||
}
|
||||
|
||||
deinit {
|
||||
self.pingCancel?.cancel()
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// --MARK: 编解码器
|
||||
struct SDLSuperClientDecoder {
|
||||
// 消息格式为: <<MsgId:32, Type:8, Body/binary>>
|
||||
static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? {
|
||||
guard let msgId = buffer.readInteger(as: UInt32.self),
|
||||
let type = buffer.readInteger(as: UInt8.self),
|
||||
let messageType = SDLPacketType(rawValue: type) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case .empty:
|
||||
return .init(msgId: msgId, packet: .empty)
|
||||
case .registerSuperAck:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck))
|
||||
|
||||
case .registerSuperNak:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak))
|
||||
|
||||
case .peerInfo:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return .init(msgId: msgId, packet: .peerInfo(peerInfo))
|
||||
case .pong:
|
||||
return .init(msgId: msgId, packet: .pong)
|
||||
|
||||
case .command:
|
||||
guard let commandVal = buffer.readInteger(as: UInt8.self),
|
||||
let command = SDLCommandType(rawValue: commandVal),
|
||||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch command {
|
||||
case .changeNetwork:
|
||||
guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand)))
|
||||
}
|
||||
|
||||
case .event:
|
||||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||||
let event = SDLEventType(rawValue: eventVal),
|
||||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event {
|
||||
case .natChanged:
|
||||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent)))
|
||||
case .sendRegister:
|
||||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent)))
|
||||
case .networkShutdown:
|
||||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent)))
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = ByteBuffer
|
||||
|
||||
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
|
||||
let len = data.readableBytes
|
||||
out.writeInteger(UInt16(len))
|
||||
out.writeBytes(data.readableBytesView)
|
||||
}
|
||||
}
|
||||
|
||||
private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = ByteBuffer
|
||||
|
||||
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
|
||||
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
|
||||
return .needMoreData
|
||||
}
|
||||
|
||||
if buffer.readableBytes >= len + 2 {
|
||||
buffer.moveReaderIndex(forwardBy: 2)
|
||||
if let bytes = buffer.readBytes(length: Int(len)) {
|
||||
context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes)))
|
||||
}
|
||||
return .continue
|
||||
} else {
|
||||
return .needMoreData
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user