This commit is contained in:
anlicheng 2026-01-28 17:34:02 +08:00
parent 599a047f5c
commit e36ecd0c29
2 changed files with 198 additions and 197 deletions

View File

@ -1,3 +1,11 @@
//
// SDLUDPHoleActor 2.swift
// punchnet
//
// Created by on 2026/1/28.
//
// //
// SDLanServer.swift // SDLanServer.swift
// Tun // Tun
@ -13,113 +21,107 @@ import SwiftProtobuf
// sn-server // sn-server
actor SDLUDPHoleActor { actor SDLUDPHoleActor {
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private let asyncChannel: NIOAsyncChannel<AddressedEnvelope<ByteBuffer>, AddressedEnvelope<ByteBuffer>> private var channel: Channel?
private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPHoleOutboundMessage.self, bufferingPolicy: .unbounded)
public var localAddress: SocketAddress?
public let eventStream: AsyncStream<UDPHoleEvent> public let eventStream: AsyncStream<UDPHoleEvent>
private let eventContinuation: AsyncStream<UDPHoleEvent>.Continuation private let eventContinuation: AsyncStream<UDPHoleEvent>.Continuation
private let logger: SDLLogger private let logger: SDLLogger
struct UDPHoleOutboundMessage {
let remoteAddress: SocketAddress
let type: SDLPacketType
let data: Data
}
enum UDPHoleEvent { enum UDPHoleEvent {
case ready case ready
case message(SocketAddress, SDLHoleInboundMessage) case message(SocketAddress, SDLHoleInboundMessage)
} }
// //
init(logger: SDLLogger) async throws { init(logger: SDLLogger) throws {
self.logger = logger self.logger = logger
(self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded) (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded)
}
func start() throws {
let bootstrap = DatagramBootstrap(group: group) let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
self.asyncChannel = try await bootstrap.bind(host: "0.0.0.0", port: 0) channel.pipeline.addHandler(SDLUDPHoleHandler(eventContinuation: self.eventContinuation, logger: self.logger))
.flatMapThrowing { channel in
return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init(
inboundType: AddressedEnvelope<ByteBuffer>.self,
outboundType: AddressedEnvelope<ByteBuffer>.self
))
}
.get()
self.localAddress = self.asyncChannel.channel.localAddress
self.logger.log("[UDPHole] started and listening on: \(self.localAddress!)", level: .debug)
} }
func start() async throws { self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
try await self.asyncChannel.executeThenClose {inbound, outbound in self.logger.log("[UDPHole] started", level: .debug)
}
func getLocalAddress() -> SocketAddress? {
return self.channel?.localAddress
}
// MARK:
func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) {
guard let channel = self.channel else {
return
}
var buffer = channel.allocator.buffer(capacity: data.count + 1)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
channel.eventLoop.execute {
channel.writeAndFlush(envelope, promise: nil)
}
}
deinit {
try? self.group.syncShutdownGracefully()
self.eventContinuation.finish()
self.channel?.close(promise: nil)
}
}
extension SDLUDPHoleActor {
final class SDLUDPHoleHandler: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private var eventContinuation: AsyncStream<UDPHoleEvent>.Continuation
private var logger: SDLLogger
// --MARK: ChannelInboundHandler delegate
init(eventContinuation: AsyncStream<UDPHoleEvent>.Continuation, logger: SDLLogger) {
self.eventContinuation = eventContinuation
self.logger = logger
}
func channelActive(context: ChannelHandlerContext) {
self.eventContinuation.yield(.ready) self.eventContinuation.yield(.ready)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
defer {
self.logger.log("[SDLUDPHole] inbound closed", level: .warning)
} }
for try await envelope in inbound { func channelRead(context: ChannelHandlerContext, data: NIOAny) {
try Task.checkCancellation() let envelope = unwrapInboundIn(data)
var buffer = envelope.data var buffer = envelope.data
let remoteAddress = envelope.remoteAddress let remoteAddress = envelope.remoteAddress
do { do {
if let message = try Self.decode(buffer: &buffer) { if let message = try decode(buffer: &buffer) {
self.eventContinuation.yield(.message(remoteAddress, message)) self.eventContinuation.yield(.message(remoteAddress, message))
} else { } else {
self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) self.logger.log("[SDLUDPHole] decode message, get null", level: .warning)
} }
} catch let err { } catch let err {
self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning)
throw err
}
}
}
group.addTask {
defer {
self.logger.log("[SDLUDPHole] outbound closed", level: .warning)
}
for await message in self.writeStream {
try Task.checkCancellation()
var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1)
buffer.writeBytes([message.type.rawValue])
buffer.writeBytes(message.data)
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: message.remoteAddress, data: buffer)
try await outbound.write(envelope)
} }
} }
if let _ = try await group.next() { func channelInactive(context: ChannelHandlerContext) {
group.cancelAll() self.eventContinuation.finish()
}
}
}
} }
func getLocalAddress() -> SocketAddress? { func errorCaught(context: ChannelHandlerContext, error: any Error) {
return self.localAddress context.close(promise: nil)
}
// MARK: client-client apis
//
func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) {
let message = UDPHoleOutboundMessage(remoteAddress: remoteAddress, type: type, data: data)
self.writeContinuation.yield(message)
} }
// --MARK: // --MARK:
private static func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { private func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? {
guard let type = buffer.readInteger(as: UInt8.self), guard let type = buffer.readInteger(as: UInt8.self),
let packetType = SDLPacketType(rawValue: type), let packetType = SDLPacketType(rawValue: type),
let bytes = buffer.readBytes(length: buffer.readableBytes) else { let bytes = buffer.readBytes(length: buffer.readableBytes) else {
@ -192,10 +194,6 @@ actor SDLUDPHoleActor {
} }
} }
deinit {
try? self.group.syncShutdownGracefully()
self.writeContinuation.finish()
self.eventContinuation.finish()
} }
} }

View File

@ -149,30 +149,16 @@ public class SDLContext {
} }
private func startUDPHole() async throws { private func startUDPHole() async throws {
self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) self.udpHoleActor = try SDLUDPHoleActor(logger: self.logger)
try await self.udpHoleActor?.start()
try await withThrowingTaskGroup(of: Void.self) { group in try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await self.udpHoleActor?.start()
}
group.addTask { group.addTask {
while !Task.isCancelled { while !Task.isCancelled {
try Task.checkCancellation() try Task.checkCancellation()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000) try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
try Task.checkCancellation() try Task.checkCancellation()
await self.sendStunRequest()
if let udpHoleActor = self.udpHoleActor {
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
let remoteAddress = self.config.stunSocketAddress
await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress)
}
} }
} }
@ -180,32 +166,7 @@ public class SDLContext {
if let eventStream = self.udpHoleActor?.eventStream { if let eventStream = self.udpHoleActor?.eventStream {
for try await event in eventStream { for try await event in eventStream {
try Task.checkCancellation() try Task.checkCancellation()
try? await self.dispatchEvent(event: event)
switch event {
case .ready:
try await self.handleUDPHoleReady()
case .message(let remoteAddress, let message):
switch message {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunReply(let stunReply):
await self.handleStunReply(stunReply: stunReply)
case .stunProbeReply(_):
()
case .data(let data):
try await self.handleData(data: data)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
} }
} }
} }
@ -285,6 +246,48 @@ public class SDLContext {
await self.udpHoleActor?.send(type: .registerSuper, data: try registerSuper.serializedData(), remoteAddress: self.config.stunSocketAddress) await self.udpHoleActor?.send(type: .registerSuper, data: try registerSuper.serializedData(), remoteAddress: self.config.stunSocketAddress)
} }
private func sendStunRequest() async {
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
if let stunData = try? stunRequest.serializedData() {
let remoteAddress = self.config.stunSocketAddress
await self.udpHoleActor?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
}
}
private func dispatchEvent(event: SDLUDPHoleActor.UDPHoleEvent) async throws {
switch event {
case .ready:
try await self.handleUDPHoleReady()
case .message(let remoteAddress, let message):
switch message {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunReply(let stunReply):
await self.handleStunReply(stunReply: stunReply)
case .stunProbeReply(_):
()
case .data(let data):
try await self.handleData(data: data)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa // rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))