diff --git a/Tun/PacketTunnelProvider.swift b/Tun/PacketTunnelProvider.swift index 296e7da..cea144c 100644 --- a/Tun/PacketTunnelProvider.swift +++ b/Tun/PacketTunnelProvider.swift @@ -37,7 +37,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { let stunServersStr = options["stun_servers"] as! String let noticePort = options["notice_port"] as! Int let token = options["token"] as! String - let networkCode = options["network_code"] as! String + let accessToken = options["access_token"] as! String let clientId = options["client_id"] as! String let remoteDnsServer = options["remote_dns_server"] as! String let hostname = options["hostname"] as! String @@ -61,7 +61,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { return } - NSLog("[PacketTunnelProvider] client_id: \(clientId), token: \(token), network_code: \(networkCode)") + NSLog("[PacketTunnelProvider] client_id: \(clientId), token: \(token)") let config = SDLConfiguration(version: 1, installedChannel: installed_channel, @@ -71,7 +71,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider { clientId: clientId, noticePort: noticePort, token: token, - networkCode: networkCode, + accessToken: accessToken, remoteDnsServer: remoteDnsServer, hostname: hostname) // 加密算法 diff --git a/Tun/Punchnet/Actors/SDLNATProberActor.swift b/Tun/Punchnet/Actors/SDLNATProberActor.swift new file mode 100644 index 0000000..0209dc3 --- /dev/null +++ b/Tun/Punchnet/Actors/SDLNATProberActor.swift @@ -0,0 +1,179 @@ +// +// SDLNATProberActor.swift +// punchnet +// +// Created by 安礼成 on 2026/1/28. +// +import Foundation +import NIOCore + +actor SDLNATProberActor { + + // MARK: - NAT Type + + enum NatType: UInt8, Encodable { + case blocked = 0 + case noNat = 1 + case fullCone = 2 + case portRestricted = 3 + case coneRestricted = 4 + case symmetric = 5 + } + + // MARK: - Internal State + + private enum State { + case idle + case waiting(step: Step) + case finished + } + + private enum Step: Int { + case step1 = 1 + case step2 = 2 + case step3 = 3 + case step4 = 4 + } + + private var state: State = .idle + + // MARK: - Dependencies + + private let udpHole: SDLUDPHoleActor + private let config: SDLConfiguration + private let logger: SDLLogger + + // MARK: - Probe Data + + private var natAddress1: SocketAddress? + private var natAddress2: SocketAddress? + + // MARK: - Completion + + private var onFinished: ((NatType) -> Void)? + + private var cookieId: UInt32 = 1 + + // MARK: - Init + + init(udpHole: SDLUDPHoleActor, config: SDLConfiguration, logger: SDLLogger) { + self.udpHole = udpHole + self.config = config + self.logger = logger + } + + // MARK: - Public API + + /// 启动 NAT 探测(一次性) + func start(onFinished: @escaping (NatType) -> Void) async { + guard case .idle = state else { + logger.log("[NAT] probe already started", level: .warning) + return + } + + self.onFinished = onFinished + transition(to: .waiting(step: .step1)) + await self.sendProbe(step: .step1) + } + + /// UDP 层收到 STUN 响应后调用 + func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async { + guard case .waiting(let currentStep) = state else { + return + } + + switch currentStep { + case .step1: + let localAddress = await self.udpHole.getLocalAddress() + if address == localAddress { + finish(.noNat) + return + } + + natAddress1 = address + transition(to: .waiting(step: .step2)) + await self.sendProbe(step: .step2) + + case .step2: + natAddress2 = address + // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 + // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 + if let ip1 = natAddress1?.ipAddress, let ip2 = natAddress2?.ipAddress, ip1 != ip2 { + finish(.symmetric) + return + } + + transition(to: .waiting(step: .step3)) + await self.sendProbe(step: .step3) + case .step3: + // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) + // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 + finish(.fullCone) + case .step4: + finish(.coneRestricted) + } + } + + /// 超时事件(由外部 Timer / Task 驱动) + func handleTimeout() async { + guard case .waiting(let currentStep) = state else { + return + } + + switch currentStep { + case .step3: + transition(to: .waiting(step: .step4)) + await sendProbe(step: .step4) + case .step4: + finish(.portRestricted) + default: + finish(.blocked) + } + } + + // MARK: - Internal helpers + + private func sendProbe(step: Step) async { + let addressArray = config.stunProbeSocketAddressArray + + let remote: SocketAddress + let attr: SDLProbeAttr + + switch step { + case .step1: + remote = addressArray[0][0] + attr = .none + case .step2: + remote = addressArray[1][1] + attr = .none + case .step3: + remote = addressArray[0][0] + attr = .peer + case .step4: + remote = addressArray[0][0] + attr = .port + } + + var stunProbe = SDLStunProbe() + stunProbe.cookie = self.cookieId + stunProbe.attr = UInt32(attr.rawValue) + + self.cookieId &+= 1 + + await self.udpHole.send(type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remote) + } + + private func finish(_ type: NatType) { + guard case .finished = state else { + transition(to: .finished) + logger.log("[NAT] finished with \(type)", level: .info) + onFinished?(type) + onFinished = nil + return + } + } + + private func transition(to newState: State) { + state = newState + } +} diff --git a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift index f6106aa..575eba5 100644 --- a/Tun/Punchnet/Actors/SDLUDPHoleActor.swift +++ b/Tun/Punchnet/Actors/SDLUDPHoleActor.swift @@ -12,32 +12,32 @@ import SwiftProtobuf // 处理和sn-server服务器之间的通讯 actor SDLUDPHoleActor { - typealias HoleMessage = (SocketAddress, SDLHoleInboundMessage) - private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) private let asyncChannel: NIOAsyncChannel, AddressedEnvelope> - private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPMessage.self, bufferingPolicy: .unbounded) + private let (writeStream, writeContinuation) = AsyncStream.makeStream(of: UDPHoleOutboundMessage.self, bufferingPolicy: .unbounded) - private var cookieGenerator = SDLIdGenerator(seed: 1) - private var promises: [UInt32:EventLoopPromise] = [:] public var localAddress: SocketAddress? - - public let messageStream: AsyncStream - private let messageContinuation: AsyncStream.Continuation + public let eventStream: AsyncStream + private let eventContinuation: AsyncStream.Continuation private let logger: SDLLogger - struct UDPMessage { + struct UDPHoleOutboundMessage { let remoteAddress: SocketAddress let type: SDLPacketType let data: Data } - + + enum UDPHoleEvent { + case ready + case message(SocketAddress, SDLHoleInboundMessage) + } + // 启动函数 init(logger: SDLLogger) async throws { self.logger = logger - (self.messageStream, self.messageContinuation) = AsyncStream.makeStream(of: HoleMessage.self, bufferingPolicy: .unbounded) + (self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded) let bootstrap = DatagramBootstrap(group: group) .channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) @@ -56,108 +56,65 @@ actor SDLUDPHoleActor { } func start() async throws { - try await withTaskCancellationHandler { - try await self.asyncChannel.executeThenClose {inbound, outbound in - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - defer { - self.logger.log("[SDLUDPHole] inbound closed", level: .warning) - } + try await self.asyncChannel.executeThenClose {inbound, outbound in + self.eventContinuation.yield(.ready) + + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + defer { + self.logger.log("[SDLUDPHole] inbound closed", level: .warning) + } + + for try await envelope in inbound { + try Task.checkCancellation() - for try await envelope in inbound { - try Task.checkCancellation() - - var buffer = envelope.data - let remoteAddress = envelope.remoteAddress - do { - if let message = try Self.decode(buffer: &buffer) { - switch message { - case .data(let data): - self.logger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug) - self.messageContinuation.yield((remoteAddress, .data(data))) - case .stunProbeReply(let probeReply): - // 执行并移除回调 - await self.trigger(probeReply: probeReply) - default: - () - } - } else { - self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) - } - } catch let err { - self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) - throw err + var buffer = envelope.data + let remoteAddress = envelope.remoteAddress + do { + if let message = try Self.decode(buffer: &buffer) { + self.eventContinuation.yield(.message(remoteAddress, message)) + } else { + self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) } + } catch let err { + self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) + throw err } } - - group.addTask { - defer { - self.logger.log("[SDLUDPHole] outbound closed", level: .warning) - } - - for await message in self.writeStream { - try Task.checkCancellation() - - var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) - buffer.writeBytes([message.type.rawValue]) - buffer.writeBytes(message.data) - - let envelope = AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer) - try await outbound.write(envelope) - } - } - - if let _ = try await group.next() { - group.cancelAll() - } + } + + group.addTask { + defer { + self.logger.log("[SDLUDPHole] outbound closed", level: .warning) + } + + for await message in self.writeStream { + try Task.checkCancellation() + + var buffer = self.asyncChannel.channel.allocator.buffer(capacity: message.data.count + 1) + buffer.writeBytes([message.type.rawValue]) + buffer.writeBytes(message.data) + + let envelope = AddressedEnvelope(remoteAddress: message.remoteAddress, data: buffer) + try await outbound.write(envelope) + } + } + + if let _ = try await group.next() { + group.cancelAll() } } - } onCancel: { - self.writeContinuation.finish() - self.messageContinuation.finish() - self.logger.log("[SDLUDPHole] withTaskCancellationHandler cancel") } } - func getCookieId() -> UInt32 { - return self.cookieGenerator.nextId() - } - - // 探测tun信息 - func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async throws -> SDLStunProbeReply { - return try await self._stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout).get() - } - - private func _stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int) -> EventLoopFuture { - let cookie = self.cookieGenerator.nextId() - var stunProbe = SDLStunProbe() - stunProbe.cookie = cookie - stunProbe.attr = UInt32(attr.rawValue) - self.send( type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remoteAddress) - self.logger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .debug) - - let promise = self.asyncChannel.channel.eventLoop.makePromise(of: SDLStunProbeReply.self) - self.promises[cookie] = promise - - return promise.futureResult - } - - private func trigger(probeReply: SDLStunProbeReply) { - let id = probeReply.cookie - // 执行并移除回调 - if let promise = self.promises[id] { - self.asyncChannel.channel.eventLoop.execute { - promise.succeed(probeReply) - } - self.promises.removeValue(forKey: id) - } + func getLocalAddress() -> SocketAddress? { + return self.localAddress } // MARK: client-client apis // 处理写入逻辑 func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) { - let message = UDPMessage(remoteAddress: remoteAddress, type: type, data: data) + let message = UDPHoleOutboundMessage(remoteAddress: remoteAddress, type: type, data: data) self.writeContinuation.yield(message) } @@ -238,7 +195,7 @@ actor SDLUDPHoleActor { deinit { try? self.group.syncShutdownGracefully() self.writeContinuation.finish() - self.messageContinuation.finish() + self.eventContinuation.finish() } } diff --git a/Tun/Punchnet/SDLConfiguration.swift b/Tun/Punchnet/SDLConfiguration.swift index cc87e22..db9fc15 100644 --- a/Tun/Punchnet/SDLConfiguration.swift +++ b/Tun/Punchnet/SDLConfiguration.swift @@ -63,9 +63,9 @@ public class SDLConfiguration { let clientId: String let token: String - let networkCode: String + let accessToken: String - public init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, noticePort: Int, token: String, networkCode: String, remoteDnsServer: String, hostname: String) { + public init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, noticePort: Int, token: String, accessToken: String, remoteDnsServer: String, hostname: String) { self.version = version self.installedChannel = installedChannel self.superHost = superHost @@ -74,7 +74,7 @@ public class SDLConfiguration { self.clientId = clientId self.noticePort = noticePort self.token = token - self.networkCode = networkCode + self.accessToken = accessToken self.remoteDnsServer = remoteDnsServer self.hostname = hostname } diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 7c5249e..fdf7dd8 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -30,10 +30,8 @@ public class SDLContext { let config: SDLConfiguration - // nat映射的相关信息, 暂时没有用处 - //var natAddress: SDLNatAddress? // nat的网络类型 - var natType: NatType = .blocked + var natType: SDLNATProberActor.NatType = .blocked // AES加密,授权通过后,对象才会被创建 var aesCipher: AESCipher @@ -51,6 +49,9 @@ public class SDLContext { // dns的client对象 var dnsClientActor: SDLDNSClientActor? + // 网络探测对象 + var proberActor: SDLNATProberActor? + // 数据包读取任务 private var readTask: Task<(), Never>? @@ -162,7 +163,6 @@ public class SDLContext { try Task.checkCancellation() if let udpHoleActor = self.udpHoleActor { - let cookie = await udpHoleActor.getCookieId() var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId @@ -172,36 +172,39 @@ public class SDLContext { let remoteAddress = self.config.stunSocketAddress await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress) - self.lastCookie = cookie } } } group.addTask { - if let messageStream = self.udpHoleActor?.messageStream { - for try await (remoteAddress, message) in messageStream { + if let eventStream = self.udpHoleActor?.eventStream { + for try await event in eventStream { try Task.checkCancellation() - switch message { - case .registerSuperAck(let registerSuperAck): - await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) - case .registerSuperNak(let registerSuperNak): - await self.handleRegisterSuperNak(nakPacket: registerSuperNak) - case .peerInfo(let peerInfo): - () - case .event(let event): - try await self.handleEvent(event: event) - case .stunReply(let stunReply): - await self.handleStunReply(stunReply: stunReply) - case .stunProbeReply(_): - () - case .data(let data): - try await self.handleData(data: data) - case .register(let register): - try await self.handleRegister(remoteAddress: remoteAddress, register: register) - case .registerAck(let registerAck): - await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) - default: - self.logger.log("get unknown message: \(message)", level: .error) + + switch event { + case .ready: + try await self.handleUDPHoleReady() + case .message(let remoteAddress, let message): + switch message { + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + await self.puncherActor.handlePeerInfo(peerInfo: peerInfo) + case .event(let event): + try await self.handleEvent(event: event) + case .stunReply(let stunReply): + await self.handleStunReply(stunReply: stunReply) + case .stunProbeReply(_): + () + case .data(let data): + try await self.handleData(data: data) + case .register(let register): + try await self.handleRegister(remoteAddress: remoteAddress, register: register) + case .registerAck(let registerAck): + await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + } } } } @@ -220,7 +223,8 @@ public class SDLContext { switch event { case .changed: // 需要重新探测网络的nat类型 - self.natType = await self.getNatType() + //self.natType = await self.getNatType() + self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info) case .unreachable: self.logger.log("didNetworkPathUnreachable", level: .warning) @@ -258,24 +262,30 @@ public class SDLContext { } } -// private func handleMessage(remoteAddress: SocketAddress, message: SDLHoleInboundMessage) async throws { -// switch message { -//// case .ready: -//// await self.puncherActor.setSuperClientActor(superClientActor: self.superClientActor) -//// -//// self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug) -//// var registerSuper = SDLRegisterSuper() -//// registerSuper.version = UInt32(self.config.version) -//// registerSuper.clientID = self.config.clientId -//// registerSuper.devAddr = self.devAddr -//// registerSuper.pubKey = self.rsaCipher.pubKey -//// registerSuper.token = self.config.token -//// registerSuper.networkCode = self.config.networkCode -//// registerSuper.hostname = self.config.hostname -//// guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()) else { -//// return -//// } - /// + private func handleUDPHoleReady() async throws { + await self.puncherActor.setUDPHoleActor(udpHoleActor: self.udpHoleActor) + + // 开始探测nat的类型 + if let udpHoleActor = self.udpHoleActor { + self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger) + await self.proberActor?.start { natType in + self.natType = natType + } + } + + var registerSuper = SDLRegisterSuper() + registerSuper.pktID = 0 + registerSuper.clientID = self.config.clientId + registerSuper.networkID = self.config.networkAddress.networkId + registerSuper.mac = self.config.networkAddress.mac + registerSuper.ip = self.config.networkAddress.ip + registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen) + registerSuper.hostname = self.config.hostname + registerSuper.pubKey = self.rsaCipher.pubKey + registerSuper.accessToken = self.config.accessToken + + await self.udpHoleActor?.send(type: .registerSuper, data: try registerSuper.serializedData(), remoteAddress: self.config.stunSocketAddress) + } private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 @@ -549,71 +559,6 @@ public class SDLContext { } } -// 网络类型探测 -extension SDLContext { - // 定义nat类型 - enum NatType: UInt8, Encodable { - case blocked = 0 - case noNat = 1 - case fullCone = 2 - case portRestricted = 3 - case coneRestricted = 4 - case symmetric = 5 - } - - // 获取当前所处的网络的nat类型 - func getNatType() async -> NatType { - guard let udpHole = self.udpHoleActor else { - return .blocked - } - - let addressArray = config.stunProbeSocketAddressArray - // step1: ip1:port1 <---- ip1:port1 - guard let natAddress1 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .none) else { - return .blocked - } - - // 网络没有在nat下 - if await natAddress1 == udpHole.localAddress { - return .noNat - } - - // step2: ip2:port2 <---- ip2:port2 - guard let natAddress2 = await getNatAddress(udpHole, remoteAddress: addressArray[1][1], attr: .none) else { - return .blocked - } - - // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 - // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2)", level: .debug) - if let ipAddress1 = natAddress1.ipAddress, let ipAddress2 = natAddress2.ipAddress, ipAddress1 != ipAddress2 { - return .symmetric - } - - // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) - // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let natAddress3 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .peer) { - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address3: \(natAddress3)", level: .debug) - return .fullCone - } - - // step3: ip1:port1 <---- ip1:port2 (port改变情况) - // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let natAddress4 = await getNatAddress(udpHole, remoteAddress: addressArray[0][0], attr: .port) { - logger.log("[SDLNatProber] nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address4: \(natAddress4)", level: .debug) - return .coneRestricted - } else { - return .portRestricted - } - } - - private func getNatAddress(_ udpHole: SDLUDPHoleActor, remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { - let stunProbeReply = try? await udpHole.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) - return stunProbeReply?.socketAddress() - } - -} - private extension UInt32 { // 转换成ip地址 func asIpAddress() -> String {