diff --git a/Tun/Punchnet/Policy/IdentityRuleMap.swift b/Tun/Punchnet/Policy/IdentityRuleMap.swift index 52eebf4..f0bfd29 100644 --- a/Tun/Punchnet/Policy/IdentityRuleMap.swift +++ b/Tun/Punchnet/Policy/IdentityRuleMap.swift @@ -6,17 +6,17 @@ // struct IdentityRuleMap { - let ruleMap: [UInt32: [UInt32: UInt8]] + // map[proto][port] + let ruleMap: [UInt8: [UInt16: Bool]] - init(ruleMap: [UInt32: [UInt32: UInt8]]) { + init(ruleMap: [UInt8: [UInt16: Bool]]) { self.ruleMap = ruleMap } - func isAllow(proto: UInt32, port: UInt32) -> Bool { + func isAllow(proto: UInt8, port: UInt16) -> Bool { if let portMap = self.ruleMap[proto], - let allowed = portMap[port], - allowed > 0 { - return true + let allowed = portMap[port] { + return allowed } else { return false } diff --git a/Tun/Punchnet/Policy/IdentityStore.swift b/Tun/Punchnet/Policy/IdentityStore.swift index 6fff947..157135b 100644 --- a/Tun/Punchnet/Policy/IdentityStore.swift +++ b/Tun/Punchnet/Policy/IdentityStore.swift @@ -5,28 +5,106 @@ // Created by 安礼成 on 2026/2/5. // import Foundation +import NIO + +final class IdentitySession { + var version: UInt32 + var totalNum: UInt32 + private var parts: [UInt32: SDLPolicyResponse] = [:] + + init(part: SDLPolicyResponse) { + self.version = part.version + self.totalNum = part.totalNum + self.parts[part.index] = part + } + + func merge(part: SDLPolicyResponse) { + if part.version < version { + // 低版本数据丢弃 + } else if part.version == version { + self.parts[part.index] = part + } else { + self.parts.removeAll() + self.parts[part.index] = part + } + } + + func process() -> Data? { + // parts是连续的,从0开始,并且数量等于total_num + let indexs = parts.keys.sorted().map { UInt32($0) } + guard indexs.count == self.totalNum && isContinuousFromZero(indexs: indexs) else { + return nil + } + + var rulesData: Data = Data() + for i in 0.. Bool { + guard !indexs.isEmpty else { + return false + } + + return indexs.enumerated().allSatisfy { idx, value in + idx == value + } + } + +} actor IdentityStore { typealias IdentityID = UInt32 + nonisolated private let alloctor = ByteBufferAllocator() + private let publisher: SnapshotPublisher - private let identityMap: [IdentityID: IdentityRuleMap] = [:] + private var identityMap: [IdentityID: IdentityRuleMap] = [:] + private var sessions: [IdentityID: IdentitySession] = [:] init(publisher: SnapshotPublisher) { self.publisher = publisher } - func apply(_ id: IdentityID, ruleBytes: Data) { + func apply(policyResponse: SDLPolicyResponse) { + let id = policyResponse.srcIdentityID + let session = self.sessions[id, default: IdentitySession(part: policyResponse)] + session.merge(part: policyResponse) -// if model.affectsRuntime { -// let snapshot = compileSnapshot(from: model) -// publisher.publish(snapshot) -// } + // 判断一下是否接受完成 + if let rulesData = session.process() { + var buffer = alloctor.buffer(bytes: rulesData) + var ruleMap: [UInt8: [UInt16: Bool]] = [:] + while true { + guard let proto = buffer.readInteger(endianness: .big, as: UInt8.self), + let port = buffer.readInteger(endianness: .big, as: UInt16.self) else { + break + } + ruleMap[proto, default: [:]][port] = true + } + self.identityMap[id] = IdentityRuleMap(ruleMap: ruleMap) + + // 删除当前的session信息 + self.sessions.removeValue(forKey: id) + + SDLLogger.shared.log("[IdentitySession] get compile Snapshot rules nums: \(self.identityMap[id]?.ruleMap.count), success: \(self.identityMap[id]?.isAllow(proto: 1, port: 80))") + + // 发布新的快照信息 + let snapshot = compileSnapshot() + publisher.publish(snapshot) + } else { + self.sessions[id] = session + } } -// -// func compileSnapshot() -> IdentitySnapshot { -// -// } -// + private func compileSnapshot() -> IdentitySnapshot { + return IdentitySnapshot(identityMap: identityMap) + } + } diff --git a/Tun/Punchnet/Policy/PolicyRequesterActor.swift b/Tun/Punchnet/Policy/PolicyRequesterActor.swift new file mode 100644 index 0000000..4a0674f --- /dev/null +++ b/Tun/Punchnet/Policy/PolicyRequesterActor.swift @@ -0,0 +1,56 @@ +// +// SDLPuncherActor.swift +// Tun +// +// Created by 安礼成 on 2026/1/7. +// + +import Foundation +import NIOCore + +actor PolicyRequesterActor { + nonisolated private let cooldown: Duration = .seconds(5) + + // identityId + private var coolingDown: Set = [] + + // 处理各个请求的版本问题, map[identityId] = version + private var versions: [UInt32: UInt32] = [:] + // 处理holer + nonisolated private let querySocketAddress: SocketAddress + + init(querySocketAddress: SocketAddress) { + self.querySocketAddress = querySocketAddress + } + + // 提交权限请求 + func submitPolicyRequest(using udpHole: SDLUDPHole?, request: inout SDLPolicyRequest) { + let identityId = request.srcIdentityID + guard let udpHole, !coolingDown.contains(identityId) else { + return + } + + // 触发一次打洞 + coolingDown.insert(identityId) + + let version = self.versions[identityId, default: 1] + request.version = version + // 更新请求的版本问题 + self.versions[identityId] = version + 1 + // 发送请求 + if let queryData = try? request.serializedData() { + udpHole.send(type: .policyRequest, data: queryData, remoteAddress: self.querySocketAddress) + } + + Task { + // 启动冷却期 + try? await Task.sleep(for: .seconds(5)) + self.endCooldown(for: identityId) + } + } + + private func endCooldown(for key: UInt32) { + self.coolingDown.remove(key) + } + +} diff --git a/Tun/Punchnet/Policy/SnapshotPublisher.swift b/Tun/Punchnet/Policy/SnapshotPublisher.swift index 935809f..e4b7559 100644 --- a/Tun/Punchnet/Policy/SnapshotPublisher.swift +++ b/Tun/Punchnet/Policy/SnapshotPublisher.swift @@ -12,11 +12,6 @@ final class SnapshotPublisher { init(initial snapshot: IdentitySnapshot) { self.atomic = ManagedAtomic(.passRetained(snapshot)) } - - deinit { - let ref = atomic.load(ordering: .relaxed) - ref.release() - } func publish(_ snapshot: IdentitySnapshot) { let newRef = Unmanaged.passRetained(snapshot) @@ -29,4 +24,9 @@ final class SnapshotPublisher { atomic.load(ordering: .relaxed).takeUnretainedValue() } + deinit { + let ref = atomic.load(ordering: .relaxed) + ref.release() + } + } diff --git a/Tun/Punchnet/SDLConfiguration.swift b/Tun/Punchnet/SDLConfiguration.swift index 45d96f1..616b55f 100644 --- a/Tun/Punchnet/SDLConfiguration.swift +++ b/Tun/Punchnet/SDLConfiguration.swift @@ -81,6 +81,7 @@ public class SDLConfiguration { let networkAddress: NetworkAddress let hostname: String let accessToken: String + let identityId: UInt32 public init(version: UInt8, installedChannel: String, @@ -92,6 +93,7 @@ public class SDLConfiguration { hostname: String, noticePort: Int, accessToken: String, + identityId: UInt32, remoteDnsServer: String) { self.version = version @@ -103,6 +105,7 @@ public class SDLConfiguration { self.networkAddress = networkAddress self.noticePort = noticePort self.accessToken = accessToken + self.identityId = identityId self.remoteDnsServer = remoteDnsServer self.hostname = hostname } @@ -118,6 +121,7 @@ extension SDLConfiguration { let stunServersStr = options["stun_servers"] as? String, let noticePort = options["notice_port"] as? Int, let accessToken = options["access_token"] as? String, + let identityId = options["identity_id"] as? UInt32, let clientId = options["client_id"] as? String, let remoteDnsServer = options["remote_dns_server"] as? String, let hostname = options["hostname"] as? String, @@ -144,6 +148,7 @@ extension SDLConfiguration { hostname: hostname, noticePort: noticePort, accessToken: accessToken, + identityId: identityId, remoteDnsServer: remoteDnsServer) } diff --git a/Tun/Punchnet/SDLContextActor.swift b/Tun/Punchnet/SDLContextActor.swift index 2f41a4e..1fd536b 100644 --- a/Tun/Punchnet/SDLContextActor.swift +++ b/Tun/Punchnet/SDLContextActor.swift @@ -73,6 +73,7 @@ actor SDLContextActor { // 处理权限控制 private let identifyStore: IdentityStore private let snapshotPublisher: SnapshotPublisher + private let policyRequesterActor: PolicyRequesterActor public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) { self.provider = provider @@ -90,6 +91,7 @@ actor SDLContextActor { let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty()) self.identifyStore = IdentityStore(publisher: snapshotPublisher) self.snapshotPublisher = snapshotPublisher + self.policyRequesterActor = PolicyRequesterActor(querySocketAddress: config.stunSocketAddress) } public func start() { @@ -241,6 +243,10 @@ actor SDLContextActor { try? await self.handleRegister(remoteAddress: remoteAddress, register: register) case .registerAck(let registerAck): await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) + case .policyReponse(let policyResponse): + SDLLogger.shared.log("[SDLContext] get a policyResponse: \(policyResponse.totalNum) of \(policyResponse.index), bytes: \(policyResponse.rules.count)") + // 处理权限的请求问题 + await self.identifyStore.apply(policyResponse: policyResponse) } } @@ -298,7 +304,6 @@ actor SDLContextActor { stunRequest.natType = UInt32(self.natType.rawValue) stunRequest.sessionToken = sessionToken - SDLLogger.shared.log("[SDLContext] send stun request: \(stunRequest)") if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) @@ -310,6 +315,8 @@ actor SDLContextActor { self.aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) self.sessionToken = registerSuperAck.sessionToken + await self.triggerPolicy() + SDLLogger.shared.log("[SDLContext] get registerSuperAck, aes_key len: \(self.aesKey!.count)", level: .info) // 服务器分配的tun网卡信息 do { @@ -496,15 +503,24 @@ actor SDLContextActor { self.provider.packetFlow.writePacketObjects([packet]) } } else { - // todo 向服务器请求权限逻辑 - + // 向服务器请求权限逻辑 + if let sessionToken = self.sessionToken { + var policyRequest = SDLPolicyRequest() + policyRequest.clientID = self.config.clientId + policyRequest.networkID = self.config.networkAddress.networkId + policyRequest.mac = self.config.networkAddress.mac + policyRequest.srcIdentityID = data.identityID + policyRequest.dstIdentityID = self.config.identityId + policyRequest.sessionToken = sessionToken + + await self.policyRequesterActor.submitPolicyRequest(using: self.udpHole, request: &policyRequest) + } } default: SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug) } } - // 流量统计 // public func flowReportTask() { // Task { @@ -585,6 +601,7 @@ actor SDLContextActor { dataPacket.srcMac = networkAddr.mac dataPacket.dstMac = dstMac dataPacket.ttl = 255 + dataPacket.identityID = self.config.identityId dataPacket.data = encodedPacket let data = try! dataPacket.serializedData() @@ -665,6 +682,22 @@ actor SDLContextActor { } } + // todo 测试代码 + private func triggerPolicy() async { + // 向服务器请求权限逻辑 + if let sessionToken = self.sessionToken { + var policyRequest = SDLPolicyRequest() + policyRequest.clientID = self.config.clientId + policyRequest.networkID = self.config.networkAddress.networkId + policyRequest.mac = self.config.networkAddress.mac + policyRequest.srcIdentityID = 1234 + policyRequest.dstIdentityID = self.config.identityId + policyRequest.sessionToken = sessionToken + + await self.policyRequesterActor.submitPolicyRequest(using: self.udpHole, request: &policyRequest) + } + } + deinit { self.udpHole = nil self.dnsClient = nil diff --git a/Tun/Punchnet/SDLMessage.pb.swift b/Tun/Punchnet/SDLMessage.pb.swift index b8f5d46..cf6bda9 100644 --- a/Tun/Punchnet/SDLMessage.pb.swift +++ b/Tun/Punchnet/SDLMessage.pb.swift @@ -467,12 +467,18 @@ struct SDLPolicyRequest: @unchecked Sendable { // `Message` and `Message+*Additions` files in the SwiftProtobuf library for // methods supported on all messages. + var clientID: String = String() + var networkID: UInt32 = 0 + var mac: Data = Data() + var srcIdentityID: UInt32 = 0 var dstIdentityID: UInt32 = 0 + var version: UInt32 = 0 + var sessionToken: Data = Data() var unknownFields = SwiftProtobuf.UnknownStorage() @@ -1589,10 +1595,13 @@ extension SDLArpResponse: SwiftProtobuf.Message, SwiftProtobuf._MessageImplement extension SDLPolicyRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding { static let protoMessageName: String = "SDLPolicyRequest" static let _protobuf_nameMap: SwiftProtobuf._NameMap = [ - 1: .standard(proto: "network_id"), - 2: .standard(proto: "src_identity_id"), - 3: .standard(proto: "dst_identity_id"), - 4: .standard(proto: "session_token"), + 1: .standard(proto: "client_id"), + 2: .standard(proto: "network_id"), + 3: .same(proto: "mac"), + 4: .standard(proto: "src_identity_id"), + 5: .standard(proto: "dst_identity_id"), + 6: .same(proto: "version"), + 7: .standard(proto: "session_token"), ] mutating func decodeMessage(decoder: inout D) throws { @@ -1601,35 +1610,50 @@ extension SDLPolicyRequest: SwiftProtobuf.Message, SwiftProtobuf._MessageImpleme // allocates stack space for every case branch when no optimizations are // enabled. https://github.com/apple/swift-protobuf/issues/1034 switch fieldNumber { - case 1: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }() - case 2: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }() - case 3: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }() - case 4: try { try decoder.decodeSingularBytesField(value: &self.sessionToken) }() + case 1: try { try decoder.decodeSingularStringField(value: &self.clientID) }() + case 2: try { try decoder.decodeSingularUInt32Field(value: &self.networkID) }() + case 3: try { try decoder.decodeSingularBytesField(value: &self.mac) }() + case 4: try { try decoder.decodeSingularUInt32Field(value: &self.srcIdentityID) }() + case 5: try { try decoder.decodeSingularUInt32Field(value: &self.dstIdentityID) }() + case 6: try { try decoder.decodeSingularUInt32Field(value: &self.version) }() + case 7: try { try decoder.decodeSingularBytesField(value: &self.sessionToken) }() default: break } } } func traverse(visitor: inout V) throws { + if !self.clientID.isEmpty { + try visitor.visitSingularStringField(value: self.clientID, fieldNumber: 1) + } if self.networkID != 0 { - try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 1) + try visitor.visitSingularUInt32Field(value: self.networkID, fieldNumber: 2) + } + if !self.mac.isEmpty { + try visitor.visitSingularBytesField(value: self.mac, fieldNumber: 3) } if self.srcIdentityID != 0 { - try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 2) + try visitor.visitSingularUInt32Field(value: self.srcIdentityID, fieldNumber: 4) } if self.dstIdentityID != 0 { - try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 3) + try visitor.visitSingularUInt32Field(value: self.dstIdentityID, fieldNumber: 5) + } + if self.version != 0 { + try visitor.visitSingularUInt32Field(value: self.version, fieldNumber: 6) } if !self.sessionToken.isEmpty { - try visitor.visitSingularBytesField(value: self.sessionToken, fieldNumber: 4) + try visitor.visitSingularBytesField(value: self.sessionToken, fieldNumber: 7) } try unknownFields.traverse(visitor: &visitor) } static func ==(lhs: SDLPolicyRequest, rhs: SDLPolicyRequest) -> Bool { + if lhs.clientID != rhs.clientID {return false} if lhs.networkID != rhs.networkID {return false} + if lhs.mac != rhs.mac {return false} if lhs.srcIdentityID != rhs.srcIdentityID {return false} if lhs.dstIdentityID != rhs.dstIdentityID {return false} + if lhs.version != rhs.version {return false} if lhs.sessionToken != rhs.sessionToken {return false} if lhs.unknownFields != rhs.unknownFields {return false} return true diff --git a/Tun/Punchnet/SDLMessage.swift b/Tun/Punchnet/SDLMessage.swift index 64d047b..6d924c8 100644 --- a/Tun/Punchnet/SDLMessage.swift +++ b/Tun/Punchnet/SDLMessage.swift @@ -32,6 +32,10 @@ enum SDLPacketType: UInt8 { case stunProbe = 0x32 case stunProbeReply = 0x33 + // 权限控制 + case policyRequest = 0xb0 + case policyResponse = 0xb1 + case data = 0xFF } @@ -103,6 +107,8 @@ enum SDLHoleSignal { case register(SDLRegister) case registerAck(SDLRegisterAck) + + case policyReponse(SDLPolicyResponse) } // 命令类型 diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index cbc8c21..e3c3a2a 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -137,9 +137,11 @@ final class SDLUDPHole: ChannelInboundHandler { // --MARK: 编解码器 private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { + let rawType = buffer.getInteger(at: 0, endianness: .big, as: UInt8.self) + guard let type = buffer.readInteger(as: UInt8.self), let packetType = SDLPacketType(rawValue: type) else { - SDLLogger.shared.log("[SDLUDPHole] decode error 11") + SDLLogger.shared.log("[SDLUDPHole] decode error 11: \(rawType)") return nil } @@ -186,6 +188,12 @@ final class SDLUDPHole: ChannelInboundHandler { return nil } return .signal(.peerInfo(peerInfo)) + case .policyResponse: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else { + return nil + } + return .signal(.policyReponse(policyResponse)) case .event: guard let eventVal = buffer.readInteger(as: UInt8.self), let event = SDLEventType(rawValue: eventVal), @@ -220,7 +228,6 @@ final class SDLUDPHole: ChannelInboundHandler { return nil } return .signal(.event(.refreshAuth(refreshAuthEvent))) - case .networkShutdown: guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { SDLLogger.shared.log("[SDLUDPHole] decode error 18") diff --git a/punchnet/Core/SystemConfig.swift b/punchnet/Core/SystemConfig.swift index e36dc18..583e7ac 100644 --- a/punchnet/Core/SystemConfig.swift +++ b/punchnet/Core/SystemConfig.swift @@ -26,7 +26,7 @@ struct SystemConfig { static let stunServers = "118.178.229.213:1365,1366;118.178.229.213:1365,1366" //static let stunServers = "127.0.0.1:1265,1266;127.0.0.1:1265,1266" - static func getOptions(networkId: UInt32, networkDomain: String, ip: String, maskLen: UInt8, accessToken: String, hostname: String, noticePort: Int) -> [String: NSObject]? { + static func getOptions(networkId: UInt32, networkDomain: String, ip: String, maskLen: UInt8, accessToken: String, identityId: UInt32, hostname: String, noticePort: Int) -> [String: NSObject]? { guard let superIp = DNSResolver.resolveAddrInfos(superHost).first else { return nil } @@ -39,6 +39,7 @@ struct SystemConfig { "installed_channel": installedChannel as NSObject, "client_id": clientId as NSObject, "access_token": accessToken as NSObject, + "identity_id": identityId as NSObject, "super_ip": superIp as NSObject, "super_port": superPort as NSObject, "stun_servers": stunServers as NSObject, diff --git a/punchnet/Views/Network/NetworkDisconnctedView.swift b/punchnet/Views/Network/NetworkDisconnctedView.swift index 6b16dc3..bb9db6b 100644 --- a/punchnet/Views/Network/NetworkDisconnctedView.swift +++ b/punchnet/Views/Network/NetworkDisconnctedView.swift @@ -61,6 +61,7 @@ struct NetworkDisconnctedView: View { ip: "10.211.179.1", maskLen: 24, accessToken: "accessToken1234", + identityId: 1234, hostname: "mysql", noticePort: 1234) // token存在则优先使用token