diff --git a/Tun/Punchnet/Actors/SDLNATProberActor.swift b/Tun/Punchnet/Actors/SDLNATProberActor.swift index fb4fce1..56e7d0c 100644 --- a/Tun/Punchnet/Actors/SDLNATProberActor.swift +++ b/Tun/Punchnet/Actors/SDLNATProberActor.swift @@ -21,168 +21,153 @@ actor SDLNATProberActor { } // MARK: - Internal State - - private enum State { - case idle - case waiting - case finished + + class ProbeSession { + var cookieId: UInt32 + // 建立step -> SDLStunProbeReply的映射关系 + var replies: [UInt32: SDLStunProbeReply] + var timeoutTask: Task? + var continuation: CheckedContinuation + + private var isFinished: Bool = false + + init(cookieId: UInt32, timeoutTask: Task? = nil, continuation: CheckedContinuation) { + self.cookieId = cookieId + self.replies = [:] + self.timeoutTask = timeoutTask + self.continuation = continuation + } + + func finished(with type: NatType) { + guard !isFinished else { + return + } + + self.continuation.resume(returning: type) + // 取消定时器 + self.timeoutTask?.cancel() + self.isFinished = true + } } - private var state: State = .idle - // MARK: - Dependencies private let udpHole: SDLUDPHoleActor - private let config: SDLConfiguration + private let addressArray: [[SocketAddress]] private let logger: SDLLogger - // MARK: - Probe Data - - private var natAddress1: SocketAddress? - private var natAddress2: SocketAddress? - // MARK: - Completion - - private var onFinished: ((NatType) -> Void)? - private var cookieId: UInt32 = 1 - private var currentCookieId: UInt32? + private var sessions: [UInt32: ProbeSession] = [:] - // 建立step -> SDLStunProbeReply的映射关系 - private var replies: [UInt32: SDLStunProbeReply] = [:] - private var timeoutTask: Task? - // MARK: - Init - init(udpHole: SDLUDPHoleActor, config: SDLConfiguration, logger: SDLLogger) { + init(udpHole: SDLUDPHoleActor, addressArray: [[SocketAddress]], logger: SDLLogger) { self.udpHole = udpHole - self.config = config + self.addressArray = addressArray self.logger = logger } // MARK: - Public API func probeNatType() async -> NatType { + let cookieId = self.cookieId + self.cookieId &+= 1 + return await withCheckedContinuation { continuation in + let timeoutTask = Task { + try? await Task.sleep(nanoseconds: 5_000_000_000) + await self.handleTimeout(cookie: cookieId) + } + + let session = ProbeSession( + cookieId: cookieId, + timeoutTask: timeoutTask, + continuation: continuation + ) + self.sessions[cookieId] = session Task { - await self.start { natType in - continuation.resume(returning: natType) - } + await self.sendProbe(cookie: cookieId) } } } - /// 启动 NAT 探测(一次性) - private func start(onFinished: @escaping (NatType) -> Void) async { - guard case .idle = state else { - return - } - - self.onFinished = onFinished - transition(to: .waiting) - - let cookieId = self.cookieId - self.cookieId &+= 1 - self.currentCookieId = cookieId - - await self.sendProbe(cookie: cookieId) - - self.timeoutTask = Task { - try? await Task.sleep(nanoseconds: 5_000_000_000) - await self.handleTimeout() - } - } - /// UDP 层收到 STUN 响应后调用 func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async { - guard case .waiting = state, let cookieId = self.currentCookieId, cookieId == reply.cookie else { + guard let session = self.sessions[reply.cookie] else { return } - replies[reply.step] = reply + session.replies[reply.step] = reply // 提前退出的情况,没有nat映射 - if let step1 = replies[1] { + if let step1 = session.replies[1] { let localAddress = await self.udpHole.getLocalAddress() if address == localAddress { - finish(.noNat) + finish(cookie: session.cookieId, .noNat) return } } - if let step1 = replies[1], let step2 = replies[2] { + if let step1 = session.replies[1], let step2 = session.replies[2] { // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 if let addr1 = step1.socketAddress(), let addr2 = step2.socketAddress(), addr1 != addr2 { - finish(.symmetric) + finish(cookie: session.cookieId, .symmetric) return } } // 收到了所有的响应, 优先判断 - if replies[1] != nil && replies[2] != nil && replies[3] != nil && replies[4] != nil { + if session.replies[1] != nil && session.replies[2] != nil && session.replies[3] != nil && session.replies[4] != nil { // step3: ip2:port2 <---- ip1:port1 (ip地址和port都变的情况) // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let step3 = replies[3] { - finish(.fullCone) + if let step3 = session.replies[3] { + finish(cookie: session.cookieId, .fullCone) return } // step3: ip1:port1 <---- ip1:port2 (port改变情况) // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - if let step4 = replies[4] { - finish(.coneRestricted) + if let step4 = session.replies[4] { + finish(cookie: session.cookieId, .coneRestricted) return } } } /// 超时事件(由外部 Timer / Task 驱动) - private func handleTimeout() async { - guard case .waiting = state else { + private func handleTimeout(cookie: UInt32) async { + guard let session = self.sessions[cookie] else { return } - if replies[1] == nil { - finish(.blocked) - } else if replies[3] != nil { - finish(.fullCone) - } else if replies[4] != nil { - finish(.coneRestricted) + if session.replies[1] == nil { + finish(cookie: cookie, .blocked) + } else if session.replies[3] != nil { + finish(cookie: cookie, .fullCone) + } else if session.replies[4] != nil { + finish(cookie: cookie, .coneRestricted) } else { - finish(.portRestricted) + finish(cookie: cookie, .portRestricted) + } + } + + private func finish(cookie: UInt32, _ type: NatType) { + if let session = self.sessions.removeValue(forKey: cookie) { + session.finished(with: type) } } // MARK: - Internal helpers private func sendProbe(cookie: UInt32) async { - let addressArray = config.stunProbeSocketAddressArray - await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 1, attr: .none), remoteAddress: addressArray[0][0]) await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 2, attr: .none), remoteAddress: addressArray[1][1]) await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 3, attr: .peer), remoteAddress: addressArray[0][0]) await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 4, attr: .port), remoteAddress: addressArray[0][0]) } - - private func finish(_ type: NatType) { - guard case .waiting = state else { - return - } - - transition(to: .finished) - onFinished?(type) - onFinished = nil - - // 取消定时器 - self.timeoutTask?.cancel() - self.timeoutTask = nil - } - - private func transition(to newState: State) { - state = newState - } private func makeProbePacket(cookieId: UInt32, step: UInt32, attr: SDLProbeAttr) -> Data { var stunProbe = SDLStunProbe() diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index 5d45499..8bd5fd0 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -267,7 +267,7 @@ public class SDLContext { // 开始探测nat的类型 if let udpHoleActor = self.udpHoleActor { - self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger) + self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger) self.natType = await self.proberActor!.probeNatType() }