逻辑上的正确行调整

This commit is contained in:
anlicheng 2026-01-28 13:05:11 +08:00
parent 6faff2e6cc
commit cbfbbc9ac6
3 changed files with 114 additions and 83 deletions

View File

@ -24,17 +24,10 @@ actor SDLNATProberActor {
private enum State {
case idle
case waiting(step: Step)
case waiting
case finished
}
private enum Step: Int {
case step1 = 1
case step2 = 2
case step3 = 3
case step4 = 4
}
private var state: State = .idle
// MARK: - Dependencies
@ -53,6 +46,12 @@ actor SDLNATProberActor {
private var onFinished: ((NatType) -> Void)?
private var cookieId: UInt32 = 1
private var currentCookieId: UInt32?
// step -> SDLStunProbeReply
private var replies: [UInt32: SDLStunProbeReply] = [:]
private var timeoutTask: Task<Void, Never>?
// MARK: - Init
@ -63,117 +62,135 @@ actor SDLNATProberActor {
}
// MARK: - Public API
func probeNatType() async -> NatType {
return await withCheckedContinuation { continuation in
Task {
await self.start { natType in
continuation.resume(returning: natType)
}
}
}
}
/// NAT
func start(onFinished: @escaping (NatType) -> Void) async {
private func start(onFinished: @escaping (NatType) -> Void) async {
guard case .idle = state else {
logger.log("[NAT] probe already started", level: .warning)
return
}
self.onFinished = onFinished
transition(to: .waiting(step: .step1))
await self.sendProbe(step: .step1)
transition(to: .waiting)
let cookieId = self.cookieId
self.cookieId &+= 1
self.currentCookieId = cookieId
await self.sendProbe(cookie: cookieId)
self.timeoutTask = Task {
try? await Task.sleep(nanoseconds: 5_000_000_000)
await self.handleTimeout()
}
}
/// UDP STUN
func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async {
guard case .waiting(let currentStep) = state else {
guard case .waiting = state, let cookieId = self.currentCookieId, cookieId == reply.cookie else {
return
}
switch currentStep {
case .step1:
replies[reply.step] = reply
// 退nat
if let step1 = replies[1] {
let localAddress = await self.udpHole.getLocalAddress()
if address == localAddress {
finish(.noNat)
return
}
natAddress1 = address
transition(to: .waiting(step: .step2))
await self.sendProbe(step: .step2)
case .step2:
natAddress2 = address
}
if let step1 = replies[1], let step2 = replies[2] {
// natAddress2 IPIPNAT;
// ip{dstIp, dstPort, srcIp, srcPort}, ip
if let ip1 = natAddress1?.ipAddress, let ip2 = natAddress2?.ipAddress, ip1 != ip2 {
if let addr1 = step1.socketAddress(), let addr2 = step2.socketAddress(), addr1 != addr2 {
finish(.symmetric)
return
}
transition(to: .waiting(step: .step3))
await self.sendProbe(step: .step3)
case .step3:
// step3: ip1:port1 <---- ip2:port2 (ipport)
}
// ,
if replies[1] != nil && replies[2] != nil && replies[3] != nil && replies[4] != nil {
// step3: ip2:port2 <---- ip1:port1 (ipport)
// IPNAT
finish(.fullCone)
case .step4:
finish(.coneRestricted)
if let step3 = replies[3] {
finish(.fullCone)
return
}
// step3: ip1:port1 <---- ip1:port2 (port)
// IPNAT
if let step4 = replies[4] {
finish(.coneRestricted)
return
}
}
}
/// Timer / Task
func handleTimeout() async {
guard case .waiting(let currentStep) = state else {
private func handleTimeout() async {
guard case .waiting = state else {
return
}
switch currentStep {
case .step3:
transition(to: .waiting(step: .step4))
await sendProbe(step: .step4)
case .step4:
finish(.portRestricted)
default:
if replies[1] == nil {
finish(.blocked)
} else if replies[3] != nil {
finish(.fullCone)
} else if replies[4] != nil {
finish(.coneRestricted)
} else {
finish(.portRestricted)
}
}
// MARK: - Internal helpers
private func sendProbe(step: Step) async {
private func sendProbe(cookie: UInt32) async {
let addressArray = config.stunProbeSocketAddressArray
let remote: SocketAddress
let attr: SDLProbeAttr
switch step {
case .step1:
remote = addressArray[0][0]
attr = .none
case .step2:
remote = addressArray[1][1]
attr = .none
case .step3:
remote = addressArray[0][0]
attr = .peer
case .step4:
remote = addressArray[0][0]
attr = .port
}
var stunProbe = SDLStunProbe()
stunProbe.cookie = self.cookieId
stunProbe.attr = UInt32(attr.rawValue)
self.cookieId &+= 1
await self.udpHole.send(type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remote)
await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 1, attr: .none), remoteAddress: addressArray[0][0])
await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 2, attr: .none), remoteAddress: addressArray[1][1])
await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 3, attr: .peer), remoteAddress: addressArray[0][0])
await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 4, attr: .port), remoteAddress: addressArray[0][0])
}
private func finish(_ type: NatType) {
guard case .finished = state else {
transition(to: .finished)
logger.log("[NAT] finished with \(type)", level: .info)
onFinished?(type)
onFinished = nil
guard case .waiting = state else {
return
}
transition(to: .finished)
onFinished?(type)
onFinished = nil
//
self.timeoutTask?.cancel()
self.timeoutTask = nil
}
private func transition(to newState: State) {
state = newState
}
private func makeProbePacket(cookieId: UInt32, step: UInt32, attr: SDLProbeAttr) -> Data {
var stunProbe = SDLStunProbe()
stunProbe.cookie = cookieId
stunProbe.step = step
stunProbe.attr = UInt32(attr.rawValue)
return try! stunProbe.serializedData()
}
}

View File

@ -268,9 +268,7 @@ public class SDLContext {
// nat
if let udpHoleActor = self.udpHoleActor {
self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger)
await self.proberActor?.start { natType in
self.natType = natType
}
self.natType = await self.proberActor!.probeNatType()
}
var registerSuper = SDLRegisterSuper()

View File

@ -396,6 +396,8 @@ struct SDLStunProbe: Sendable {
var attr: UInt32 = 0
var step: UInt32 = 0
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
@ -408,6 +410,8 @@ struct SDLStunProbeReply: Sendable {
var cookie: UInt32 = 0
var step: UInt32 = 0
var port: UInt32 = 0
var ip: UInt32 = 0
@ -1362,6 +1366,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "cookie"),
2: .same(proto: "attr"),
3: .same(proto: "step"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -1372,6 +1377,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.attr) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.step) }()
default: break
}
}
@ -1384,12 +1390,16 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
if self.attr != 0 {
try visitor.visitSingularUInt32Field(value: self.attr, fieldNumber: 2)
}
if self.step != 0 {
try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor)
}
static func ==(lhs: SDLStunProbe, rhs: SDLStunProbe) -> Bool {
if lhs.cookie != rhs.cookie {return false}
if lhs.attr != rhs.attr {return false}
if lhs.step != rhs.step {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
@ -1399,8 +1409,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
static let protoMessageName: String = "SDLStunProbeReply"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "cookie"),
2: .same(proto: "port"),
3: .same(proto: "ip"),
2: .same(proto: "step"),
3: .same(proto: "port"),
4: .same(proto: "ip"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -1410,8 +1421,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.port) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.step) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.port) }()
case 4: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }()
default: break
}
}
@ -1421,17 +1433,21 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
if self.cookie != 0 {
try visitor.visitSingularUInt32Field(value: self.cookie, fieldNumber: 1)
}
if self.step != 0 {
try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 2)
}
if self.port != 0 {
try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 2)
try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 3)
}
if self.ip != 0 {
try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 3)
try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 4)
}
try unknownFields.traverse(visitor: &visitor)
}
static func ==(lhs: SDLStunProbeReply, rhs: SDLStunProbeReply) -> Bool {
if lhs.cookie != rhs.cookie {return false}
if lhs.step != rhs.step {return false}
if lhs.port != rhs.port {return false}
if lhs.ip != rhs.ip {return false}
if lhs.unknownFields != rhs.unknownFields {return false}