From df236d4c1f291caed205e6d02e6f906afd5725b6 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Thu, 29 Jan 2026 23:23:53 +0800 Subject: [PATCH] fix --- Tun/Punchnet/SDLContext.swift | 73 +++++---- Tun/Punchnet/SDLDNSClient.swift | 56 +++---- Tun/Punchnet/SDLMessage.swift | 7 +- Tun/Punchnet/SDLUDPHole.swift | 254 ++++++++++++++++---------------- 4 files changed, 203 insertions(+), 187 deletions(-) diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index aa8ecb9..9168eb3 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -119,12 +119,57 @@ public class SDLContext { } } + // 处理event事件流 group.addTask { if let eventStream = self.udpHole?.eventStream { for try await event in eventStream { try Task.checkCancellation() Task { - try await self.dispatchEvent(event: event) + switch event { + case .ready: + try await self.handleUDPHoleReady() + case .closed: + () + } + } + } + } + } + + // 处理数据流 + group.addTask { + if let dataStream = self.udpHole?.dataStream { + for try await data in dataStream { + try Task.checkCancellation() + Task { + try await self.handleData(data: data) + } + } + } + } + + // 处理signal信号流 + group.addTask { + if let signalStream = self.udpHole?.signalStream { + for try await(remoteAddress, signal) in signalStream { + try Task.checkCancellation() + Task { + switch signal { + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) + case .event(let event): + try await self.handleEvent(event: event) + case .stunProbeReply(let probeReply): + await self.proberActor?.handleProbeReply(reply: probeReply) + case .register(let register): + try await self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + } } } } @@ -215,32 +260,6 @@ public class SDLContext { } } - private func dispatchEvent(event: SDLUDPHole.UDPHoleEvent) async throws { - switch event { - case .ready: - try await self.handleUDPHoleReady() - case .message(let remoteAddress, let message): - switch message { - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) - case .event(let event): - try await self.handleEvent(event: event) - case .stunProbeReply(let probeReply): - await self.proberActor?.handleProbeReply(reply: probeReply) - case .data(let data): - try await self.handleData(data: data) - case .register(let register): - try await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - } - } - } - private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) diff --git a/Tun/Punchnet/SDLDNSClient.swift b/Tun/Punchnet/SDLDNSClient.swift index c35378e..1956fa9 100644 --- a/Tun/Punchnet/SDLDNSClient.swift +++ b/Tun/Punchnet/SDLDNSClient.swift @@ -10,7 +10,9 @@ import NIOCore import NIOPosix // 处理和sn-server服务器之间的通讯 -final class SDLDNSClient { +final class SDLDNSClient: ChannelInboundHandler { + typealias InboundIn = AddressedEnvelope + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? @@ -31,12 +33,31 @@ final class SDLDNSClient { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in - channel.pipeline.addHandler(SDLDNSInboundHandler(packetContinuation: self.packetContinuation, logger: self.logger)) + channel.pipeline.addHandler(self) } self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[DNSClient] started", level: .debug) } + + // --MARK: ChannelInboundHandler delegate + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let envelope = unwrapInboundIn(data) + + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + self.logger.log("[DNSClient] read data: \(buffer), from: \(remoteAddress)", level: .debug) + + let len = buffer.readableBytes + if let bytes = buffer.readBytes(length: len) { + self.packetContinuation.yield(Data(bytes)) + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.packetContinuation.finish() + } func forward(ipPacket: IPPacket) { guard let channel = self.channel else { @@ -59,37 +80,6 @@ final class SDLDNSClient { extension SDLDNSClient { - private final class SDLDNSInboundHandler: ChannelInboundHandler { - typealias InboundIn = AddressedEnvelope - - private var packetContinuation: AsyncStream.Continuation - private var logger: SDLLogger - - // --MARK: ChannelInboundHandler delegate - - init(packetContinuation: AsyncStream.Continuation, logger: SDLLogger) { - self.packetContinuation = packetContinuation - self.logger = logger - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let envelope = unwrapInboundIn(data) - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - self.logger.log("[DNSClient] read data: \(buffer), from: \(remoteAddress)", level: .debug) - - let len = buffer.readableBytes - if let bytes = buffer.readBytes(length: len) { - self.packetContinuation.yield(Data(bytes)) - } - } - - func channelInactive(context: ChannelHandlerContext) { - self.packetContinuation.finish() - } - } - struct Helper { static let dnsServer: String = "100.100.100.100" // dns请求包的目标地址 diff --git a/Tun/Punchnet/SDLMessage.swift b/Tun/Punchnet/SDLMessage.swift index 079a5de..c5f75cc 100644 --- a/Tun/Punchnet/SDLMessage.swift +++ b/Tun/Punchnet/SDLMessage.swift @@ -93,8 +93,12 @@ extension SDLStunProbeReply { } // --MARK: 进来的消息, 这里需要采用代数类型来表示 +enum SDLHoleMessage { + case data(SDLData) + case signal(SDLHoleSignal) +} -enum SDLHoleInboundMessage { +enum SDLHoleSignal { case registerSuperAck(SDLRegisterSuperAck) case registerSuperNak(SDLRegisterSuperNak) @@ -103,7 +107,6 @@ enum SDLHoleInboundMessage { case stunProbeReply(SDLStunProbeReply) - case data(SDLData) case register(SDLRegister) case registerAck(SDLRegisterAck) } diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index c81ad2e..34aeed3 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -19,36 +19,87 @@ import NIOPosix import SwiftProtobuf // 处理和sn-server服务器之间的通讯 -final class SDLUDPHole { +final class SDLUDPHole: ChannelInboundHandler { + typealias InboundIn = AddressedEnvelope + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? - public let eventStream: AsyncStream - private let eventContinuation: AsyncStream.Continuation + public let signalStream: AsyncStream<(SocketAddress, SDLHoleSignal)> + private let signalContinuation: AsyncStream<(SocketAddress, SDLHoleSignal)>.Continuation + + public let dataStream: AsyncStream + private let dataContinuation: AsyncStream.Continuation + + public let eventStream: AsyncStream + private let eventContinuation: AsyncStream.Continuation + private let logger: SDLLogger - enum UDPHoleEvent { + enum HoleEvent { case ready - case message(SocketAddress, SDLHoleInboundMessage) + case closed } // 启动函数 init(logger: SDLLogger) throws { self.logger = logger - (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded) + (self.signalStream, self.signalContinuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleSignal).self, bufferingPolicy: .unbounded) + (self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded) + (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: HoleEvent.self, bufferingPolicy: .unbounded) } func start() throws { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in - channel.pipeline.addHandler(SDLUDPHoleHandler(eventContinuation: self.eventContinuation, logger: self.logger)) + channel.pipeline.addHandler(self) } self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() self.logger.log("[UDPHole] started", level: .debug) } + // --MARK: ChannelInboundHandler delegate + + func channelActive(context: ChannelHandlerContext) { + self.eventContinuation.yield(.ready) + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let envelope = unwrapInboundIn(data) + + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + do { + if let message = try decode(buffer: &buffer) { + switch message { + case .data(let data): + self.dataContinuation.yield(data) + case .signal(let signal): + self.signalContinuation.yield((remoteAddress, signal)) + } + } else { + self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) + } + } catch let err { + self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.signalContinuation.finish() + self.dataContinuation.finish() + self.eventContinuation.yield(.closed) + self.eventContinuation.finish() + + context.close(promise: nil) + } + + func errorCaught(context: ChannelHandlerContext, error: any Error) { + context.close(promise: nil) + } + func getLocalAddress() -> SocketAddress? { return self.channel?.localAddress } @@ -69,6 +120,77 @@ final class SDLUDPHole { } } + // --MARK: 编解码器 + private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { + guard let type = buffer.readInteger(as: UInt8.self), + let packetType = SDLPacketType(rawValue: type), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch packetType { + case .data: + let dataPacket = try SDLData(serializedBytes: bytes) + return .data(dataPacket) + case .register: + let registerPacket = try SDLRegister(serializedBytes: bytes) + return .signal(.register(registerPacket)) + case .registerAck: + let registerAck = try SDLRegisterAck(serializedBytes: bytes) + return .signal(.registerAck(registerAck)) + case .stunProbeReply: + let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) + return .signal(.stunProbeReply(stunProbeReply)) + case .registerSuperAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + return nil + } + return .signal(.registerSuperAck(registerSuperAck)) + case .registerSuperNak: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + return nil + } + return .signal(.registerSuperNak(registerSuperNak)) + + case .peerInfo: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + return nil + } + + return .signal(.peerInfo(peerInfo)) + case .event: + guard let eventVal = buffer.readInteger(as: UInt8.self), + let event = SDLEventType(rawValue: eventVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + return nil + } + + switch event { + case .natChanged: + guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + return nil + } + return .signal(.event(.natChanged(natChangedEvent))) + case .sendRegister: + guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + return nil + } + return .signal(.event(.sendRegister(sendRegisterEvent))) + case .networkShutdown: + guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + return nil + } + return .signal(.event(.networkShutdown(networkShutdownEvent))) + } + + default: + return nil + } + } + deinit { try? self.group.syncShutdownGracefully() self.eventContinuation.finish() @@ -76,121 +198,3 @@ final class SDLUDPHole { } } - -extension SDLUDPHole { - - final class SDLUDPHoleHandler: ChannelInboundHandler { - typealias InboundIn = AddressedEnvelope - - private var eventContinuation: AsyncStream.Continuation - private var logger: SDLLogger - - // --MARK: ChannelInboundHandler delegate - - init(eventContinuation: AsyncStream.Continuation, logger: SDLLogger) { - self.eventContinuation = eventContinuation - self.logger = logger - } - - func channelActive(context: ChannelHandlerContext) { - self.eventContinuation.yield(.ready) - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let envelope = unwrapInboundIn(data) - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - do { - if let message = try decode(buffer: &buffer) { - self.eventContinuation.yield(.message(remoteAddress, message)) - } else { - self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) - } - } catch let err { - self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) - } - } - - func channelInactive(context: ChannelHandlerContext) { - self.eventContinuation.finish() - } - - func errorCaught(context: ChannelHandlerContext, error: any Error) { - context.close(promise: nil) - } - - // --MARK: 编解码器 - private func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { - guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch packetType { - case .data: - let dataPacket = try SDLData(serializedBytes: bytes) - return .data(dataPacket) - case .register: - let registerPacket = try SDLRegister(serializedBytes: bytes) - return .register(registerPacket) - case .registerAck: - let registerAck = try SDLRegisterAck(serializedBytes: bytes) - return .registerAck(registerAck) - case .stunProbeReply: - let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) - return .stunProbeReply(stunProbeReply) - case .registerSuperAck: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { - return nil - } - return .registerSuperAck(registerSuperAck) - case .registerSuperNak: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { - return nil - } - return .registerSuperNak(registerSuperNak) - - case .peerInfo: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { - return nil - } - - return .peerInfo(peerInfo) - case .event: - guard let eventVal = buffer.readInteger(as: UInt8.self), - let event = SDLEventType(rawValue: eventVal), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - return nil - } - - switch event { - case .natChanged: - guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { - return nil - } - return .event(.natChanged(natChangedEvent)) - case .sendRegister: - guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { - return nil - } - return .event(.sendRegister(sendRegisterEvent)) - case .networkShutdown: - guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { - return nil - } - return .event(.networkShutdown(networkShutdownEvent)) - } - - default: - return nil - } - } - - } - -}