From d1452ce0b758f392a153bfe59aa9f33d8739943a Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Fri, 1 Aug 2025 12:01:37 +0800 Subject: [PATCH] fix --- Sources/Punchnet/SDLContext.swift | 37 ++++++----- Sources/Punchnet/SDLMessage.swift | 2 +- Sources/Punchnet/SDLNatProber.swift | 8 +-- Sources/Punchnet/SDLNoticeClient.swift | 85 +++++++++++--------------- Sources/Punchnet/SDLUDPHoleActor.swift | 2 + 5 files changed, 64 insertions(+), 70 deletions(-) diff --git a/Sources/Punchnet/SDLContext.swift b/Sources/Punchnet/SDLContext.swift index 373d0d0..42b6db6 100644 --- a/Sources/Punchnet/SDLContext.swift +++ b/Sources/Punchnet/SDLContext.swift @@ -70,7 +70,7 @@ public class SDLContext: @unchecked Sendable { private var monitorCancel: AnyCancellable? // 内部socket通讯 - private var noticeClient: SDLNoticeClient + private var noticeClient: SDLNoticeClient? // 流量统计 private var flowTracer = SDLFlowTracerActor() @@ -90,13 +90,13 @@ public class SDLContext: @unchecked Sendable { self.sessionManager = SessionManager() self.holerManager = HolerManager() self.arpServer = ArpServer(known_macs: [:]) - self.noticeClient = SDLNoticeClient() } public func start() async throws { self.udpHoleActor = try await SDLUDPHoleActor() self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort) - + self.noticeClient = try await SDLNoticeClient() + try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { try await self.udpHoleActor?.start() @@ -106,6 +106,10 @@ public class SDLContext: @unchecked Sendable { try await self.superClientActor?.start() } + group.addTask { + try await self.noticeClient?.start() + } + group.addTask { if let eventFlow = self.superClientActor?.eventFlow { for try await event in eventFlow { @@ -123,10 +127,16 @@ public class SDLContext: @unchecked Sendable { } + group.addTask { + while !Task.isCancelled { + try await Task.sleep(nanoseconds: 5 * 1_000_000_000) + self.lastCookie = await self.udpHoleActor?.stunRequest(context: self) + } + } + try await group.waitForAll() } -// self.noticeClient.start() // // 启动网络监控 // self.monitorCancel = self.monitor.eventFlow.sink { event in // switch event { @@ -169,7 +179,7 @@ public class SDLContext: @unchecked Sendable { if upgradeType == .force { let forceUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) - self.noticeClient.send(data: forceUpgrade.binaryData) + await self.noticeClient?.send(data: forceUpgrade.binaryData) exit(-1) } @@ -179,7 +189,7 @@ public class SDLContext: @unchecked Sendable { if upgradeType == .normal { let normalUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) - self.noticeClient.send(data: normalUpgrade.binaryData) + await self.noticeClient?.send(data: normalUpgrade.binaryData) } case .registerSuperNak(let nakPacket): @@ -191,11 +201,11 @@ public class SDLContext: @unchecked Sendable { switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage) - self.noticeClient.send(data: alertNotice.binaryData) + await self.noticeClient?.send(data: alertNotice.binaryData) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage) - self.noticeClient.send(data: alertNotice.binaryData) + await self.noticeClient?.send(data: alertNotice.binaryData) } NSLog("[SDLContext] Get a SuperNak message exit") default: @@ -226,7 +236,7 @@ public class SDLContext: @unchecked Sendable { case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.AlertMessage(alert: shutdownEvent.message) - self.noticeClient.send(data: alertNotice.binaryData) + await self.noticeClient?.send(data: alertNotice.binaryData) exit(-1) } case .command(let packetId, let command): @@ -267,14 +277,9 @@ public class SDLContext: @unchecked Sendable { switch event { case .ready: // 获取当前网络的类型 - self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config) + //self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config) SDLLogger.log("[SDLContext] nat type is: \(self.natType)", level: .debug) - let timer = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect() -// self.stunCancel = Just(Date()).merge(with: timer).sink { _ in -// self.lastCookie = await self.udpHoleActor?.stunRequest(context: self) -// } - case .closed: DispatchQueue.main.asyncAfter(deadline: .now() + 5) { Task { @@ -289,7 +294,7 @@ public class SDLContext: @unchecked Sendable { // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { // 回复ack包 - self.udpHoleActor?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac) + await self.udpHoleActor?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) diff --git a/Sources/Punchnet/SDLMessage.swift b/Sources/Punchnet/SDLMessage.swift index f5c1f35..9293a29 100644 --- a/Sources/Punchnet/SDLMessage.swift +++ b/Sources/Punchnet/SDLMessage.swift @@ -54,7 +54,7 @@ enum SDLUpgradeType: UInt32 { } // Id生成器 -struct SDLIdGenerator { +struct SDLIdGenerator: Sendable { // 消息体id private var packetId: UInt32 diff --git a/Sources/Punchnet/SDLNatProber.swift b/Sources/Punchnet/SDLNatProber.swift index 5201b89..18710d7 100644 --- a/Sources/Punchnet/SDLNatProber.swift +++ b/Sources/Punchnet/SDLNatProber.swift @@ -22,7 +22,7 @@ struct SDLNatProber { } // 获取当前所处的网络的nat类型 - static func getNatType(udpHole: SDLUDPHole?, config: SDLConfiguration) async -> NatType { + static func getNatType(udpHole: SDLUDPHoleActor?, config: SDLConfiguration) async -> NatType { guard let udpHole else { return .blocked } @@ -34,7 +34,7 @@ struct SDLNatProber { } // 网络没有在nat下 - if natAddress1 == udpHole.localAddress { + if await natAddress1 == udpHole.localAddress { return .noNat } @@ -67,8 +67,8 @@ struct SDLNatProber { } } - private static func getNatAddress(_ udpHole: SDLUDPHole, remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { - let stunProbeReply = await udpHole.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) + private static func getNatAddress(_ udpHole: SDLUDPHoleActor, remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { + let stunProbeReply = try? await udpHole.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) return stunProbeReply?.socketAddress() } diff --git a/Sources/Punchnet/SDLNoticeClient.swift b/Sources/Punchnet/SDLNoticeClient.swift index 7c0e904..73d0afc 100644 --- a/Sources/Punchnet/SDLNoticeClient.swift +++ b/Sources/Punchnet/SDLNoticeClient.swift @@ -15,72 +15,59 @@ import Foundation // import Foundation -@preconcurrency import NIOCore +import NIOCore import NIOPosix // 处理和sn-server服务器之间的通讯 -class SDLNoticeClient: ChannelInboundHandler, @unchecked Sendable { - public typealias InboundIn = AddressedEnvelope - public typealias OutboundOut = AddressedEnvelope - - var channel: Channel? +actor SDLNoticeClient { private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + private let asyncChannel: NIOAsyncChannel, AddressedEnvelope> private let remoteAddress: SocketAddress - - init() { - self.remoteAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 50195) - } + private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded) // 启动函数 - func start() { + init() async throws { + self.remoteAddress = try! SocketAddress(ipAddress: "127.0.0.1", port: 50195) + let bootstrap = DatagramBootstrap(group: self.group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .channelInitializer { channel in - // 接收缓冲区 - channel.pipeline.addHandler(self) + + self.asyncChannel = try await bootstrap.bind(host: "0.0.0.0", port: 0) + .flatMapThrowing {channel in + return try NIOAsyncChannel(wrappingChannelSynchronously: channel, configuration: .init( + inboundType: AddressedEnvelope.self, + outboundType: AddressedEnvelope.self + )) } + .get() - self.channel = try! bootstrap.bind(host: "0.0.0.0", port: 0).wait() - SDLLogger.log("[SDLNoticeClient] started and listening on: \(self.channel?.localAddress!)", level: .debug) + SDLLogger.log("[SDLNoticeClient] started and listening on: \(self.asyncChannel.channel.localAddress!)", level: .debug) } - // -- MARK: ChannelInboundHandler Methods - - public func channelActive(context: ChannelHandlerContext) { - - } - - // 接收到的消息, 消息需要根据类型分流 - public func channelRead(context: ChannelHandlerContext, data: NIOAny) { - context.fireChannelRead(data) - } - - public func errorCaught(context: ChannelHandlerContext, error: Error) { - // As we are not really interested getting notified on success or failure we just pass nil as promise to - // reduce allocations. - context.close(promise: nil) - self.channel = nil - } - - public func channelInactive(context: ChannelHandlerContext) { - self.channel = nil - context.close(promise: nil) + func start() async throws { + try await self.asyncChannel.executeThenClose { inbound, outbound in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + self.writeContinuation.finish() + } + + for try await message in self.writeStream { + let buf = self.asyncChannel.channel.allocator.buffer(bytes: message) + let envelope = AddressedEnvelope(remoteAddress: self.remoteAddress, data: buf) + + try await outbound.write(envelope) + } + } + + try await group.waitForAll() + } + } } // 处理写入逻辑 func send(data: Data) { - guard let channel = self.channel else { - return - } - - let remoteAddress = self.remoteAddress - let allocator = channel.allocator - - channel.eventLoop.execute { [allocator] in - let buffer = allocator.buffer(bytes: data) - let envelope = AddressedEnvelope(remoteAddress: remoteAddress, data: buffer) - channel.writeAndFlush(self.wrapOutboundOut(envelope), promise: nil) - } + self.writeContinuation.yield(data) } deinit { diff --git a/Sources/Punchnet/SDLUDPHoleActor.swift b/Sources/Punchnet/SDLUDPHoleActor.swift index a4cede4..c4ec640 100644 --- a/Sources/Punchnet/SDLUDPHoleActor.swift +++ b/Sources/Punchnet/SDLUDPHoleActor.swift @@ -100,6 +100,8 @@ actor SDLUDPHoleActor { } } + + //eventFlow.send(.ready) try await group.waitForAll()