逻辑上的正确行调整

This commit is contained in:
anlicheng 2026-01-28 13:05:11 +08:00
parent 6faff2e6cc
commit cbfbbc9ac6
3 changed files with 114 additions and 83 deletions

View File

@ -24,17 +24,10 @@ actor SDLNATProberActor {
private enum State { private enum State {
case idle case idle
case waiting(step: Step) case waiting
case finished case finished
} }
private enum Step: Int {
case step1 = 1
case step2 = 2
case step3 = 3
case step4 = 4
}
private var state: State = .idle private var state: State = .idle
// MARK: - Dependencies // MARK: - Dependencies
@ -54,6 +47,12 @@ actor SDLNATProberActor {
private var cookieId: UInt32 = 1 private var cookieId: UInt32 = 1
private var currentCookieId: UInt32?
// step -> SDLStunProbeReply
private var replies: [UInt32: SDLStunProbeReply] = [:]
private var timeoutTask: Task<Void, Never>?
// MARK: - Init // MARK: - Init
init(udpHole: SDLUDPHoleActor, config: SDLConfiguration, logger: SDLLogger) { init(udpHole: SDLUDPHoleActor, config: SDLConfiguration, logger: SDLLogger) {
@ -64,116 +63,134 @@ actor SDLNATProberActor {
// MARK: - Public API // MARK: - Public API
func probeNatType() async -> NatType {
return await withCheckedContinuation { continuation in
Task {
await self.start { natType in
continuation.resume(returning: natType)
}
}
}
}
/// NAT /// NAT
func start(onFinished: @escaping (NatType) -> Void) async { private func start(onFinished: @escaping (NatType) -> Void) async {
guard case .idle = state else { guard case .idle = state else {
logger.log("[NAT] probe already started", level: .warning)
return return
} }
self.onFinished = onFinished self.onFinished = onFinished
transition(to: .waiting(step: .step1)) transition(to: .waiting)
await self.sendProbe(step: .step1)
let cookieId = self.cookieId
self.cookieId &+= 1
self.currentCookieId = cookieId
await self.sendProbe(cookie: cookieId)
self.timeoutTask = Task {
try? await Task.sleep(nanoseconds: 5_000_000_000)
await self.handleTimeout()
}
} }
/// UDP STUN /// UDP STUN
func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async { func handleProbeReply(from address: SocketAddress, reply: SDLStunProbeReply) async {
guard case .waiting(let currentStep) = state else { guard case .waiting = state, let cookieId = self.currentCookieId, cookieId == reply.cookie else {
return return
} }
switch currentStep { replies[reply.step] = reply
case .step1:
// 退nat
if let step1 = replies[1] {
let localAddress = await self.udpHole.getLocalAddress() let localAddress = await self.udpHole.getLocalAddress()
if address == localAddress { if address == localAddress {
finish(.noNat) finish(.noNat)
return return
} }
}
natAddress1 = address if let step1 = replies[1], let step2 = replies[2] {
transition(to: .waiting(step: .step2))
await self.sendProbe(step: .step2)
case .step2:
natAddress2 = address
// natAddress2 IPIPNAT; // natAddress2 IPIPNAT;
// ip{dstIp, dstPort, srcIp, srcPort}, ip // ip{dstIp, dstPort, srcIp, srcPort}, ip
if let ip1 = natAddress1?.ipAddress, let ip2 = natAddress2?.ipAddress, ip1 != ip2 { if let addr1 = step1.socketAddress(), let addr2 = step2.socketAddress(), addr1 != addr2 {
finish(.symmetric) finish(.symmetric)
return return
} }
}
transition(to: .waiting(step: .step3)) // ,
await self.sendProbe(step: .step3) if replies[1] != nil && replies[2] != nil && replies[3] != nil && replies[4] != nil {
case .step3: // step3: ip2:port2 <---- ip1:port1 (ipport)
// step3: ip1:port1 <---- ip2:port2 (ipport)
// IPNAT // IPNAT
finish(.fullCone) if let step3 = replies[3] {
case .step4: finish(.fullCone)
finish(.coneRestricted) return
}
// step3: ip1:port1 <---- ip1:port2 (port)
// IPNAT
if let step4 = replies[4] {
finish(.coneRestricted)
return
}
} }
} }
/// Timer / Task /// Timer / Task
func handleTimeout() async { private func handleTimeout() async {
guard case .waiting(let currentStep) = state else { guard case .waiting = state else {
return return
} }
switch currentStep { if replies[1] == nil {
case .step3:
transition(to: .waiting(step: .step4))
await sendProbe(step: .step4)
case .step4:
finish(.portRestricted)
default:
finish(.blocked) finish(.blocked)
} else if replies[3] != nil {
finish(.fullCone)
} else if replies[4] != nil {
finish(.coneRestricted)
} else {
finish(.portRestricted)
} }
} }
// MARK: - Internal helpers // MARK: - Internal helpers
private func sendProbe(step: Step) async { private func sendProbe(cookie: UInt32) async {
let addressArray = config.stunProbeSocketAddressArray let addressArray = config.stunProbeSocketAddressArray
let remote: SocketAddress await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 1, attr: .none), remoteAddress: addressArray[0][0])
let attr: SDLProbeAttr await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 2, attr: .none), remoteAddress: addressArray[1][1])
await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 3, attr: .peer), remoteAddress: addressArray[0][0])
switch step { await self.udpHole.send(type: .stunProbe, data: makeProbePacket(cookieId: cookie, step: 4, attr: .port), remoteAddress: addressArray[0][0])
case .step1:
remote = addressArray[0][0]
attr = .none
case .step2:
remote = addressArray[1][1]
attr = .none
case .step3:
remote = addressArray[0][0]
attr = .peer
case .step4:
remote = addressArray[0][0]
attr = .port
}
var stunProbe = SDLStunProbe()
stunProbe.cookie = self.cookieId
stunProbe.attr = UInt32(attr.rawValue)
self.cookieId &+= 1
await self.udpHole.send(type: .stunProbe, data: try! stunProbe.serializedData(), remoteAddress: remote)
} }
private func finish(_ type: NatType) { private func finish(_ type: NatType) {
guard case .finished = state else { guard case .waiting = state else {
transition(to: .finished)
logger.log("[NAT] finished with \(type)", level: .info)
onFinished?(type)
onFinished = nil
return return
} }
transition(to: .finished)
onFinished?(type)
onFinished = nil
//
self.timeoutTask?.cancel()
self.timeoutTask = nil
} }
private func transition(to newState: State) { private func transition(to newState: State) {
state = newState state = newState
} }
private func makeProbePacket(cookieId: UInt32, step: UInt32, attr: SDLProbeAttr) -> Data {
var stunProbe = SDLStunProbe()
stunProbe.cookie = cookieId
stunProbe.step = step
stunProbe.attr = UInt32(attr.rawValue)
return try! stunProbe.serializedData()
}
} }

View File

@ -268,9 +268,7 @@ public class SDLContext {
// nat // nat
if let udpHoleActor = self.udpHoleActor { if let udpHoleActor = self.udpHoleActor {
self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger) self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, config: self.config, logger: self.logger)
await self.proberActor?.start { natType in self.natType = await self.proberActor!.probeNatType()
self.natType = natType
}
} }
var registerSuper = SDLRegisterSuper() var registerSuper = SDLRegisterSuper()

View File

@ -396,6 +396,8 @@ struct SDLStunProbe: Sendable {
var attr: UInt32 = 0 var attr: UInt32 = 0
var step: UInt32 = 0
var unknownFields = SwiftProtobuf.UnknownStorage() var unknownFields = SwiftProtobuf.UnknownStorage()
init() {} init() {}
@ -408,6 +410,8 @@ struct SDLStunProbeReply: Sendable {
var cookie: UInt32 = 0 var cookie: UInt32 = 0
var step: UInt32 = 0
var port: UInt32 = 0 var port: UInt32 = 0
var ip: UInt32 = 0 var ip: UInt32 = 0
@ -1362,6 +1366,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "cookie"), 1: .same(proto: "cookie"),
2: .same(proto: "attr"), 2: .same(proto: "attr"),
3: .same(proto: "step"),
] ]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws { mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -1372,6 +1377,7 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
switch fieldNumber { switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }() case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.attr) }() case 2: try { try decoder.decodeSingularUInt32Field(value: &self.attr) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.step) }()
default: break default: break
} }
} }
@ -1384,12 +1390,16 @@ extension SDLStunProbe: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementat
if self.attr != 0 { if self.attr != 0 {
try visitor.visitSingularUInt32Field(value: self.attr, fieldNumber: 2) try visitor.visitSingularUInt32Field(value: self.attr, fieldNumber: 2)
} }
if self.step != 0 {
try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 3)
}
try unknownFields.traverse(visitor: &visitor) try unknownFields.traverse(visitor: &visitor)
} }
static func ==(lhs: SDLStunProbe, rhs: SDLStunProbe) -> Bool { static func ==(lhs: SDLStunProbe, rhs: SDLStunProbe) -> Bool {
if lhs.cookie != rhs.cookie {return false} if lhs.cookie != rhs.cookie {return false}
if lhs.attr != rhs.attr {return false} if lhs.attr != rhs.attr {return false}
if lhs.step != rhs.step {return false}
if lhs.unknownFields != rhs.unknownFields {return false} if lhs.unknownFields != rhs.unknownFields {return false}
return true return true
} }
@ -1399,8 +1409,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
static let protoMessageName: String = "SDLStunProbeReply" static let protoMessageName: String = "SDLStunProbeReply"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .same(proto: "cookie"), 1: .same(proto: "cookie"),
2: .same(proto: "port"), 2: .same(proto: "step"),
3: .same(proto: "ip"), 3: .same(proto: "port"),
4: .same(proto: "ip"),
] ]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws { mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -1410,8 +1421,9 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
// enabled. https://github.com/apple/swift-protobuf/issues/1034 // enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber { switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }() case 1: try { try decoder.decodeSingularUInt32Field(value: &self.cookie) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.port) }() case 2: try { try decoder.decodeSingularUInt32Field(value: &self.step) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }() case 3: try { try decoder.decodeSingularUInt32Field(value: &self.port) }()
case 4: try { try decoder.decodeSingularUInt32Field(value: &self.ip) }()
default: break default: break
} }
} }
@ -1421,17 +1433,21 @@ extension SDLStunProbeReply: SwiftProtobuf.Message, SwiftProtobuf._MessageImplem
if self.cookie != 0 { if self.cookie != 0 {
try visitor.visitSingularUInt32Field(value: self.cookie, fieldNumber: 1) try visitor.visitSingularUInt32Field(value: self.cookie, fieldNumber: 1)
} }
if self.step != 0 {
try visitor.visitSingularUInt32Field(value: self.step, fieldNumber: 2)
}
if self.port != 0 { if self.port != 0 {
try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 2) try visitor.visitSingularUInt32Field(value: self.port, fieldNumber: 3)
} }
if self.ip != 0 { if self.ip != 0 {
try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 3) try visitor.visitSingularUInt32Field(value: self.ip, fieldNumber: 4)
} }
try unknownFields.traverse(visitor: &visitor) try unknownFields.traverse(visitor: &visitor)
} }
static func ==(lhs: SDLStunProbeReply, rhs: SDLStunProbeReply) -> Bool { static func ==(lhs: SDLStunProbeReply, rhs: SDLStunProbeReply) -> Bool {
if lhs.cookie != rhs.cookie {return false} if lhs.cookie != rhs.cookie {return false}
if lhs.step != rhs.step {return false}
if lhs.port != rhs.port {return false} if lhs.port != rhs.port {return false}
if lhs.ip != rhs.ip {return false} if lhs.ip != rhs.ip {return false}
if lhs.unknownFields != rhs.unknownFields {return false} if lhs.unknownFields != rhs.unknownFields {return false}