diff --git a/Tun/Punchnet/Policy/IdentityRuleMap.swift b/Tun/Punchnet/Policy/IdentityRuleMap.swift new file mode 100644 index 0000000..52eebf4 --- /dev/null +++ b/Tun/Punchnet/Policy/IdentityRuleMap.swift @@ -0,0 +1,25 @@ +// +// RuleMap.swift +// punchnet +// +// Created by 安礼成 on 2026/2/5. +// + +struct IdentityRuleMap { + let ruleMap: [UInt32: [UInt32: UInt8]] + + init(ruleMap: [UInt32: [UInt32: UInt8]]) { + self.ruleMap = ruleMap + } + + func isAllow(proto: UInt32, port: UInt32) -> Bool { + if let portMap = self.ruleMap[proto], + let allowed = portMap[port], + allowed > 0 { + return true + } else { + return false + } + } + +} diff --git a/Tun/Punchnet/Policy/IdentitySnapshot.swift b/Tun/Punchnet/Policy/IdentitySnapshot.swift new file mode 100644 index 0000000..2df5f3f --- /dev/null +++ b/Tun/Punchnet/Policy/IdentitySnapshot.swift @@ -0,0 +1,25 @@ +// +// IdentitySnapshot.swift +// punchnet +// +// Created by 安礼成 on 2026/2/5. +// + +final class IdentitySnapshot { + typealias IdentityID = UInt32 + + private let identityMap: [IdentityID: IdentityRuleMap] + + init(identityMap: [IdentityID : IdentityRuleMap]) { + self.identityMap = identityMap + } + + func lookup(_ id: IdentityID) -> IdentityRuleMap? { + return self.identityMap[id] + } + + static func empty() -> IdentitySnapshot { + return IdentitySnapshot(identityMap: [:]) + } + +} diff --git a/Tun/Punchnet/Policy/IdentityStore.swift b/Tun/Punchnet/Policy/IdentityStore.swift new file mode 100644 index 0000000..6fff947 --- /dev/null +++ b/Tun/Punchnet/Policy/IdentityStore.swift @@ -0,0 +1,32 @@ +// +// IdentityStore.swift +// punchnet +// +// Created by 安礼成 on 2026/2/5. +// +import Foundation + +actor IdentityStore { + typealias IdentityID = UInt32 + + private let publisher: SnapshotPublisher + private let identityMap: [IdentityID: IdentityRuleMap] = [:] + + init(publisher: SnapshotPublisher) { + self.publisher = publisher + } + + func apply(_ id: IdentityID, ruleBytes: Data) { + +// if model.affectsRuntime { +// let snapshot = compileSnapshot(from: model) +// publisher.publish(snapshot) +// } + } + +// +// func compileSnapshot() -> IdentitySnapshot { +// +// } +// +} diff --git a/Tun/Punchnet/Policy/SnapshotPublisher.swift b/Tun/Punchnet/Policy/SnapshotPublisher.swift new file mode 100644 index 0000000..935809f --- /dev/null +++ b/Tun/Punchnet/Policy/SnapshotPublisher.swift @@ -0,0 +1,32 @@ +// +// SnapshotPublisher.swift +// punchnet +// +// Created by 安礼成 on 2026/2/5. +// +import Atomics + +final class SnapshotPublisher { + private let atomic: ManagedAtomic> + + init(initial snapshot: IdentitySnapshot) { + self.atomic = ManagedAtomic(.passRetained(snapshot)) + } + + deinit { + let ref = atomic.load(ordering: .relaxed) + ref.release() + } + + func publish(_ snapshot: IdentitySnapshot) { + let newRef = Unmanaged.passRetained(snapshot) + let oldRef = atomic.exchange(newRef, ordering: .acquiring) + oldRef.release() + } + + @inline(__always) + func current() -> IdentitySnapshot { + atomic.load(ordering: .relaxed).takeUnretainedValue() + } + +} diff --git a/Tun/Punchnet/SDLContextActor.swift b/Tun/Punchnet/SDLContextActor.swift index ade4f52..2f41a4e 100644 --- a/Tun/Punchnet/SDLContextActor.swift +++ b/Tun/Punchnet/SDLContextActor.swift @@ -68,9 +68,12 @@ actor SDLContextActor { // 处理内部的需要长时间运行的任务 private var loopChildWorkers: [Task] = [] - private let provider: NEPacketTunnelProvider + // 处理权限控制 + private let identifyStore: IdentityStore + private let snapshotPublisher: SnapshotPublisher + public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) { self.provider = provider self.config = config @@ -82,6 +85,11 @@ actor SDLContextActor { self.puncherActor = SDLPuncherActor(querySocketAddress: config.stunSocketAddress) self.proberActor = SDLNATProberActor(addressArray: config.stunProbeSocketAddressArray) + + // 权限控制 + let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty()) + self.identifyStore = IdentityStore(publisher: snapshotPublisher) + self.snapshotPublisher = snapshotPublisher } public func start() { @@ -479,8 +487,18 @@ actor SDLContextActor { guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else { return } - let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) - self.provider.packetFlow.writePacketObjects([packet]) + + // 检查权限逻辑 + let identitySnapshot = self.snapshotPublisher.current() + if let ruleMap = identitySnapshot.lookup(data.identityID) { + if ruleMap.isAllow(proto: 2, port: 3) { + let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) + self.provider.packetFlow.writePacketObjects([packet]) + } + } else { + // todo 向服务器请求权限逻辑 + + } default: SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug) } diff --git a/Tun/Punchnet/SDLMessage.pb.swift b/Tun/Punchnet/SDLMessage.pb.swift index 0cd6b90..b8f5d46 100644 --- a/Tun/Punchnet/SDLMessage.pb.swift +++ b/Tun/Punchnet/SDLMessage.pb.swift @@ -74,6 +74,8 @@ struct SDLRegisterSuper: @unchecked Sendable { var clientID: String = String() + /// 网络地址信息已经有https请求分配了 + /// 注册的时候需要带上(network_id, mac, ip, mask_len, hostname) var networkID: UInt32 = 0 var mac: Data = Data() @@ -398,6 +400,7 @@ struct SDLStunProbe: Sendable { var attr: UInt32 = 0 + /// 增加step是为了方便端上判断,收到的请求和响应之间的映射关系;服务器端原样返回 var step: UInt32 = 0 var unknownFields = SwiftProtobuf.UnknownStorage() @@ -412,6 +415,7 @@ struct SDLStunProbeReply: Sendable { var cookie: UInt32 = 0 + /// 增加step是为了方便端上判断,收到的请求和响应之间的映射关系;服务器端原样返回 var step: UInt32 = 0 var port: UInt32 = 0 @@ -457,6 +461,53 @@ struct SDLArpResponse: @unchecked Sendable { init() {} } +/// 权限请求查询相关 +struct SDLPolicyRequest: @unchecked Sendable { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + var networkID: UInt32 = 0 + + var srcIdentityID: UInt32 = 0 + + var dstIdentityID: UInt32 = 0 + + var sessionToken: Data = Data() + + var unknownFields = SwiftProtobuf.UnknownStorage() + + init() {} +} + +struct SDLPolicyResponse: @unchecked Sendable { + // SwiftProtobuf.Message conformance is added in an extension below. See the + // `Message` and `Message+*Additions` files in the SwiftProtobuf library for + // methods supported on all messages. + + var networkID: UInt32 = 0 + + var srcIdentityID: UInt32 = 0 + + var dstIdentityID: UInt32 = 0 + + /// 版本号,客户端需要比较版本号确定是否覆盖 + var version: UInt32 = 0 + + /// 总包数 + var totalNum: UInt32 = 0 + + /// 当前分片 + var index: UInt32 = 0 + + /// 4+1+2 的稀疏序列化规则 + var rules: Data = Data() + + var unknownFields = SwiftProtobuf.UnknownStorage() + + init() {} +} + // MARK: - Code below here is support for the SwiftProtobuf runtime. extension SDLV4Info: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { @@ -1534,3 +1585,121 @@ extension SDLArpResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplement return true } } + +extension SDLPolicyRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + static let protoMessageName: String = "SDLPolicyRequest" + static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "network_id"), + 2: .standard(proto: "src_identity_id"), + 3: .standard(proto: "dst_identity_id"), + 4: .standard(proto: "session_token"), + ] + + mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }() + case 3: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }() + case 4: try { try decoder.decodeSingularBytesField(value: &self.sessionToken) }() + default: break + } + } + } + + func traverse(visitor: inout V) throws { + if self.networkID != 0 { + try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 1) + } + if self.srcIdentityID != 0 { + try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 2) + } + if self.dstIdentityID != 0 { + try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 3) + } + if !self.sessionToken.isEmpty { + try visitor.visitSingularBytesField(value: self.sessionToken, fieldNumber: 4) + } + try unknownFields.traverse(visitor: &visitor) + } + + static func ==(lhs: SDLPolicyRequest, rhs: SDLPolicyRequest) -> Bool { + if lhs.networkID != rhs.networkID {return false} + if lhs.srcIdentityID != rhs.srcIdentityID {return false} + if lhs.dstIdentityID != rhs.dstIdentityID {return false} + if lhs.sessionToken != rhs.sessionToken {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +} + +extension SDLPolicyResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { + static let protoMessageName: String = "SDLPolicyResponse" + static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ + 1: .standard(proto: "network_id"), + 2: .standard(proto: "src_identity_id"), + 3: .standard(proto: "dst_identity_id"), + 4: .same(proto: "version"), + 5: .standard(proto: "total_num"), + 6: .same(proto: "index"), + 7: .same(proto: "rules"), + ] + + mutating func decodeMessage(decoder: inout D) throws { + while let fieldNumber = try decoder.nextFieldNumber() { + // The use of inline closures is to circumvent an issue where the compiler + // allocates stack space for every case branch when no optimizations are + // enabled. https://github.com/apple/swift-protobuf/issues/1034 + switch fieldNumber { + case 1: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }() + case 3: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }() + case 4: try { try decoder.decodeSingularUInt32Field(value: &self.version) }() + case 5: try { try decoder.decodeSingularUInt32Field(value: &self.totalNum) }() + case 6: try { try decoder.decodeSingularUInt32Field(value: &self.index) }() + case 7: try { try decoder.decodeSingularBytesField(value: &self.rules) }() + default: break + } + } + } + + func traverse(visitor: inout V) throws { + if self.networkID != 0 { + try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 1) + } + if self.srcIdentityID != 0 { + try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 2) + } + if self.dstIdentityID != 0 { + try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 3) + } + if self.version != 0 { + try visitor.visitSingularUInt32Field(value: self.version, fieldNumber: 4) + } + if self.totalNum != 0 { + try visitor.visitSingularUInt32Field(value: self.totalNum, fieldNumber: 5) + } + if self.index != 0 { + try visitor.visitSingularUInt32Field(value: self.index, fieldNumber: 6) + } + if !self.rules.isEmpty { + try visitor.visitSingularBytesField(value: self.rules, fieldNumber: 7) + } + try unknownFields.traverse(visitor: &visitor) + } + + static func ==(lhs: SDLPolicyResponse, rhs: SDLPolicyResponse) -> Bool { + if lhs.networkID != rhs.networkID {return false} + if lhs.srcIdentityID != rhs.srcIdentityID {return false} + if lhs.dstIdentityID != rhs.dstIdentityID {return false} + if lhs.version != rhs.version {return false} + if lhs.totalNum != rhs.totalNum {return false} + if lhs.index != rhs.index {return false} + if lhs.rules != rhs.rules {return false} + if lhs.unknownFields != rhs.unknownFields {return false} + return true + } +}