fix policy

This commit is contained in:
anlicheng 2026-02-05 15:25:22 +08:00
parent e79c3270ea
commit 3947c1f6da
6 changed files with 304 additions and 3 deletions

View File

@ -0,0 +1,25 @@
//
// RuleMap.swift
// punchnet
//
// Created by on 2026/2/5.
//
struct IdentityRuleMap {
let ruleMap: [UInt32: [UInt32: UInt8]]
init(ruleMap: [UInt32: [UInt32: UInt8]]) {
self.ruleMap = ruleMap
}
func isAllow(proto: UInt32, port: UInt32) -> Bool {
if let portMap = self.ruleMap[proto],
let allowed = portMap[port],
allowed > 0 {
return true
} else {
return false
}
}
}

View File

@ -0,0 +1,25 @@
//
// IdentitySnapshot.swift
// punchnet
//
// Created by on 2026/2/5.
//
final class IdentitySnapshot {
typealias IdentityID = UInt32
private let identityMap: [IdentityID: IdentityRuleMap]
init(identityMap: [IdentityID : IdentityRuleMap]) {
self.identityMap = identityMap
}
func lookup(_ id: IdentityID) -> IdentityRuleMap? {
return self.identityMap[id]
}
static func empty() -> IdentitySnapshot {
return IdentitySnapshot(identityMap: [:])
}
}

View File

@ -0,0 +1,32 @@
//
// IdentityStore.swift
// punchnet
//
// Created by on 2026/2/5.
//
import Foundation
actor IdentityStore {
typealias IdentityID = UInt32
private let publisher: SnapshotPublisher<IdentitySnapshot>
private let identityMap: [IdentityID: IdentityRuleMap] = [:]
init(publisher: SnapshotPublisher<IdentitySnapshot>) {
self.publisher = publisher
}
func apply(_ id: IdentityID, ruleBytes: Data) {
// if model.affectsRuntime {
// let snapshot = compileSnapshot(from: model)
// publisher.publish(snapshot)
// }
}
//
// func compileSnapshot() -> IdentitySnapshot {
//
// }
//
}

View File

@ -0,0 +1,32 @@
//
// SnapshotPublisher.swift
// punchnet
//
// Created by on 2026/2/5.
//
import Atomics
final class SnapshotPublisher<IdentitySnapshot: AnyObject> {
private let atomic: ManagedAtomic<Unmanaged<IdentitySnapshot>>
init(initial snapshot: IdentitySnapshot) {
self.atomic = ManagedAtomic(.passRetained(snapshot))
}
deinit {
let ref = atomic.load(ordering: .relaxed)
ref.release()
}
func publish(_ snapshot: IdentitySnapshot) {
let newRef = Unmanaged.passRetained(snapshot)
let oldRef = atomic.exchange(newRef, ordering: .acquiring)
oldRef.release()
}
@inline(__always)
func current() -> IdentitySnapshot {
atomic.load(ordering: .relaxed).takeUnretainedValue()
}
}

View File

@ -68,9 +68,12 @@ actor SDLContextActor {
//
private var loopChildWorkers: [Task<Void, Never>] = []
private let provider: NEPacketTunnelProvider
//
private let identifyStore: IdentityStore
private let snapshotPublisher: SnapshotPublisher<IdentitySnapshot>
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) {
self.provider = provider
self.config = config
@ -82,6 +85,11 @@ actor SDLContextActor {
self.puncherActor = SDLPuncherActor(querySocketAddress: config.stunSocketAddress)
self.proberActor = SDLNATProberActor(addressArray: config.stunProbeSocketAddressArray)
//
let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty())
self.identifyStore = IdentityStore(publisher: snapshotPublisher)
self.snapshotPublisher = snapshotPublisher
}
public func start() {
@ -479,8 +487,18 @@ actor SDLContextActor {
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
//
let identitySnapshot = self.snapshotPublisher.current()
if let ruleMap = identitySnapshot.lookup(data.identityID) {
if ruleMap.isAllow(proto: 2, port: 3) {
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
}
} else {
// todo
}
default:
SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug)
}

View File

@ -74,6 +74,8 @@ struct SDLRegisterSuper: @unchecked Sendable {
var clientID: String = String()
/// https
/// (network_id, mac, ip, mask_len, hostname)
var networkID: UInt32 = 0
var mac: Data = Data()
@ -398,6 +400,7 @@ struct SDLStunProbe: Sendable {
var attr: UInt32 = 0
/// step便
var step: UInt32 = 0
var unknownFields = SwiftProtobuf.UnknownStorage()
@ -412,6 +415,7 @@ struct SDLStunProbeReply: Sendable {
var cookie: UInt32 = 0
/// step便
var step: UInt32 = 0
var port: UInt32 = 0
@ -457,6 +461,53 @@ struct SDLArpResponse: @unchecked Sendable {
init() {}
}
///
struct SDLPolicyRequest: @unchecked Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
var networkID: UInt32 = 0
var srcIdentityID: UInt32 = 0
var dstIdentityID: UInt32 = 0
var sessionToken: Data = Data()
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
}
struct SDLPolicyResponse: @unchecked Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.
var networkID: UInt32 = 0
var srcIdentityID: UInt32 = 0
var dstIdentityID: UInt32 = 0
///
var version: UInt32 = 0
///
var totalNum: UInt32 = 0
///
var index: UInt32 = 0
/// 4+1+2
var rules: Data = Data()
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
}
// MARK: - Code below here is support for the SwiftProtobuf runtime.
extension SDLV4Info: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
@ -1534,3 +1585,121 @@ extension SDLArpResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplement
return true
}
}
extension SDLPolicyRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = "SDLPolicyRequest"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "network_id"),
2: .standard(proto: "src_identity_id"),
3: .standard(proto: "dst_identity_id"),
4: .standard(proto: "session_token"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }()
case 4: try { try decoder.decodeSingularBytesField(value: &self.sessionToken) }()
default: break
}
}
}
func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if self.networkID != 0 {
try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 1)
}
if self.srcIdentityID != 0 {
try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 2)
}
if self.dstIdentityID != 0 {
try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 3)
}
if !self.sessionToken.isEmpty {
try visitor.visitSingularBytesField(value: self.sessionToken, fieldNumber: 4)
}
try unknownFields.traverse(visitor: &visitor)
}
static func ==(lhs: SDLPolicyRequest, rhs: SDLPolicyRequest) -> Bool {
if lhs.networkID != rhs.networkID {return false}
if lhs.srcIdentityID != rhs.srcIdentityID {return false}
if lhs.dstIdentityID != rhs.dstIdentityID {return false}
if lhs.sessionToken != rhs.sessionToken {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}
extension SDLPolicyResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = "SDLPolicyResponse"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
1: .standard(proto: "network_id"),
2: .standard(proto: "src_identity_id"),
3: .standard(proto: "dst_identity_id"),
4: .same(proto: "version"),
5: .standard(proto: "total_num"),
6: .same(proto: "index"),
7: .same(proto: "rules"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 1: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }()
case 2: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }()
case 4: try { try decoder.decodeSingularUInt32Field(value: &self.version) }()
case 5: try { try decoder.decodeSingularUInt32Field(value: &self.totalNum) }()
case 6: try { try decoder.decodeSingularUInt32Field(value: &self.index) }()
case 7: try { try decoder.decodeSingularBytesField(value: &self.rules) }()
default: break
}
}
}
func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
if self.networkID != 0 {
try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 1)
}
if self.srcIdentityID != 0 {
try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 2)
}
if self.dstIdentityID != 0 {
try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 3)
}
if self.version != 0 {
try visitor.visitSingularUInt32Field(value: self.version, fieldNumber: 4)
}
if self.totalNum != 0 {
try visitor.visitSingularUInt32Field(value: self.totalNum, fieldNumber: 5)
}
if self.index != 0 {
try visitor.visitSingularUInt32Field(value: self.index, fieldNumber: 6)
}
if !self.rules.isEmpty {
try visitor.visitSingularBytesField(value: self.rules, fieldNumber: 7)
}
try unknownFields.traverse(visitor: &visitor)
}
static func ==(lhs: SDLPolicyResponse, rhs: SDLPolicyResponse) -> Bool {
if lhs.networkID != rhs.networkID {return false}
if lhs.srcIdentityID != rhs.srcIdentityID {return false}
if lhs.dstIdentityID != rhs.dstIdentityID {return false}
if lhs.version != rhs.version {return false}
if lhs.totalNum != rhs.totalNum {return false}
if lhs.index != rhs.index {return false}
if lhs.rules != rhs.rules {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}