diff --git a/Tun/Punchnet/SDLUDPHoleV6.swift b/Tun/Punchnet/SDLUDPHoleV6.swift new file mode 100644 index 0000000..32674ce --- /dev/null +++ b/Tun/Punchnet/SDLUDPHoleV6.swift @@ -0,0 +1,205 @@ +// +// SDLUDPHoleV6.swift +// Tun +// +// Created by 安礼成 on 2026/4/15. +// + +import Foundation +import NIOCore +import NIOPosix +import SwiftProtobuf + +// 处理和sn-server服务器之间的通讯 +final class SDLUDPHoleV6: ChannelInboundHandler { + typealias InboundIn = AddressedEnvelope + + private enum State: Equatable { + case idle + case ready + case stopping + case stopped + } + + private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + private var channel: Channel? + private var closeFuture: EventLoopFuture? + private var state: State = .idle + private var didFinishMessageStream: Bool = false + + public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)> + private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation + + // 启动函数 + init() throws { + let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048)) + self.messageStream = stream + self.messageContinuation = continuation + } + + func start() throws -> SocketAddress { + let bootstrap = DatagramBootstrap(group: group) + .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .channelInitializer { channel in + channel.pipeline.addHandler(self) + } + + // 绑定到IPv6通配地址,只处理IPv6流量 + let channel = try bootstrap.bind(host: "::", port: 0).wait() + self.channel = channel + self.closeFuture = channel.closeFuture + self.state = .ready + precondition(channel.localAddress != nil, "UDP v6 channel has no localAddress after bind") + + return channel.localAddress! + } + + func waitClose() async throws { + switch self.state { + case .idle: + SDLLogger.log("[SDLUDPHoleV6] waitClose11", for: .debug) + return + case .ready, .stopping, .stopped: + guard let closeFuture = self.closeFuture else { + SDLLogger.log("[SDLUDPHoleV6] waitClose22", for: .debug) + return + } + try await closeFuture.get() + SDLLogger.log("[SDLUDPHoleV6] waitClose33", for: .debug) + } + } + + func stop() { + switch self.state { + case .stopping, .stopped: + return + case .idle: + self.state = .stopped + self.finishMessageStream() + return + case .ready: + self.state = .stopping + } + + self.finishMessageStream() + self.channel?.close(promise: nil) + } + + // --MARK: ChannelInboundHandler delegate + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard case .ready = self.state else { + return + } + + let envelope = unwrapInboundIn(data) + + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + + if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) { + SDLLogger.log("[SDLUDPHoleV6] get raw bytes: \(rawBytes.count), from: \(remoteAddress)", for: .debug) + } + + do { + if let message = try decode(buffer: &buffer) { + self.messageContinuation.yield((remoteAddress, message)) + } else { + SDLLogger.log("[SDLUDPHoleV6] decode message, get null", for: .debug) + } + } catch let err { + SDLLogger.log("[SDLUDPHoleV6] decode message, get error: \(err)", for: .debug) + } + } + + func channelInactive(context: ChannelHandlerContext) { + self.finishMessageStream() + self.channel = nil + self.state = .stopped + } + + func errorCaught(context: ChannelHandlerContext, error: any Error) { + SDLLogger.log("[SDLUDPHoleV6] channel error: \(error)", for: .debug) + self.finishMessageStream() + if self.state != .stopped { + self.state = .stopping + } + context.close(promise: nil) + } + + // MARK: 处理写入逻辑 + func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { + guard case .ready = self.state, let channel = self.channel else { + return + } + + var buffer = channel.allocator.buffer(capacity: data.count + 1) + buffer.writeBytes([type.rawValue]) + buffer.writeBytes(data) + + let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) + _ = channel.eventLoop.submit { + channel.writeAndFlush(envelope, promise: nil) + } + } + + // --MARK: 编解码器 + private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { + guard let type = buffer.readInteger(as: UInt8.self), + let packetType = SDLPacketType(rawValue: type) else { + return nil + } + + switch packetType { + case .data: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let dataPacket = try? SDLData(serializedBytes: bytes) else { + return nil + } + return .data(dataPacket) + case .register: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerPacket = try? SDLRegister(serializedBytes: bytes) else { + return nil + } + return .register(registerPacket) + case .registerAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else { + return nil + } + return .registerAck(registerAck) + case .stunProbeReply: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else { + return nil + } + return .stunProbeReply(stunProbeReply) + case .stunReply: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let stunReply = try? SDLStunReply(serializedBytes: bytes) else { + return nil + } + return .stunReply(stunReply) + default: + SDLLogger.log("SDLUDPHoleV6 decode miss type: \(type)", for: .debug) + + return nil + } + } + + private func finishMessageStream() { + guard !self.didFinishMessageStream else { + return + } + + self.didFinishMessageStream = true + self.messageContinuation.finish() + } + + deinit { + self.stop() + try? self.group.syncShutdownGracefully() + } + +}