punchnet-macos/Tun/Punchnet/SDLUDPHole.swift
2026-04-14 10:31:52 +08:00

217 lines
7.1 KiB
Swift

//
// SDLanServer.swift
// Tun
//
// Created by on 2024/1/31.
//
import Foundation
import NIOCore
import NIOPosix
import SwiftProtobuf
// sn-server
final class SDLUDPHole: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private enum State {
case idle
case ready
case stopping
case stopped
}
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var channel: Channel?
private var state: State = .idle
private var didFinishMessageStream: Bool = false
private let closeStream: AsyncStream<Void>
private let closeContinuation: AsyncStream<Void>.Continuation
public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)>
private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation
//
init() throws {
let (closeStream, closeContinuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1))
self.closeStream = closeStream
self.closeContinuation = closeContinuation
let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048))
self.messageStream = stream
self.messageContinuation = continuation
}
func start() throws -> SocketAddress {
switch self.state {
case .ready:
guard let channel = self.channel else {
preconditionFailure("SDLUDPHole is ready but channel is nil")
}
precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind")
return channel.localAddress!
case .stopping, .stopped:
preconditionFailure("SDLUDPHole cannot be restarted after stop")
case .idle:
break
}
let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
channel.pipeline.addHandler(self)
}
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.channel = channel
self.state = .ready
precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind")
return channel.localAddress!
}
func waitClose() async throws {
for await _ in self.closeStream { }
}
func stop() {
switch self.state {
case .stopping, .stopped:
return
case .idle:
self.state = .stopped
self.finishMessageStream()
self.closeContinuation.finish()
return
case .ready:
self.state = .stopping
}
self.finishMessageStream()
self.channel?.close(promise: nil)
self.channel = nil
}
// --MARK: ChannelInboundHandler delegate
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
guard case .ready = self.state else {
return
}
let envelope = unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) {
SDLLogger.log("[SDLUDPHole] get raw bytes: \(rawBytes.count), from: \(remoteAddress)", for: .debug)
}
do {
if let message = try decode(buffer: &buffer) {
self.messageContinuation.yield((remoteAddress, message))
} else {
SDLLogger.log("[SDLUDPHole] decode message, get null", for: .debug)
}
} catch let err {
SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", for: .debug)
}
}
func channelInactive(context: ChannelHandlerContext) {
self.finishMessageStream()
self.channel = nil
self.state = .stopped
self.closeContinuation.yield(())
self.closeContinuation.finish()
context.close(promise: nil)
}
func errorCaught(context: ChannelHandlerContext, error: any Error) {
self.finishMessageStream()
self.channel = nil
self.state = .stopped
self.closeContinuation.yield(())
self.closeContinuation.finish()
context.close(promise: nil)
}
// MARK:
func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) {
guard case .ready = self.state, let channel = self.channel else {
return
}
var buffer = channel.allocator.buffer(capacity: data.count + 1)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
_ = channel.eventLoop.submit {
channel.writeAndFlush(envelope, promise: nil)
}
}
// --MARK:
private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? {
guard let type = buffer.readInteger(as: UInt8.self),
let packetType = SDLPacketType(rawValue: type) else {
return nil
}
switch packetType {
case .data:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let dataPacket = try? SDLData(serializedBytes: bytes) else {
return nil
}
return .data(dataPacket)
case .register:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerPacket = try? SDLRegister(serializedBytes: bytes) else {
return nil
}
return .register(registerPacket)
case .registerAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else {
return nil
}
return .registerAck(registerAck)
case .stunProbeReply:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else {
return nil
}
return .stunProbeReply(stunProbeReply)
case .stunReply:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let stunReply = try? SDLStunReply(serializedBytes: bytes) else {
return nil
}
return .stunReply(stunReply)
default:
SDLLogger.log("SDLUDPHole decode miss type: \(type)", for: .debug)
return nil
}
}
private func finishMessageStream() {
guard !self.didFinishMessageStream else {
return
}
self.didFinishMessageStream = true
self.messageContinuation.finish()
}
deinit {
self.stop()
try? self.group.syncShutdownGracefully()
}
}