// // SDLanServer.swift // Tun // // Created by 安礼成 on 2024/1/31. // import Foundation import NIOCore import NIOPosix import Combine // 处理和sn-server服务器之间的通讯 class SDLUDPHole: ChannelInboundHandler, @unchecked Sendable { public typealias InboundIn = AddressedEnvelope public typealias OutboundOut = AddressedEnvelope // 回调函数 public typealias CallbackFun = (SDLStunProbeReply?) -> Void private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private var cookieGenerator = SDLIdGenerator(seed: 1) private var callbackManager = HoleCallbackManager() public var localAddress: SocketAddress? public var channel: Channel? public var eventFlow = PassthroughSubject() // 定义事件类型 enum UDPEvent { case ready case closed case message(SocketAddress, SDLHoleInboundMessage) case data(SDLData) } init() { } // MARK: super_node apis func stunRequest(context ctx: SDLContext) -> UInt32 { let cookie = self.cookieGenerator.nextId() let remoteAddress = ctx.config.stunSocketAddress var stunRequest = SDLStunRequest() stunRequest.cookie = cookie stunRequest.clientID = ctx.config.clientId stunRequest.networkID = ctx.devAddr.networkID stunRequest.ip = ctx.devAddr.netAddr stunRequest.mac = ctx.devAddr.mac stunRequest.natType = UInt32(ctx.natType.rawValue) SDLLogger.log("[SDLUDPHole] stunRequest: \(remoteAddress), host: \(ctx.config.stunServers[0].host):\(ctx.config.stunServers[0].ports[0])", level: .warning) self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! stunRequest.serializedData()) return cookie } // 探测tun信息 func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async -> SDLStunProbeReply? { return await withCheckedContinuation { continuation in self.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout) { probeReply in continuation.resume(returning: probeReply) } } } private func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int, callback: @escaping CallbackFun) { let cookie = self.cookieGenerator.nextId() var stunProbe = SDLStunProbe() stunProbe.cookie = cookie stunProbe.attr = UInt32(attr.rawValue) self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData()) SDLLogger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .warning) self.callbackManager.addCallback(id: cookie, callback: callback) } // MARK: client-client apis // 发送数据包到其他session func sendPacket(context ctx: SDLContext, session: Session, data: Data) { let remoteAddress = session.natAddress var dataPacket = SDLData() dataPacket.networkID = ctx.devAddr.networkID dataPacket.srcMac = ctx.devAddr.mac dataPacket.dstMac = session.dstMac dataPacket.ttl = 255 dataPacket.data = data let packet = try! dataPacket.serializedData() SDLLogger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug) self.send(remoteAddress: remoteAddress, type: .data, data: packet) } // 通过sn服务器转发数据包, data已经是加密过后的数据 func forwardPacket(context ctx: SDLContext, dst_mac: Data, data: Data) { let remoteAddress = ctx.config.stunSocketAddress var dataPacket = SDLData() dataPacket.networkID = ctx.devAddr.networkID dataPacket.srcMac = ctx.devAddr.mac dataPacket.dstMac = dst_mac dataPacket.ttl = 255 dataPacket.data = data let packet = try! dataPacket.serializedData() NSLog("[SDLContext] forward packet, remoteAddress: \(remoteAddress), data size: \(packet.count)") self.send(remoteAddress: remoteAddress, type: .data, data: packet) } // 发送register包 func sendRegister(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) { var register = SDLRegister() register.networkID = ctx.devAddr.networkID register.srcMac = ctx.devAddr.mac register.dstMac = dst_mac SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: ctx.devAddr.mac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug) self.send(remoteAddress: remoteAddress, type: .register, data: try! register.serializedData()) } // 回复registerAck func sendRegisterAck(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) { var registerAck = SDLRegisterAck() registerAck.networkID = ctx.devAddr.networkID registerAck.srcMac = ctx.devAddr.mac registerAck.dstMac = dst_mac SDLLogger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug) self.send(remoteAddress: remoteAddress, type: .registerAck, data: try! registerAck.serializedData()) } // 启动函数 func start() async throws { let bootstrap = DatagramBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .channelInitializer { channel in // 接收缓冲区 return channel.setOption(ChannelOptions.socketOption(.so_rcvbuf), value: 5 * 1024 * 1024) .flatMap { channel.setOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_SNDBUF), value: 5 * 1024 * 1024) }.flatMap { channel.pipeline.addHandler(self) } } let channel = try await bootstrap.bind(host: "0.0.0.0", port: 0).get() SDLLogger.log("[UDPHole] started and listening on: \(channel.localAddress!)", level: .debug) self.localAddress = channel.localAddress self.channel = channel } // -- MARK: ChannelInboundHandler Methods public func channelActive(context: ChannelHandlerContext) { self.eventFlow.send(.ready) } // 接收到的消息, 消息需要根据类型分流 public func channelRead(context: ChannelHandlerContext, data: NIOAny) { let envelope = self.unwrapInboundIn(data) var buffer = envelope.data let remoteAddress = envelope.remoteAddress do { if let message = try decode(buffer: &buffer) { Task { switch message { case .data(let data): SDLLogger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug) self.eventFlow.send(.data(data)) case .stunProbeReply(let probeReply): self.callbackManager.fireCallback(message: probeReply) default: self.eventFlow.send(.message(remoteAddress, message)) } } } else { SDLLogger.log("[SDLUDPHole] decode message, get null", level: .warning) } } catch let err { SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", level: .debug) } } public func errorCaught(context: ChannelHandlerContext, error: Error) { SDLLogger.log("[SDLUDPHole] get error: \(error)", level: .error) // As we are not really interested getting notified on success or failure we just pass nil as promise to // reduce allocations. context.close(promise: nil) self.channel = nil self.eventFlow.send(.closed) } public func channelInactive(context: ChannelHandlerContext) { self.channel = nil context.close(promise: nil) } // 处理写入逻辑 func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) { guard let channel = self.channel else { return } channel.eventLoop.execute { var buffer = channel.allocator.buffer(capacity: data.count + 1) buffer.writeBytes([type.rawValue]) buffer.writeBytes(data) let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) channel.writeAndFlush(self.wrapOutboundOut(envelope), promise: nil) } } deinit { try? self.group.syncShutdownGracefully() } } //--MARK: 编解码器 extension SDLUDPHole { func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? { guard let type = buffer.readInteger(as: UInt8.self), let packetType = SDLPacketType(rawValue: type), let bytes = buffer.readBytes(length: buffer.readableBytes) else { SDLLogger.log("[SDLUDPHole] decode error", level: .error) return nil } switch packetType { case .data: let dataPacket = try SDLData(serializedBytes: bytes) return .data(dataPacket) case .register: let registerPacket = try SDLRegister(serializedBytes: bytes) return .register(registerPacket) case .registerAck: let registerAck = try SDLRegisterAck(serializedBytes: bytes) return .registerAck(registerAck) case .stunReply: let stunReply = try SDLStunReply(serializedBytes: bytes) return .stunReply(stunReply) case .stunProbeReply: let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) return .stunProbeReply(stunProbeReply) default: return nil } } } // --MARK: 回调函数管理器 extension SDLUDPHole { private struct HoleCallbackManager { // 存储回调函数和对应的超时任务 private var callbacks: [UInt32: CallbackFun] = [:] //private var timeoutCallbacks: [UInt32: CallbackFun] = [:] // 添加回调并设置超时 mutating func addCallback(id: UInt32, callback: @escaping CallbackFun) { // 存储回调 self.callbacks[id] = callback } // 正常触发回调(收到响应) mutating func fireCallback(message: SDLStunProbeReply) { let id = message.cookie // 执行并移除回调 if let callback = callbacks[id] { callback(message) self.callbacks.removeValue(forKey: id) } } // 触发所有回调(清理场景) mutating func fireAllCallbacks(message: SDLSuperInboundMessage) { // 触发所有回调 for callback in callbacks.values { callback(nil) } self.callbacks.removeAll() } } }