From 5ce84689599d353f77ec1f1027d89d52b349e1fe Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Tue, 14 Apr 2026 18:45:27 +0800 Subject: [PATCH] fix udp hole --- Tun/Punchnet/Actors/SDLContextActor.swift | 78 +------- Tun/Punchnet/Actors/SDLPuncherActor.swift | 13 +- Tun/Punchnet/SDLUDPHole.swift | 3 +- Tun/Punchnet/SDLUDPHoleV6.swift | 214 ---------------------- 4 files changed, 21 insertions(+), 287 deletions(-) delete mode 100644 Tun/Punchnet/SDLUDPHoleV6.swift diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index 65bd4a3..9cc7d60 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -48,9 +48,7 @@ actor SDLContextActor { // 依赖的变量 private var udpHole: SDLUDPHole? private var udpHoleWorkers: [Task]? - private var udpHoleV6: SDLUDPHoleV6? - private var udpHoleV6Workers: [Task]? - private var udpHoleV6LocalAddress: SocketAddress? + private var udpHoleLocalAddress: SocketAddress? // dns的client对象 private var dnsClient: DNSCloudClient? @@ -150,13 +148,6 @@ actor SDLContextActor { try await udpHole.waitClose() SDLLogger.log("[SDLContext] udp closed!!!!") } - - await self.supervisor.addWorker(name: "udpHoleV6") { - let udpHoleV6 = try await self.startUDPHoleV6() - SDLLogger.log("[SDLContext] udpV6 running!!!!") - try await udpHoleV6.waitClose() - SDLLogger.log("[SDLContext] udpV6 closed!!!!") - } } public func waitForReady() async throws { @@ -214,7 +205,7 @@ actor SDLContextActor { await self.handleRegisterSuperNak(nakPacket: registerSuperNak) case .peerInfo(let peerInfo): //SDLLogger.shared.log("[SDLContext] peer message: \(peerInfo)") - await self.puncherActor.handlePeerInfo(using: self.udpHole, udpHoleV6: self.udpHoleV6, peerInfo: peerInfo) + await self.puncherActor.handlePeerInfo(using: self.udpHole, peerInfo: peerInfo) case .event(let event): await self.handleEvent(event: event) case .policyReponse(let policyResponse): @@ -364,6 +355,7 @@ actor SDLContextActor { } self.udpHole = udpHole + self.udpHoleLocalAddress = localAddress self.udpHoleWorkers = [pingTask, messageTask] // 开始探测nat的类型 @@ -372,42 +364,6 @@ actor SDLContextActor { return udpHole } - private func startUDPHoleV6() async throws -> SDLUDPHoleV6 { - self.udpHoleV6Workers?.forEach {$0.cancel()} - self.udpHoleV6Workers = nil - - // 启动udp ipv6服务器 - let udpHoleV6 = try SDLUDPHoleV6() - let localAddress = try udpHoleV6.start() - SDLLogger.log("[SDLContext] udpHoleV6 started, on address: \(localAddress)") - - // ip地址只会收到到 register:registerAck | data - let messageTask = Task.detached { - for await (remoteAddress, message) in udpHoleV6.messageStream { - if Task.isCancelled { - break - } - - switch message { - case .register(let register): - try? await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - case .data(let data): - try? await self.handleHoleData(data: data) - default: - () - } - } - } - - self.udpHoleV6 = udpHoleV6 - self.udpHoleV6LocalAddress = localAddress - self.udpHoleV6Workers = [messageTask] - - return udpHoleV6 - } - // 处理context的停止问题 public func stop() async { self.resumeReadyWaiters(.failure(CancellationError())) @@ -421,12 +377,7 @@ actor SDLContextActor { self.udpHoleWorkers = nil self.udpHole?.stop() self.udpHole = nil - - self.udpHoleV6Workers?.forEach { $0.cancel() } - self.udpHoleV6Workers = nil - self.udpHoleV6?.stop() - self.udpHoleV6 = nil - self.udpHoleV6LocalAddress = nil + self.udpHoleLocalAddress = nil self.quicWorker?.cancel() self.quicWorker = nil @@ -516,7 +467,7 @@ actor SDLContextActor { stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) stunRequest.sessionToken = sessionToken - if let v6Info = self.makeUDPHoleV6Info() { + if let v6Info = self.makeCurrentV6Info() { stunRequest.v6Info = v6Info } @@ -892,8 +843,7 @@ actor SDLContextActor { let networkAddr = self.config.networkAddress // 将数据封装层2层的数据包 let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data) - guard (self.udpHole != nil || self.udpHoleV6 != nil), - let dataCipher = self.dataCipher, + guard let dataCipher = self.dataCipher, let encodedPacket = try? dataCipher.encrypt(plainText: layerPacket.marshal()) else { return } @@ -992,18 +942,11 @@ actor SDLContextActor { // 发送给peer的数据 private func sendPeerPacket(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { - switch remoteAddress { - case .v4: - self.udpHole?.send(type: type, data: data, remoteAddress: remoteAddress) - case .v6: - self.udpHoleV6?.send(type: type, data: data, remoteAddress: remoteAddress) - default: - SDLLogger.log("[SDLContext] unsupported remoteAddress: \(remoteAddress)", for: .debug) - } + self.udpHole?.send(type: type, data: data, remoteAddress: remoteAddress) } - private func makeUDPHoleV6Info() -> SDLV6Info? { - guard let port = self.udpHoleV6LocalAddress?.port else { + private func makeCurrentV6Info() -> SDLV6Info? { + guard let port = self.udpHoleLocalAddress?.port else { return nil } @@ -1098,8 +1041,7 @@ actor SDLContextActor { deinit { self.udpHole = nil - self.udpHoleV6 = nil - self.udpHoleV6LocalAddress = nil + self.udpHoleLocalAddress = nil self.dnsClient = nil } } diff --git a/Tun/Punchnet/Actors/SDLPuncherActor.swift b/Tun/Punchnet/Actors/SDLPuncherActor.swift index d7e3fbd..3a94618 100644 --- a/Tun/Punchnet/Actors/SDLPuncherActor.swift +++ b/Tun/Punchnet/Actors/SDLPuncherActor.swift @@ -93,7 +93,7 @@ actor SDLPuncherActor { quicClient.send(type: .queryInfo, data: queryData) } - func handlePeerInfo(using udpHole: SDLUDPHole?, udpHoleV6: SDLUDPHoleV6?, peerInfo: SDLPeerInfo) async { + func handlePeerInfo(using udpHole: SDLUDPHole?, peerInfo: SDLPeerInfo) async { let now = Date() self.cleanupExpiredEntries(now: now) @@ -108,6 +108,11 @@ actor SDLPuncherActor { entry.markCoolingDown() self.requestEntries[peerInfo.dstMac] = entry + guard let udpHole else { + SDLLogger.log("[SDLPuncherActor] udpHole is nil when peerInfo arrived", for: .debug) + return + } + var register = SDLRegister() register.networkID = entry.request.networkId register.srcMac = entry.request.srcMac @@ -119,7 +124,7 @@ actor SDLPuncherActor { } // 并行发送register请求 - if let udpHole, peerInfo.hasV4Info { + if peerInfo.hasV4Info { if let remoteAddress = try? await peerInfo.v4Info.socketAddress() { SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", for: .punchnet) udpHole.send(type: .register, data: registerData, remoteAddress: remoteAddress) @@ -128,10 +133,10 @@ actor SDLPuncherActor { } } - if let udpHoleV6, peerInfo.hasV6Info { + if peerInfo.hasV6Info { if let remoteAddress = try? await peerInfo.v6Info.socketAddress() { SDLLogger.log("[SDLContext] hole sock address v6: \(remoteAddress)", for: .punchnet) - udpHoleV6.send(type: .register, data: registerData, remoteAddress: remoteAddress) + udpHole.send(type: .register, data: registerData, remoteAddress: remoteAddress) } else { SDLLogger.log("[SDLPuncherActor] failed to resolve peerInfo.v6Info", for: .debug) } diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index ae1dd85..4d65b28 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -57,7 +57,8 @@ final class SDLUDPHole: ChannelInboundHandler { channel.pipeline.addHandler(self) } - let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() + // 绑定到IPv6通配地址,依赖SwiftNIO创建dual-stack socket,同时接收IPv4/IPv6流量 + let channel = try bootstrap.bind(host: "::", port: 0).wait() self.channel = channel self.closeFuture = channel.closeFuture self.state = .ready diff --git a/Tun/Punchnet/SDLUDPHoleV6.swift b/Tun/Punchnet/SDLUDPHoleV6.swift deleted file mode 100644 index fc03611..0000000 --- a/Tun/Punchnet/SDLUDPHoleV6.swift +++ /dev/null @@ -1,214 +0,0 @@ -// -// SDLUDPHoleV6.swift -// Tun -// -// Created by 安礼成 on 2026/4/14. -// - -import Foundation -import NIOCore -import NIOPosix -import SwiftProtobuf - -// 处理和sn-server服务器之间的通讯, ipv6版本 -final class SDLUDPHoleV6: ChannelInboundHandler { - typealias InboundIn = AddressedEnvelope - - private enum State: Equatable { - case idle - case ready - case stopping - case stopped - } - - private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - private var channel: Channel? - private var closeFuture: EventLoopFuture? - private var state: State = .idle - private var didFinishMessageStream: Bool = false - - public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)> - private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation - - // 启动函数 - init() throws { - let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048)) - self.messageStream = stream - self.messageContinuation = continuation - } - - func start() throws -> SocketAddress { - switch self.state { - case .ready: - guard let channel = self.channel else { - preconditionFailure("SDLUDPHoleV6 is ready but channel is nil") - } - precondition(channel.localAddress != nil, "UDP IPv6 channel has no localAddress after bind") - return channel.localAddress! - case .stopping, .stopped: - preconditionFailure("SDLUDPHoleV6 cannot be restarted after stop") - case .idle: - break - } - - let bootstrap = DatagramBootstrap(group: group) - .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .channelInitializer { channel in - channel.pipeline.addHandler(self) - } - - let channel = try bootstrap.bind(host: "::", port: 0).wait() - self.channel = channel - self.closeFuture = channel.closeFuture - self.state = .ready - precondition(channel.localAddress != nil, "UDP IPv6 channel has no localAddress after bind") - - return channel.localAddress! - } - - func waitClose() async throws { - switch self.state { - case .idle: - return - case .ready, .stopping, .stopped: - guard let closeFuture = self.closeFuture else { - return - } - try await closeFuture.get() - } - } - - func stop() { - switch self.state { - case .stopping, .stopped: - return - case .idle: - self.state = .stopped - self.finishMessageStream() - return - case .ready: - self.state = .stopping - } - - self.finishMessageStream() - self.channel?.close(promise: nil) - } - - // --MARK: ChannelInboundHandler delegate - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - guard case .ready = self.state else { - return - } - - let envelope = unwrapInboundIn(data) - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - - if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) { - SDLLogger.log("[SDLUDPHoleV6] get raw bytes: \(rawBytes.count), from: \(remoteAddress)", for: .debug) - } - - do { - if let message = try decode(buffer: &buffer) { - self.messageContinuation.yield((remoteAddress, message)) - } else { - SDLLogger.log("[SDLUDPHoleV6] decode message, get null", for: .debug) - } - } catch let err { - SDLLogger.log("[SDLUDPHoleV6] decode message, get error: \(err)", for: .debug) - } - } - - func channelInactive(context: ChannelHandlerContext) { - self.finishMessageStream() - self.channel = nil - self.state = .stopped - } - - func errorCaught(context: ChannelHandlerContext, error: any Error) { - SDLLogger.log("[SDLUDPHoleV6] channel error: \(error)", for: .debug) - self.finishMessageStream() - if self.state != .stopped { - self.state = .stopping - } - context.close(promise: nil) - } - - // MARK: 处理写入逻辑 - func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { - guard case .ready = self.state, let channel = self.channel else { - return - } - - var buffer = channel.allocator.buffer(capacity: data.count + 1) - buffer.writeBytes([type.rawValue]) - buffer.writeBytes(data) - - let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) - _ = channel.eventLoop.submit { - channel.writeAndFlush(envelope, promise: nil) - } - } - - // --MARK: 编解码器 - private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { - guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type) else { - return nil - } - - switch packetType { - case .data: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let dataPacket = try? SDLData(serializedBytes: bytes) else { - return nil - } - return .data(dataPacket) - case .register: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerPacket = try? SDLRegister(serializedBytes: bytes) else { - return nil - } - return .register(registerPacket) - case .registerAck: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else { - return nil - } - return .registerAck(registerAck) - case .stunProbeReply: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else { - return nil - } - return .stunProbeReply(stunProbeReply) - case .stunReply: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let stunReply = try? SDLStunReply(serializedBytes: bytes) else { - return nil - } - return .stunReply(stunReply) - default: - SDLLogger.log("SDLUDPHoleV6 decode miss type: \(type)", for: .debug) - - return nil - } - } - - private func finishMessageStream() { - guard !self.didFinishMessageStream else { - return - } - - self.didFinishMessageStream = true - self.messageContinuation.finish() - } - - deinit { - self.stop() - try? self.group.syncShutdownGracefully() - } - -}