diff --git a/Tun/Punchnet/Actors/SDLNATProberActor.swift b/Tun/Punchnet/Actors/SDLNATProberActor.swift index 0209dc3..fb4fce1 100644 --- a/Tun/Punchnet/Actors/SDLNATProberActor.swift +++ b/Tun/Punchnet/Actors/SDLNATProberActor.swift @@ -24,17 +24,10 @@ actor SDLNATProberActor { private enum State { case idle - case waiting(step: Step) + case waiting case finished } - private enum Step: Int { - case step1 = 1 - case step2 = 2 - case step3 = 3 - case step4 = 4 - } - private var state: State = .idle // MARK: - Dependencies @@ -53,6 +46,12 @@ actor SDLNATProberActor { private var onFinished: ((NatType) -> Void)? private var cookieId: UInt32 = 1 + + private var currentCookieId: UInt32? + + // 建立step -> SDLStunProbeReply的映射关系 + private var replies: [UInt32: SDLStunProbeReply] = [:] + private var timeoutTask: Task? // MARK: - Init @@ -63,117 +62,135 @@ actor SDLNATProberActor { } // MARK: - Public API - + + func probeNatType() async -> NatType { + return await withCheckedContinuation { continuation in + Task { + await self.start { natType in + continuation.resume(returning: natType) + } + } + } + } + /// 启动 NAT 探测(一次性) - func start(onFinished: @escaping (NatType) -> Void) async { + private func start(onFinished: @escaping (NatType) -> Void) async { guard case .idle = state else { - logger.log("[NAT] probe already started", level: .warning) return } self.onFinished = onFinished - transition(to: .waiting(step: .step1)) - await self.sendProbe(step: .step1) + transition(to: .waiting) + + let cookieId = self.cookieId + self.cookieId &+= 1 + self.currentCookieId = cookieId + + await self.sendProbe(cookie: cookieId) + + self.timeoutTask = Task { + try? await Task.sleep(nanoseconds: 5_000_000_000) + await self.handleTimeout() + } } /// UDP 层收到 STUN 响应后调用 func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async { - guard case .waiting(let currentStep) = state else { + guard case .waiting = state, let cookieId = self.currentCookieId, cookieId == reply.cookie else { return } - - switch currentStep { - case .step1: + + replies[reply.step] = reply + + // 提前退出的情况,没有nat映射 + if let step1 = replies[1] { let localAddress = await self.udpHole.getLocalAddress() if address == localAddress { finish(.noNat) return } - - natAddress1 = address - transition(to: .waiting(step: .step2)) - await self.sendProbe(step: .step2) - - case .step2: - natAddress2 = address + } + + if let step1 = replies[1], let step2 = replies[2] { // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 - if let ip1 = natAddress1?.ipAddress, let ip2 = natAddress2?.ipAddress, ip1 != ip2 { + if let addr1 = step1.socketAddress(), let addr2 = step2.socketAddress(), addr1 != addr2 { finish(.symmetric) return } - - transition(to: .waiting(step: .step3)) - await self.sendProbe(step: .step3) - case .step3: - // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) + } + + // 收到了所有的响应, 优先判断 + if replies[1] != nil && replies[2] != nil && replies[3] != nil && replies[4] != nil { + // step3: ip2:port2 <---- ip1:port1 (ip地址和port都变的情况) // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 - finish(.fullCone) - case .step4: - finish(.coneRestricted) + if let step3 = replies[3] { + finish(.fullCone) + return + } + + // step3: ip1:port1 <---- ip1:port2 (port改变情况) + // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 + if let step4 = replies[4] { + finish(.coneRestricted) + return + } } } - + /// 超时事件(由外部 Timer / Task 驱动) - func handleTimeout() async { - guard case .waiting(let currentStep) = state else { + private func handleTimeout() async { + guard case .waiting = state else { return } - switch currentStep { - case .step3: - transition(to: .waiting(step: .step4)) - await sendProbe(step: .step4) - case .step4: - finish(.portRestricted) - default: + if replies[1] == nil { finish(.blocked) + } else if replies[3] != nil { + finish(.fullCone) + } else if replies[4] != nil { + finish(.coneRestricted) + } else { + finish(.portRestricted) } } // MARK: - Internal helpers - private func sendProbe(step: Step) async { + private func sendProbe(cookie: UInt32) async { let addressArray = config.stunProbeSocketAddressArray - - let remote: SocketAddress - let attr: SDLProbeAttr - - switch step { - case .step1: - remote = addressArray[0][0] - attr = .none - case .step2: - remote = addressArray[1][1] - attr = .none - case .step3: - remote = addressArray[0][0] - attr = .peer - case .step4: - remote = addressArray[0][0] - attr = .port - } - - var stunProbe = SDLStunProbe() - stunProbe.cookie = self.cookieId - stunProbe.attr = UInt32(attr.rawValue) - self.cookieId &+= 1 - - await self.udpHole.send(type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remote) + await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 1, attr: .none), remoteAddress: addressArray[0][0]) + await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 2, attr: .none), remoteAddress: addressArray[1][1]) + await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 3, attr: .peer), remoteAddress: addressArray[0][0]) + await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 4, attr: .port), remoteAddress: addressArray[0][0]) } private func finish(_ type: NatType) { - guard case .finished = state else { - transition(to: .finished) - logger.log("[NAT] finished with \(type)", level: .info) - onFinished?(type) - onFinished = nil + guard case .waiting = state else { return } + + transition(to: .finished) + onFinished?(type) + onFinished = nil + + // 取消定时器 + self.timeoutTask?.cancel() + self.timeoutTask = nil } private func transition(to newState: State) { state = newState } + + private func makeProbePacket(cookieId: UInt32, step: UInt32, attr: SDLProbeAttr) -> Data { + var stunProbe = SDLStunProbe() + stunProbe.cookie = cookieId + stunProbe.step = step + stunProbe.attr = UInt32(attr.rawValue) + + return try! stunProbe.serializedData() + } + } diff --git a/Tun/Punchnet/SDLContext.swift b/Tun/Punchnet/SDLContext.swift index fdf7dd8..5d45499 100644 --- a/Tun/Punchnet/SDLContext.swift +++ b/Tun/Punchnet/SDLContext.swift @@ -268,9 +268,7 @@ public class SDLContext { // 开始探测nat的类型 if let udpHoleActor = self.udpHoleActor { self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger) - await self.proberActor?.start { natType in - self.natType = natType - } + self.natType = await self.proberActor!.probeNatType() } var registerSuper = SDLRegisterSuper() diff --git a/Tun/Punchnet/SDLMessage.pb.swift b/Tun/Punchnet/SDLMessage.pb.swift index 9c11032..f07d3be 100644 --- a/Tun/Punchnet/SDLMessage.pb.swift +++ b/Tun/Punchnet/SDLMessage.pb.swift @@ -396,6 +396,8 @@ struct SDLStunProbe: Sendable { var attr: UInt32 = 0 + var step: UInt32 = 0 + var unknownFields = SwiftProtobuf.UnknownStorage() init() {} @@ -408,6 +410,8 @@ struct SDLStunProbeReply: Sendable { var cookie: UInt32 = 0 + var step: UInt32 = 0 + var port: UInt32 = 0 var ip: UInt32 = 0 @@ -1362,6 +1366,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .same(proto: "cookie"), 2: .same(proto: "attr"), + 3: .same(proto: "step"), ] mutating func decodeMessage(decoder: inout D) throws { @@ -1372,6 +1377,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat switch fieldNumber { case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }() case 2: try { try decoder.decodeSingularUInt32Field(value: &self.attr) }() + case 3: try { try decoder.decodeSingularUInt32Field(value: &self.step) }() default: break } } @@ -1384,12 +1390,16 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat if self.attr != 0 { try visitor.visitSingularUInt32Field(value: self.attr, fieldNumber: 2) } + if self.step != 0 { + try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 3) + } try unknownFields.traverse(visitor: &visitor) } static func ==(lhs: SDLStunProbe, rhs: SDLStunProbe) -> Bool { if lhs.cookie != rhs.cookie {return false} if lhs.attr != rhs.attr {return false} + if lhs.step != rhs.step {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true } @@ -1399,8 +1409,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem static let protoMessageName: String = "SDLStunProbeReply" static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ 1: .same(proto: "cookie"), - 2: .same(proto: "port"), - 3: .same(proto: "ip"), + 2: .same(proto: "step"), + 3: .same(proto: "port"), + 4: .same(proto: "ip"), ] mutating func decodeMessage(decoder: inout D) throws { @@ -1410,8 +1421,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }() - case 2: try { try decoder.decodeSingularUInt32Field(value: &self.port) }() - case 3: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.step) }() + case 3: try { try decoder.decodeSingularUInt32Field(value: &self.port) }() + case 4: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }() default: break } } @@ -1421,17 +1433,21 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem if self.cookie != 0 { try visitor.visitSingularUInt32Field(value: self.cookie, fieldNumber: 1) } + if self.step != 0 { + try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 2) + } if self.port != 0 { - try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 2) + try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 3) } if self.ip != 0 { - try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 3) + try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 4) } try unknownFields.traverse(visitor: &visitor) } static func ==(lhs: SDLStunProbeReply, rhs: SDLStunProbeReply) -> Bool { if lhs.cookie != rhs.cookie {return false} + if lhs.step != rhs.step {return false} if lhs.port != rhs.port {return false} if lhs.ip != rhs.ip {return false} if lhs.unknownFields != rhs.unknownFields {return false}