// // SDLContext.swift // Tun // // Created by 安礼成 on 2024/2/29. // import Foundation import NetworkExtension import NIOCore import Combine // 上下文环境变量,全局共享 /* 1. 处理rsa的加解密逻辑 */ class SDLContext { // 路由信息 struct Route { let dstAddress: String let subnetMask: String var debugInfo: String { return "\(dstAddress):\(subnetMask)" } } // 配置项目 final class Configuration { struct StunServer { let host: String let ports: [Int] } // 当前的客户端版本 let version: UInt8 // 安装渠道 let installedChannel: String let superHost: String let superPort: Int let stunServers: [StunServer] lazy var stunSocketAddress: SocketAddress = { let stunServer = stunServers[0] return try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[0]) }() // 网络探测地址信息 lazy var stunProbeSocketAddressArray: [[SocketAddress]] = { return stunServers.map { stunServer in [ try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[0]), try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[1]) ] } }() let clientId: String let token: String init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, token: String) { self.version = version self.installedChannel = installedChannel self.superHost = superHost self.superPort = superPort self.stunServers = stunServers self.clientId = clientId self.token = token } } let config: Configuration // tun网络地址信息 var devAddr: SDLDevAddr // nat映射的相关信息, 暂时没有用处 //var natAddress: SDLNatAddress? // nat的网络类型 var natType: NatType = .blocked // AES加密,授权通过后,对象才会被创建 var aesCipher: AESCipher? // rsa的相关配置, public_key是本地生成的 let rsaCipher: RSACipher // 依赖的变量 var udpHole: SDLUDPHole? private var udpCancel: AnyCancellable? var superClient: SDLSuperClient? private var superCancel: AnyCancellable? // 数据包读取任务 private var readTask: Task<(), Never>? let provider: PacketTunnelProvider private var sessionManager: SessionManager private var holerManager: HolerManager private var arpServer: ArpServer // 记录最后发送的stunRequest的cookie private var lastCookie: UInt32? = 0 // 定时器 private var stunCancel: AnyCancellable? // 网络状态变化的健康 private var monitor = SDLNetworkMonitor() private var monitorCancel: AnyCancellable? // 内部socket通讯 private var noticeClient: SDLNoticeClient // 流量统计 private var flowTracer = SDLFlowTracerActor() private var flowTracerCancel: AnyCancellable? init(provider: PacketTunnelProvider, config: Configuration) throws { self.config = config self.rsaCipher = try RSACipher(keySize: 1024) // 生成mac地址 var devAddr = SDLDevAddr() devAddr.mac = Self.getMacAddress() self.devAddr = devAddr self.provider = provider self.sessionManager = SessionManager() self.holerManager = HolerManager() self.arpServer = ArpServer(known_macs: [:]) self.noticeClient = SDLNoticeClient() } func start() async throws { try await self.startSuperClient() try await self.startUDPHole() self.noticeClient.start() // 启动网络监控 self.monitorCancel = self.monitor.eventFlow.sink { event in switch event { case .changed: // 需要重新探测网络的nat类型 Task { self.natType = await self.getNatType() NSLog("didNetworkPathChanged, nat type is: \(self.natType)") } case .unreachable: NSLog("didNetworkPathUnreachable") } } self.monitor.start() } private func startSuperClient() async throws { self.superClient = SDLSuperClient(host: config.superHost, port: config.superPort) // 建立super的绑定关系 self.superCancel?.cancel() self.superCancel = self.superClient?.eventFlow.sink { event in Task.detached { await self.handleSuperEvent(event: event) } } try await self.superClient?.start() } private func handleSuperEvent(event: SDLSuperClient.SuperEvent) async { switch event { case .ready: NSLog("[SDLContext] get registerSuper, mac address: \(Self.formatMacAddress(mac: self.devAddr.mac))") guard let message = await self.superClient?.registerSuper(context: self) else { return } switch message.packet { case .registerSuperAck(let registerSuperAck): // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType) NSLog("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)") self.devAddr = registerSuperAck.devAddr if upgradeType == .force { let forceUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) self.noticeClient.send(data: forceUpgrade.binaryData) exit(-1) } // 服务器分配的tun网卡信息 await self.didNetworkConfigChanged(devAddr: self.devAddr) self.aesCipher = AESCipher(aesKey: aesKey) if upgradeType == .normal { let normalUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress) self.noticeClient.send(data: normalUpgrade.binaryData) } case .registerSuperNak(let nakPacket): let errorMessage = nakPacket.errorMessage guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else { return } switch errorCode { case .invalidToken, .nodeDisabled: let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage) self.noticeClient.send(data: alertNotice.binaryData) exit(-1) case .noIpAddress, .networkFault, .internalFault: let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage) self.noticeClient.send(data: alertNotice.binaryData) } SDLLogger.log("Get a SuperNak message exit", level: .error) default: () } case .closed: SDLLogger.log("[SDLContext] super client closed", level: .debug) await self.arpServer.clear() DispatchQueue.global().asyncAfter(deadline: .now() + 5) { Task { try await self.startSuperClient() } } case .event(let evt): switch evt { case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac NSLog("natChangedEvent, dstMac: \(dstMac)") await sessionManager.removeSession(dstMac: dstMac) case .sendRegister(let sendRegisterEvent): NSLog("sendRegisterEvent, ip: \(sendRegisterEvent)") let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp) if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) { // 发送register包 self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.AlertMessage(alert: shutdownEvent.message) self.noticeClient.send(data: alertNotice.binaryData) exit(-1) } case .command(let packetId, let command): switch command { case .changeNetwork(let changeNetworkCommand): // 需要对数据通过rsa的私钥解码 let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey)) NSLog("[SDLContext] change network command get aes_key len: \(aesKey.count)") self.devAddr = changeNetworkCommand.devAddr // 服务器分配的tun网卡信息 await self.didNetworkConfigChanged(devAddr: self.devAddr) self.aesCipher = AESCipher(aesKey: aesKey) var commandAck = SDLCommandAck() commandAck.status = true self.superClient?.commandAck(packetId: packetId, ack: commandAck) } } } private func startUDPHole() async throws { self.udpHole = SDLUDPHole() self.udpCancel?.cancel() self.udpCancel = self.udpHole?.eventFlow.sink { event in Task.detached { await self.handleUDPEvent(event: event) } } try await self.udpHole?.start() } private func handleUDPEvent(event: SDLUDPHole.UDPEvent) async { switch event { case .ready: // 获取当前网络的类型 self.natType = await self.getNatType() SDLLogger.log("[SDLContext] nat type is: \(self.natType)", level: .debug) let timer = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect() self.stunCancel = Just(Date()).merge(with: timer).sink { _ in self.lastCookie = self.udpHole?.stunRequest(context: self) } case .closed: DispatchQueue.global().asyncAfter(deadline: .now() + 5) { Task { try await self.startUDPHole() } } case .message(let remoteAddress, let message): switch message { case .register(let register): NSLog("register packet: \(register), dev_addr: \(self.devAddr)") // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID { // 回复ack包 self.udpHole?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac) // 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址 let session = Session(dstMac: register.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { SDLLogger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning) } case .registerAck(let registerAck): // 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下 if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID { let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress) await self.sessionManager.addSession(session: session) } else { SDLLogger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning) } case .stunReply(let stunReply): let cookie = stunReply.cookie if cookie == self.lastCookie { // 记录下当前在nat上的映射信息,暂时没有用;后续会用来判断网络类型 //self.natAddress = stunReply.natAddress SDLLogger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())") } default: () } case .data(let data): let mac = LayerPacket.MacAddress(data: data.dstMac) guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else { NSLog("[SDLContext] didReadData 1") return } guard let decyptedData = try? self.aesCipher?.decypt(data: Data(data.data)) else { NSLog("[SDLContext] didReadData 2") return } do { let layerPacket = try LayerPacket(layerData: decyptedData) await self.flowTracer.inc(num: decyptedData.count, type: .inbound) // 处理arp请求 switch layerPacket.type { case .arp: // 判断如果收到的是arp请求 if let arpPacket = ARPPacket(data: layerPacket.data) { if arpPacket.targetIP == self.devAddr.netAddr { switch arpPacket.opcode { case .request: NSLog("[SDLContext] get arp request packet") let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr) await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal()) case .response: NSLog("[SDLContext] get arp response packet") await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC) } } else { NSLog("[SDLContext] get invalid arp packet, target_ip: \(arpPacket)") } } else { NSLog("[SDLContext] get invalid arp packet") } case .ipv4: NSLog("[SDLContext] get ipv4 packet") guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else { return } let packet = NEPacket(data: ipPacket.data, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([packet]) default: NSLog("[SDLContext] get invalid packet") } } catch let err { NSLog("[SDLContext] didReadData err: \(err)") } } } // 流量统计 func flowReportTask() { Task { // 每分钟汇报一次 self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect() .sink { _ in Task { let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset() self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum) } } } } // 网络改变时需要重新配置网络信息 private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServers: [String]? = nil) async { let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen)) let routes = [Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress)] // Add code here to start the process of connecting the tunnel. let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8") networkSettings.mtu = 1460 // 设置网卡的DNS解析 if let dnsServers { networkSettings.dnsSettings = NEDNSSettings(servers: dnsServers) } else { networkSettings.dnsSettings = NEDNSSettings(servers: ["8.8.8.8", "114.114.114.114"]) } let ipv4Settings = NEIPv4Settings(addresses: [netAddress.ipAddress], subnetMasks: [netAddress.maskAddress]) // 设置路由表 //NEIPv4Route.default() ipv4Settings.includedRoutes = routes.map { route in NEIPv4Route(destinationAddress: route.dstAddress, subnetMask: route.subnetMask) } networkSettings.ipv4Settings = ipv4Settings // 网卡配置设置必须成功 do { try await self.provider.setTunnelNetworkSettings(networkSettings) await self.holerManager.cleanup() self.startReader() } catch let err { SDLLogger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } } // 开始读取数据, 用单独的线程处理packetFlow private func startReader() { // 停止之前的任务 self.readTask?.cancel() // 开启新的任务 self.readTask = Task(priority: .high) { repeat { if Task.isCancelled { break } let (packets, numbers) = await self.provider.packetFlow.readPackets() for (data, number) in zip(packets, numbers) where number == 2 { if let packet = IPPacket(data) { Task.detached { let dstIp = packet.header.destination // 本地通讯, 目标地址是本地服务器的ip地址 if dstIp == self.devAddr.netAddr { let nePacket = NEPacket(data: packet.data, protocolFamily: 2) self.provider.packetFlow.writePacketObjects([nePacket]) return } // 查找arp缓存中是否有目标mac地址 if let dstMac = await self.arpServer.query(ip: dstIp) { await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data) } else { // 构造arp请求 let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]) let arpReqeust: ARPPacket = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp) await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal()) NSLog("[SDLContext] dstIp: \(dstIp) arp query not found") } } } } } while true } } private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async { // 将数据封装层2层的数据包 let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data) guard let encodedPacket = try? self.aesCipher?.encrypt(data: layerPacket.marshal()) else { return } // 通过session发送到对端 if let session = await self.sessionManager.getSession(toAddress: dstMac) { NSLog("[SDLContext] send packet by session: \(session)") self.udpHole?.sendPacket(context: self, session: session, data: encodedPacket) await self.flowTracer.inc(num: data.count, type: .p2p) } else { // 通过super_node进行转发 self.udpHole?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket) // 流量统计 await self.flowTracer.inc(num: data.count, type: .forward) // 尝试打洞 await self.holerManager.addHoler(dstMac: dstMac) { self.holerTask(dstMac: dstMac) } } } deinit { self.stunCancel?.cancel() self.udpHole = nil self.superClient = nil } } //--MARK: 处理RSA加密算法 extension SDLContext { struct RSACipher { let pubKey: String let privateKeyDER: Data init(keySize: Int) throws { let (privateKey, publicKey) = try Self.loadKeys(keySize: keySize) let privKeyStr = SwKeyConvert.PrivateKey.derToPKCS1PEM(privateKey) self.pubKey = SwKeyConvert.PublicKey.derToPKCS8PEM(publicKey) self.privateKeyDER = try SwKeyConvert.PrivateKey.pemToPKCS1DER(privKeyStr) } public func decode(data: Data) throws -> Data { let tag = Data() let (decryptedData, _) = try CC.RSA.decrypt(data, derKey: self.privateKeyDER, tag: tag, padding: .pkcs1, digest: .none) return decryptedData } private static func loadKeys(keySize: Int) throws -> (Data, Data) { if let privateKey = UserDefaults.standard.data(forKey: "privateKey"), let publicKey = UserDefaults.standard.data(forKey: "publicKey") { return (privateKey, publicKey) } else { let (privateKey, publicKey) = try CC.RSA.generateKeyPair(keySize) UserDefaults.standard.setValue(privateKey, forKey: "privateKey") UserDefaults.standard.setValue(publicKey, forKey: "publicKey") return (privateKey, publicKey) } } } } // --MARK: 处理AES加密, AES256 extension SDLContext { struct AESCipher { let aesKey: Data let ivData: Data init(aesKey: Data) { self.aesKey = aesKey self.ivData = Data(aesKey.prefix(16)) } func decypt(data: Data) throws -> Data { return try CC.crypt(.decrypt, blockMode: .cbc, algorithm: .aes, padding: .pkcs7Padding, data: data, key: aesKey, iv: ivData) } func encrypt(data: Data) throws -> Data { return try CC.crypt(.encrypt, blockMode: .cbc, algorithm: .aes, padding: .pkcs7Padding, data: data, key: aesKey, iv: ivData) } } } // --MARK: session管理, session的有效时间为10s,没次使用后更新最后使用时间 extension SDLContext { struct Session { // 在内部的通讯的ip地址, 整数格式 let dstMac: Data // 对端的主机在nat上映射的端口信息 let natAddress: SocketAddress // 最后使用时间 var lastTimestamp: Int32 init(dstMac: Data, natAddress: SocketAddress) { self.dstMac = dstMac self.natAddress = natAddress self.lastTimestamp = Int32(Date().timeIntervalSince1970) } mutating func updateLastTimestamp(_ lastTimestamp: Int32) { self.lastTimestamp = lastTimestamp } } actor SessionManager { private var sessions: [Data:Session] = [:] // session的有效时间 private let ttl: Int32 = 10 func getSession(toAddress: Data) -> Session? { let timestamp = Int32(Date().timeIntervalSince1970) if let session = self.sessions[toAddress] { if session.lastTimestamp >= timestamp + ttl { self.sessions[toAddress]?.updateLastTimestamp(timestamp) return session } else { self.sessions.removeValue(forKey: toAddress) } } return nil } func addSession(session: Session) { self.sessions[session.dstMac] = session } func removeSession(dstMac: Data) { self.sessions.removeValue(forKey: dstMac) } } } // --MARK: known_ips管理 extension SDLContext { actor ArpServer { private var known_macs: [UInt32:Data] = [:] init(known_macs: [UInt32:Data]) { self.known_macs = known_macs } func query(ip: UInt32) -> Data? { return self.known_macs[ip] } func append(ip: UInt32, mac: Data) { self.known_macs[ip] = mac } func remove(ip: UInt32) { self.known_macs.removeValue(forKey: ip) } func clear() { self.known_macs = [:] } } } // --MARK: 打洞流程管理 extension SDLContext { actor HolerManager { private var holers: [Data:Task<(), Never>] = [:] func addHoler(dstMac: Data, creator: @escaping () -> Task<(), Never>) { if let task = self.holers[dstMac] { if task.isCancelled { self.holers[dstMac] = creator() } } else { self.holers[dstMac] = creator() } } func cleanup() { for holer in holers.values { holer.cancel() } self.holers.removeAll() } } func holerTask(dstMac: Data) -> Task<(), Never> { return Task { guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac) else { return } switch message.packet { case .empty: SDLLogger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug) case .peerInfo(let peerInfo): if let remoteAddress = peerInfo.v4Info.socketAddress() { SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .warning) // 发送register包 self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac) } else { SDLLogger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning) } default: SDLLogger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning) } } } } //--MARK: 网络类型探测 extension SDLContext { // 定义nat类型 enum NatType: UInt8, Encodable { case blocked = 0 case noNat = 1 case fullCone = 2 case portRestricted = 3 case coneRestricted = 4 case symmetric = 5 } private func getNatAddress(remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? { let stunProbeReply = await self.udpHole?.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5) return stunProbeReply?.socketAddress() } // 获取当前所处的网络的nat类型 func getNatType() async -> NatType { let addressArray = config.stunProbeSocketAddressArray // step1: ip1:port1 <---- ip1:port1 guard let natAddress1 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .none) else { return .blocked } // 网络没有在nat下 if natAddress1 == self.udpHole?.localAddress { return .noNat } // step2: ip2:port2 <---- ip2:port2 guard let natAddress2 = await getNatAddress(remoteAddress: addressArray[1][1], attr: .none) else { return .blocked } // 如果natAddress2 的IP地址与上次回来的IP是不一样的,它就是对称型NAT; 这次的包也一定能发成功并收到 // 如果ip地址变了,这说明{dstIp, dstPort, srcIp, srcPort}, 其中有一个变了;则用新的ip地址 NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2)") if let ipAddress1 = natAddress1.ipAddress, let ipAddress2 = natAddress2.ipAddress, ipAddress1 != ipAddress2 { return .symmetric } // step3: ip1:port1 <---- ip2:port2 (ip地址和port都变的情况) // 如果能收到的,说明是完全锥形 说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 if let natAddress3 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .peer) { NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address3: \(natAddress3)") return .fullCone } // step3: ip1:port1 <---- ip1:port2 (port改变情况) // 如果能收到的说明是IP地址限制锥型NAT,如果不能收到说明是端口限制锥型。 if let natAddress4 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .port) { NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address4: \(natAddress4)") return .coneRestricted } else { return .portRestricted } } } //--MARK: 获取设备的UUID extension SDLContext { static func getUUID() -> String { let userDefaults = UserDefaults.standard if let uuid = userDefaults.value(forKey: "gClientId") as? String { return uuid } else { let uuid = UUID().uuidString.replacingOccurrences(of: "-", with: "").lowercased() userDefaults.setValue(uuid, forKey: "gClientId") return uuid } } // 获取mac地址 static func getMacAddress() -> Data { let key = "gMacAddress2" let userDefaults = UserDefaults.standard if let mac = userDefaults.value(forKey: key) as? Data { return mac } else { let mac = generateMacAddress() userDefaults.setValue(mac, forKey: key) return mac } } // 随机生成mac地址 private static func generateMacAddress() -> Data { var macAddress = [UInt8](repeating: 0, count: 6) for i in 0..<6 { macAddress[i] = UInt8.random(in: 0...255) } return Data(macAddress) } // 将mac地址转换成字符串 private static func formatMacAddress(mac: Data) -> String { let bytes = [UInt8](mac) return bytes.map { String(format: "%02X", $0) }.joined(separator: ":").lowercased() } }