From e36ecd0c29ab8fc8ec2b6b3e69132676cd748741 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Wed, 28 Jan 2026 17:34:02 +0800 Subject: [PATCH] fix hole --- Tun/Punchnet/Actors/SDLUDPHoleActor.swift | 304 +++++++++++----------- Tun/Punchnet/SDLContext.swift | 91 +++---- 2 files changed, 198 insertions(+), 197 deletions(-) diff --git a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift index 575eba5..c3f4dd2 100644 --- a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift +++ b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift @@ -1,3 +1,11 @@ +// +// SDLUDPHoleActor 2.swift +// punchnet +// +// Created by 安礼成 on 2026/1/28. +// + + // // SDLanServer.swift // Tun @@ -13,189 +21,179 @@ import SwiftProtobuf // 处理和sn-server服务器之间的通讯 actor SDLUDPHoleActor { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - private let asyncChannel: NIOAsyncChannel, AddressedEnvelope> - private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPHoleOutboundMessage.self, bufferingPolicy: .unbounded) + private var channel: Channel? - public var localAddress: SocketAddress? public let eventStream: AsyncStream private let eventContinuation: AsyncStream.Continuation - private let logger: SDLLogger - struct UDPHoleOutboundMessage { - let remoteAddress: SocketAddress - let type: SDLPacketType - let data: Data - } - enum UDPHoleEvent { case ready case message(SocketAddress, SDLHoleInboundMessage) } // 启动函数 - init(logger: SDLLogger) async throws { + init(logger: SDLLogger) throws { self.logger = logger - (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded) - - let bootstrap = DatagramBootstrap(group: group) - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - - self.asyncChannel = try await bootstrap.bind(host: "0.0.0.0", port: 0) - .flatMapThrowing { channel in - return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( - inboundType: AddressedEnvelope.self, - outboundType: AddressedEnvelope.self - )) - } - .get() - - self.localAddress = self.asyncChannel.channel.localAddress - self.logger.log("[UDPHole] started and listening on: \(self.localAddress!)", level: .debug) } - func start() async throws { - try await self.asyncChannel.executeThenClose {inbound, outbound in - self.eventContinuation.yield(.ready) - - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - defer { - self.logger.log("[SDLUDPHole] inbound closed", level: .warning) - } - - for try await envelope in inbound { - try Task.checkCancellation() - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - do { - if let message = try Self.decode(buffer: &buffer) { - self.eventContinuation.yield(.message(remoteAddress, message)) - } else { - self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) - } - } catch let err { - self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) - throw err - } - } - } - - group.addTask { - defer { - self.logger.log("[SDLUDPHole] outbound closed", level: .warning) - } - - for await message in self.writeStream { - try Task.checkCancellation() - - var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) - buffer.writeBytes([message.type.rawValue]) - buffer.writeBytes(message.data) - - let envelope = AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer) - try await outbound.write(envelope) - } - } - - if let _ = try await group.next() { - group.cancelAll() - } + func start() throws { + let bootstrap = DatagramBootstrap(group: group) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .channelInitializer { channel in + channel.pipeline.addHandler(SDLUDPHoleHandler(eventContinuation: self.eventContinuation, logger: self.logger)) } - } + + self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() + self.logger.log("[UDPHole] started", level: .debug) } func getLocalAddress() -> SocketAddress? { - return self.localAddress + return self.channel?.localAddress } - // MARK: client-client apis - // 处理写入逻辑 + // MARK: 处理写入逻辑 func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { - let message = UDPHoleOutboundMessage(remoteAddress: remoteAddress, type: type, data: data) - self.writeContinuation.yield(message) - } - - //--MARK: 编解码器 - private static func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { - guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil + guard let channel = self.channel else { + return } - switch packetType { - case .data: - let dataPacket = try SDLData(serializedBytes: bytes) - return .data(dataPacket) - case .register: - let registerPacket = try SDLRegister(serializedBytes: bytes) - return .register(registerPacket) - case .registerAck: - let registerAck = try SDLRegisterAck(serializedBytes: bytes) - return .registerAck(registerAck) - case .stunReply: - let stunReply = try SDLStunReply(serializedBytes: bytes) - return .stunReply(stunReply) - case .stunProbeReply: - let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) - return .stunProbeReply(stunProbeReply) - case .registerSuperAck: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { - return nil - } - return .registerSuperAck(registerSuperAck) - case .registerSuperNak: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { - return nil - } - return .registerSuperNak(registerSuperNak) - - case .peerInfo: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { - return nil - } - - return .peerInfo(peerInfo) - case .event: - guard let eventVal = buffer.readInteger(as: UInt8.self), - let event = SDLEventType(rawValue: eventVal), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch event { - case .natChanged: - guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { - return nil - } - return .event(.natChanged(natChangedEvent)) - case .sendRegister: - guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { - return nil - } - return .event(.sendRegister(sendRegisterEvent)) - case .networkShutdown: - guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { - return nil - } - return .event(.networkShutdown(networkShutdownEvent)) - } - - default: - return nil + var buffer = channel.allocator.buffer(capacity: data.count + 1) + buffer.writeBytes([type.rawValue]) + buffer.writeBytes(data) + + let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) + channel.eventLoop.execute { + channel.writeAndFlush(envelope, promise: nil) } } deinit { try? self.group.syncShutdownGracefully() - self.writeContinuation.finish() self.eventContinuation.finish() + self.channel?.close(promise: nil) + } + +} + +extension SDLUDPHoleActor { + + final class SDLUDPHoleHandler: ChannelInboundHandler { + typealias InboundIn = AddressedEnvelope + + private var eventContinuation: AsyncStream.Continuation + private var logger: SDLLogger + + // --MARK: ChannelInboundHandler delegate + + init(eventContinuation: AsyncStream.Continuation, logger: SDLLogger) { + self.eventContinuation = eventContinuation + self.logger = logger + } + + func channelActive(context: ChannelHandlerContext) { + self.eventContinuation.yield(.ready) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let envelope = unwrapInboundIn(data) + + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + do { + if let message = try decode(buffer: &buffer) { + self.eventContinuation.yield(.message(remoteAddress, message)) + } else { + self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) + } + } catch let err { + self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.eventContinuation.finish() + } + + func errorCaught(context: ChannelHandlerContext, error: any Error) { + context.close(promise: nil) + } + + // --MARK: 编解码器 + private func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { + guard let type = buffer.readInteger(as: UInt8.self), + let packetType = SDLPacketType(rawValue: type), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch packetType { + case .data: + let dataPacket = try SDLData(serializedBytes: bytes) + return .data(dataPacket) + case .register: + let registerPacket = try SDLRegister(serializedBytes: bytes) + return .register(registerPacket) + case .registerAck: + let registerAck = try SDLRegisterAck(serializedBytes: bytes) + return .registerAck(registerAck) + case .stunReply: + let stunReply = try SDLStunReply(serializedBytes: bytes) + return .stunReply(stunReply) + case .stunProbeReply: + let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) + return .stunProbeReply(stunProbeReply) + case .registerSuperAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + return nil + } + return .registerSuperAck(registerSuperAck) + case .registerSuperNak: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + return nil + } + return .registerSuperNak(registerSuperNak) + + case .peerInfo: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + return nil + } + + return .peerInfo(peerInfo) + case .event: + guard let eventVal = buffer.readInteger(as: UInt8.self), + let event = SDLEventType(rawValue: eventVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch event { + case .natChanged: + guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + return nil + } + return .event(.natChanged(natChangedEvent)) + case .sendRegister: + guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + return nil + } + return .event(.sendRegister(sendRegisterEvent)) + case .networkShutdown: + guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + return nil + } + return .event(.networkShutdown(networkShutdownEvent)) + } + + default: + return nil + } + } + } } diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 8bd5fd0..3f4b8b6 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -149,30 +149,16 @@ public class SDLContext { } private func startUDPHole() async throws { - self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger) - + self.udpHoleActor = try SDLUDPHoleActor(logger: self.logger) + try await self.udpHoleActor?.start() + try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - try await self.udpHoleActor?.start() - } - group.addTask { while !Task.isCancelled { try Task.checkCancellation() try await Task.sleep(nanoseconds: 5 * 1_000_000_000) try Task.checkCancellation() - - if let udpHoleActor = self.udpHoleActor { - var stunRequest = SDLStunRequest() - stunRequest.clientID = self.config.clientId - stunRequest.networkID = self.config.networkAddress.networkId - stunRequest.ip = self.config.networkAddress.ip - stunRequest.mac = self.config.networkAddress.mac - stunRequest.natType = UInt32(self.natType.rawValue) - - let remoteAddress = self.config.stunSocketAddress - await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress) - } + await self.sendStunRequest() } } @@ -180,32 +166,7 @@ public class SDLContext { if let eventStream = self.udpHoleActor?.eventStream { for try await event in eventStream { try Task.checkCancellation() - - switch event { - case .ready: - try await self.handleUDPHoleReady() - case .message(let remoteAddress, let message): - switch message { - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) - case .event(let event): - try await self.handleEvent(event: event) - case .stunReply(let stunReply): - await self.handleStunReply(stunReply: stunReply) - case .stunProbeReply(_): - () - case .data(let data): - try await self.handleData(data: data) - case .register(let register): - try await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - } - } + try? await self.dispatchEvent(event: event) } } } @@ -285,6 +246,48 @@ public class SDLContext { await self.udpHoleActor?.send(type: .registerSuper, data: try registerSuper.serializedData(), remoteAddress: self.config.stunSocketAddress) } + private func sendStunRequest() async { + var stunRequest = SDLStunRequest() + stunRequest.clientID = self.config.clientId + stunRequest.networkID = self.config.networkAddress.networkId + stunRequest.ip = self.config.networkAddress.ip + stunRequest.mac = self.config.networkAddress.mac + stunRequest.natType = UInt32(self.natType.rawValue) + + if let stunData = try? stunRequest.serializedData() { + let remoteAddress = self.config.stunSocketAddress + await self.udpHoleActor?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) + } + } + + private func dispatchEvent(event: SDLUDPHoleActor.UDPHoleEvent) async throws { + switch event { + case .ready: + try await self.handleUDPHoleReady() + case .message(let remoteAddress, let message): + switch message { + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) + case .event(let event): + try await self.handleEvent(event: event) + case .stunReply(let stunReply): + await self.handleStunReply(stunReply: stunReply) + case .stunProbeReply(_): + () + case .data(let data): + try await self.handleData(data: data) + case .register(let register): + try await self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + } + } + } + private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))