punchnet-macos/Tun/Punchnet/Actors/SDLSuperClientActor.swift

301 lines
12 KiB
Swift

//
// SDLWebsocketClient.swift
// Tun
//
// Created by on 2024/3/28.
//
import Foundation
import NIOCore
import NIOPosix
// --MARK: SuperNode
actor SDLSuperClientActor {
//
private typealias TcpMessage = (packetId: UInt32, type: SDLPacketType, data: Data)
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private let asyncChannel: NIOAsyncChannel<ByteBuffer,ByteBuffer>
private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: TcpMessage.self, bufferingPolicy: .unbounded)
private var continuations: [UInt32:CheckedContinuation<SDLSuperInboundMessage, Error>] = [:]
public let eventFlow: AsyncStream<SuperEvent>
private let inboundContinuation: AsyncStream<SuperEvent>.Continuation
// id
var idGenerator = SDLIdGenerator(seed: 1)
private let logger: SDLLogger
//
enum SuperEvent {
case ready
case event(SDLEvent)
case command(UInt32, SDLCommand)
}
enum SDLSuperClientError: Error {
case timeout
case connectionClosed
case cancelled
}
init(host: String, port: Int, logger: SDLLogger) async throws {
self.logger = logger
(self.eventFlow, self.inboundContinuation) = AsyncStream.makeStream(of: SuperEvent.self, bufferingPolicy: .unbounded)
let bootstrap = ClientBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
return channel.pipeline.addHandlers([
ByteToMessageHandler(FixedHeaderDecoder()),
MessageToByteHandler(FixedHeaderEncoder())
])
}
self.asyncChannel = try await bootstrap.connect(host: host, port: port)
.flatMapThrowing { channel in
return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init(
inboundType: ByteBuffer.self,
outboundType: ByteBuffer.self
))
}
.get()
}
func start() async throws {
try await withTaskCancellationHandler {
try await self.asyncChannel.executeThenClose { inbound, outbound in
self.inboundContinuation.yield(.ready)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
defer {
self.logger.log("[SDLSuperClient] inbound closed", level: .warning)
}
for try await var packet in inbound {
try Task.checkCancellation()
if let message = SDLSuperClientDecoder.decode(buffer: &packet) {
if !message.isPong() {
self.logger.log("[SDLSuperClient] read message: \(message)", level: .debug)
}
switch message.packet {
case .event(let event):
self.inboundContinuation.yield(.event(event))
case .command(let command):
self.inboundContinuation.yield(.command(message.msgId, command))
default:
await self.fireCallback(message: message)
}
}
}
}
group.addTask {
defer {
self.logger.log("[SDLSuperClient] outbound closed", level: .warning)
}
for await (packetId, type, data) in self.writeStream {
try Task.checkCancellation()
var buffer = self.asyncChannel.channel.allocator.buffer(capacity: data.count + 5)
buffer.writeInteger(packetId, as: UInt32.self)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
try await outbound.write(buffer)
}
}
// --MARK:
group.addTask {
defer {
self.logger.log("[SDLSuperClient] ping task closed", level: .warning)
}
while true {
try Task.checkCancellation()
await self.ping()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
}
}
// 退,
if let _ = try await group.next() {
group.cancelAll()
}
}
}
} onCancel: {
self.inboundContinuation.finish()
self.writeContinuation.finish()
self.logger.log("[SDLSuperClient] withTaskCancellationHandler cancel")
}
}
// -- MARK: apis
func unregister() throws {
self.send(type: .unregisterSuper, packetId: 0, data: Data())
}
private func ping() {
self.send(type: .ping, packetId: 0, data: Data())
}
func request(type: SDLPacketType, data: Data, timeout: Duration = .seconds(5)) async throws -> SDLSuperInboundMessage {
let packetId = idGenerator.nextId()
return try await withCheckedThrowingContinuation { cont in
self.continuations[packetId] = cont
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
Task {
try? await Task.sleep(for: timeout)
self.timeout(packetId: packetId)
}
}
}
func send(type: SDLPacketType, packetId: UInt32, data: Data) {
self.writeContinuation.yield(TcpMessage(packetId: packetId, type: type, data: data))
}
//
private func fireCallback(message: SDLSuperInboundMessage) {
guard let cont = self.continuations.removeValue(forKey: message.msgId) else {
return
}
cont.resume(returning: message)
}
private func timeout(packetId: UInt32) {
guard let cont = self.continuations.removeValue(forKey: packetId) else {
return
}
cont.resume(throwing: SDLSuperClientError.timeout)
}
deinit {
try! group.syncShutdownGracefully()
}
}
// --MARK:
private struct SDLSuperClientDecoder {
// : <<MsgId:32, Type:8, Body/binary>>
static func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? {
guard let msgId = buffer.readInteger(as: UInt32.self),
let type = buffer.readInteger(as: UInt8.self),
let messageType = SDLPacketType(rawValue: type) else {
return nil
}
switch messageType {
case .empty:
return .init(msgId: msgId, packet: .empty)
case .registerSuperAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck))
case .registerSuperNak:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak))
case .peerInfo:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .peerInfo(peerInfo))
case .pong:
return .init(msgId: msgId, packet: .pong)
case .command:
guard let commandVal = buffer.readInteger(as: UInt8.self),
let command = SDLCommandType(rawValue: commandVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch command {
case .changeNetwork:
guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand)))
}
case .event:
guard let eventVal = buffer.readInteger(as: UInt8.self),
let event = SDLEventType(rawValue: eventVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch event {
case .natChanged:
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent)))
case .sendRegister:
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent)))
case .networkShutdown:
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
return nil
}
return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent)))
}
default:
return nil
}
}
}
private final class FixedHeaderEncoder: MessageToByteEncoder, @unchecked Sendable {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
let len = data.readableBytes
out.writeInteger(UInt16(len))
out.writeBytes(data.readableBytesView)
}
}
private final class FixedHeaderDecoder: ByteToMessageDecoder, @unchecked Sendable {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
return .needMoreData
}
if buffer.readableBytes >= len + 2 {
buffer.moveReaderIndex(forwardBy: 2)
if let bytes = buffer.readBytes(length: Int(len)) {
context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes)))
}
return .continue
} else {
return .needMoreData
}
}
}