185 lines
5.8 KiB
Swift
185 lines
5.8 KiB
Swift
//
|
|
// SDLUDPHoleActor 2.swift
|
|
// punchnet
|
|
//
|
|
// Created by 安礼成 on 2026/1/28.
|
|
//
|
|
|
|
|
|
//
|
|
// SDLanServer.swift
|
|
// Tun
|
|
//
|
|
// Created by 安礼成 on 2024/1/31.
|
|
//
|
|
|
|
import Foundation
|
|
import NIOCore
|
|
import NIOPosix
|
|
import SwiftProtobuf
|
|
|
|
// 处理和sn-server服务器之间的通讯
|
|
final class SDLUDPHole: ChannelInboundHandler {
|
|
typealias InboundIn = AddressedEnvelope<ByteBuffer>
|
|
|
|
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
|
private var channel: Channel?
|
|
|
|
public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)>
|
|
private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation
|
|
|
|
// 解决channelready的问题
|
|
private var cont: CheckedContinuation<Void, Never>?
|
|
private var isReady: Bool = false
|
|
|
|
enum HoleEvent {
|
|
case ready
|
|
case closed
|
|
}
|
|
|
|
// 启动函数
|
|
init() throws {
|
|
(self.messageStream, self.messageContinuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .unbounded)
|
|
}
|
|
|
|
func start() throws {
|
|
let bootstrap = DatagramBootstrap(group: group)
|
|
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
|
.channelInitializer { channel in
|
|
channel.pipeline.addHandler(self)
|
|
}
|
|
|
|
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
|
|
SDLLogger.shared.log("[UDPHole] started", level: .debug)
|
|
self.channel = channel
|
|
}
|
|
|
|
func channelIsActived() async {
|
|
await withCheckedContinuation { c in
|
|
if isReady {
|
|
c.resume()
|
|
} else {
|
|
self.cont = c
|
|
}
|
|
}
|
|
}
|
|
|
|
func waitClose() async throws {
|
|
try await self.channel?.closeFuture.get()
|
|
}
|
|
|
|
// --MARK: ChannelInboundHandler delegate
|
|
|
|
func channelActive(context: ChannelHandlerContext) {
|
|
guard !isReady else {
|
|
return
|
|
}
|
|
self.isReady = true
|
|
self.cont?.resume()
|
|
self.cont = nil
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let envelope = unwrapInboundIn(data)
|
|
|
|
var buffer = envelope.data
|
|
let remoteAddress = envelope.remoteAddress
|
|
|
|
if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) {
|
|
SDLLogger.shared.log("[SDLUDPHole] get raw bytes: \(rawBytes), from: \(remoteAddress)")
|
|
}
|
|
|
|
do {
|
|
if let message = try decode(buffer: &buffer) {
|
|
self.messageContinuation.yield((remoteAddress, message))
|
|
} else {
|
|
SDLLogger.shared.log("[SDLUDPHole] decode message, get null", level: .warning)
|
|
}
|
|
} catch let err {
|
|
SDLLogger.shared.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning)
|
|
}
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
self.messageContinuation.finish()
|
|
context.close(promise: nil)
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: any Error) {
|
|
context.close(promise: nil)
|
|
}
|
|
|
|
func getLocalAddress() -> SocketAddress? {
|
|
return self.channel?.localAddress
|
|
}
|
|
|
|
// MARK: 处理写入逻辑
|
|
func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) {
|
|
guard let channel = self.channel else {
|
|
return
|
|
}
|
|
|
|
var buffer = channel.allocator.buffer(capacity: data.count + 1)
|
|
buffer.writeBytes([type.rawValue])
|
|
buffer.writeBytes(data)
|
|
|
|
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
|
|
_ = channel.eventLoop.submit {
|
|
channel.writeAndFlush(envelope, promise: nil)
|
|
}
|
|
}
|
|
|
|
// --MARK: 编解码器
|
|
private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? {
|
|
guard let type = buffer.readInteger(as: UInt8.self),
|
|
let packetType = SDLPacketType(rawValue: type) else {
|
|
SDLLogger.shared.log("[SDLUDPHole] decode error 11")
|
|
return nil
|
|
}
|
|
|
|
switch packetType {
|
|
case .data:
|
|
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
|
let dataPacket = try? SDLData(serializedBytes: bytes) else {
|
|
return nil
|
|
}
|
|
return .data(dataPacket)
|
|
case .register:
|
|
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
|
let registerPacket = try? SDLRegister(serializedBytes: bytes) else {
|
|
return nil
|
|
}
|
|
return .register(registerPacket)
|
|
case .registerAck:
|
|
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
|
let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else {
|
|
return nil
|
|
}
|
|
return .registerAck(registerAck)
|
|
case .stunProbeReply:
|
|
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
|
let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else {
|
|
return nil
|
|
}
|
|
return .stunProbeReply(stunProbeReply)
|
|
case .stunReply:
|
|
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
|
let stunReply = try? SDLStunReply(serializedBytes: bytes) else {
|
|
return nil
|
|
}
|
|
return .stunReply(stunReply)
|
|
default:
|
|
SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)")
|
|
|
|
return nil
|
|
}
|
|
|
|
}
|
|
|
|
deinit {
|
|
try? self.group.syncShutdownGracefully()
|
|
self.channel?.close(promise: nil)
|
|
}
|
|
|
|
}
|