// // SDLanServer.swift // Tun // // Created by 安礼成 on 2024/1/31. // import Foundation import NIOCore import NIOPosix import SwiftProtobuf // 处理和sn-server服务器之间的通讯 final class SDLUDPHole: ChannelInboundHandler { typealias InboundIn = AddressedEnvelope private enum State: Equatable { case idle case ready case stopping case stopped } private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var channel: Channel? private var closeFuture: EventLoopFuture? private var state: State = .idle private var didFinishMessageStream: Bool = false public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)> private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation // 启动函数 init() throws { let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048)) self.messageStream = stream self.messageContinuation = continuation } func start() throws -> SocketAddress { let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in channel.pipeline.addHandler(self) } // 绑定到IPv6通配地址,依赖SwiftNIO创建dual-stack socket,同时接收IPv4/IPv6流量 let channel = try bootstrap.bind(host: "::", port: 0).wait() self.channel = channel self.closeFuture = channel.closeFuture self.state = .ready precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind") return channel.localAddress! } func waitClose() async throws { switch self.state { case .idle: SDLLogger.log("[SDLUDPHole] waitClose11", for: .debug) return case .ready, .stopping, .stopped: guard let closeFuture = self.closeFuture else { SDLLogger.log("[SDLUDPHole] waitClose22", for: .debug) return } try await closeFuture.get() SDLLogger.log("[SDLUDPHole] waitClose33", for: .debug) } } func stop() { switch self.state { case .stopping, .stopped: return case .idle: self.state = .stopped self.finishMessageStream() return case .ready: self.state = .stopping } self.finishMessageStream() self.channel?.close(promise: nil) } // --MARK: ChannelInboundHandler delegate func channelRead(context: ChannelHandlerContext, data: NIOAny) { guard case .ready = self.state else { return } let envelope = unwrapInboundIn(data) var buffer = envelope.data let remoteAddress = envelope.remoteAddress if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) { SDLLogger.log("[SDLUDPHole] get raw bytes: \(rawBytes.count), from: \(remoteAddress)", for: .debug) } do { if let message = try decode(buffer: &buffer) { self.messageContinuation.yield((remoteAddress, message)) } else { SDLLogger.log("[SDLUDPHole] decode message, get null", for: .debug) } } catch let err { SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", for: .debug) } } func channelInactive(context: ChannelHandlerContext) { self.finishMessageStream() self.channel = nil self.state = .stopped } func errorCaught(context: ChannelHandlerContext, error: any Error) { SDLLogger.log("[SDLUDPHole] channel error: \(error)", for: .debug) self.finishMessageStream() if self.state != .stopped { self.state = .stopping } context.close(promise: nil) } // MARK: 处理写入逻辑 func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { guard case .ready = self.state, let channel = self.channel else { return } var buffer = channel.allocator.buffer(capacity: data.count + 1) buffer.writeBytes([type.rawValue]) buffer.writeBytes(data) let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) _ = channel.eventLoop.submit { channel.writeAndFlush(envelope, promise: nil) } } // --MARK: 编解码器 private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { guard let type = buffer.readInteger(as: UInt8.self), let packetType = SDLPacketType(rawValue: type) else { return nil } switch packetType { case .data: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let dataPacket = try? SDLData(serializedBytes: bytes) else { return nil } return .data(dataPacket) case .register: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerPacket = try? SDLRegister(serializedBytes: bytes) else { return nil } return .register(registerPacket) case .registerAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else { return nil } return .registerAck(registerAck) case .stunProbeReply: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else { return nil } return .stunProbeReply(stunProbeReply) case .stunReply: guard let bytes = buffer.readBytes(length: buffer.readableBytes), let stunReply = try? SDLStunReply(serializedBytes: bytes) else { return nil } return .stunReply(stunReply) default: SDLLogger.log("SDLUDPHole decode miss type: \(type)", for: .debug) return nil } } private func finishMessageStream() { guard !self.didFinishMessageStream else { return } self.didFinishMessageStream = true self.messageContinuation.finish() } deinit { self.stop() try? self.group.syncShutdownGracefully() } }