fix punch

This commit is contained in:
anlicheng 2026-01-07 17:42:27 +08:00
parent ccb3a22707
commit 9d3a380063
2 changed files with 95 additions and 65 deletions

View File

@ -0,0 +1,89 @@
//
// SDLPuncherActor.swift
// Tun
//
// Created by on 2026/1/7.
//
import Foundation
actor SDLPuncherActor {
// dstMac
private var coolingDown: Set<Data> = []
private let cooldown: Duration = .seconds(5)
private var superClientActor: SDLSuperClientActor?
private var udpHoleActor: SDLUDPHoleActor?
// holer
private var logger: SDLLogger
struct RegisterRequest {
let srcMac: Data
let dstMac: Data
let networkId: UInt32
}
init(logger: SDLLogger) {
self.logger = logger
}
func setSuperClientActor(superClientActor: SDLSuperClientActor?) {
self.superClientActor = superClientActor
}
func setUDPHoleActor(udpHoleActor: SDLUDPHoleActor?) {
self.udpHoleActor = udpHoleActor
}
func submitRegisterRequest(request: RegisterRequest) {
let dstMac = request.dstMac
guard !coolingDown.contains(dstMac) else {
return
}
//
coolingDown.insert(dstMac)
Task {
await self.tryHole(request: request)
//
try? await Task.sleep(for: .seconds(5))
self.endCooldown(for: dstMac)
}
}
private func endCooldown(for key: Data) {
self.coolingDown.remove(key)
}
private func tryHole(request: RegisterRequest) async {
var queryInfo = SDLQueryInfo()
queryInfo.dstMac = request.dstMac
guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else {
return
}
switch message.packet {
case .empty:
self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug)
case .peerInfo(let peerInfo):
if let remoteAddress = peerInfo.v4Info.socketAddress() {
self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug)
// register
var register = SDLRegister()
register.networkID = request.networkId
register.srcMac = request.srcMac
register.dstMac = request.dstMac
await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress)
} else {
self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
}
default:
self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning)
}
}
}

View File

@ -51,6 +51,7 @@ public class SDLContext {
var udpHoleActor: SDLUDPHoleActor?
var superClientActor: SDLSuperClientActor?
var providerActor: SDLTunnelProviderActor
var puncherActor: SDLPuncherActor
// dnsclient
var dnsClient: DNSClient?
@ -74,20 +75,9 @@ public class SDLContext {
private var flowTracer = SDLFlowTracerActor()
private var flowTracerCancel: AnyCancellable?
// holer
private var holerPublishers: [Data:PassthroughSubject<RegisterRequest, Never>] = [:]
private var bag = Set<AnyCancellable>()
private var locker = NSLock()
private let logger: SDLLogger
private var rootTask: Task<Void, Error>?
struct RegisterRequest {
let srcMac: Data
let dstMac: Data
let networkId: UInt32
}
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger
@ -103,6 +93,7 @@ public class SDLContext {
self.sessionManager = SessionManager()
self.arpServer = ArpServer(known_macs: [:])
self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger)
self.puncherActor = SDLPuncherActor(logger: logger)
}
public func start() async throws {
@ -297,6 +288,8 @@ public class SDLContext {
private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws {
switch event {
case .ready:
await self.puncherActor.setSuperClientActor(superClientActor: self.superClientActor)
self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug)
var registerSuper = SDLRegisterSuper()
registerSuper.version = UInt32(self.config.version)
@ -413,6 +406,7 @@ public class SDLContext {
private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws {
switch event {
case .ready:
await self.puncherActor.setUDPHoleActor(udpHoleActor: self.udpHoleActor)
//
//self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
self.logger.log("[SDLContext] nat type is: \(self.natType)", level: .debug)
@ -602,60 +596,7 @@ public class SDLContext {
await self.flowTracer.inc(num: data.count, type: .forward)
//
let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID)
self.submitRegisterRequest(request: registerRequest)
}
}
private func submitRegisterRequest(request: RegisterRequest) {
self.locker.lock()
defer {
self.locker.unlock()
}
let dstMac = request.dstMac
if let publisher = self.holerPublishers[dstMac] {
publisher.send(request)
} else {
let publisher = PassthroughSubject<RegisterRequest, Never>()
publisher.throttle(for: .seconds(5), scheduler: DispatchQueue.global(), latest: true)
.sink { request in
Task {
await self.tryHole(request: request)
}
}
.store(in: &self.bag)
self.holerPublishers[dstMac] = publisher
}
}
private func tryHole(request: RegisterRequest) async {
var queryInfo = SDLQueryInfo()
queryInfo.dstMac = request.dstMac
guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else {
return
}
switch message.packet {
case .empty:
self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug)
case .peerInfo(let peerInfo):
if let remoteAddress = peerInfo.v4Info.socketAddress() {
self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug)
// register
var register = SDLRegister()
register.networkID = request.networkId
register.srcMac = request.srcMac
register.dstMac = request.dstMac
await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress)
} else {
self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
}
default:
self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning)
await self.puncherActor.submitRegisterRequest(request: .init(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID))
}
}