swiftlib_sdlan/Sources/sdlan/SDLContext.swift
2025-07-14 15:33:40 +08:00

834 lines
31 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
import Combine
//
/*
1. rsa
*/
class SDLContext {
//
struct Route {
let dstAddress: String
let subnetMask: String
var debugInfo: String {
return "\(dstAddress):\(subnetMask)"
}
}
//
final class Configuration {
struct StunServer {
let host: String
let ports: [Int]
}
//
let version: UInt8
//
let installedChannel: String
let superHost: String
let superPort: Int
let stunServers: [StunServer]
lazy var stunSocketAddress: SocketAddress = {
let stunServer = stunServers[0]
return try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[0])
}()
//
lazy var stunProbeSocketAddressArray: [[SocketAddress]] = {
return stunServers.map { stunServer in
[
try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[0]),
try! SocketAddress.makeAddressResolvingHost(stunServer.host, port: stunServer.ports[1])
]
}
}()
let clientId: String
let token: String
init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, token: String) {
self.version = version
self.installedChannel = installedChannel
self.superHost = superHost
self.superPort = superPort
self.stunServers = stunServers
self.clientId = clientId
self.token = token
}
}
let config: Configuration
// tun
var devAddr: SDLDevAddr
// nat,
//var natAddress: SDLNatAddress?
// nat
var natType: NatType = .blocked
// AES
var aesCipher: AESCipher?
// rsa, public_key
let rsaCipher: RSACipher
//
var udpHole: SDLUDPHole?
private var udpCancel: AnyCancellable?
var superClient: SDLSuperClient?
private var superCancel: AnyCancellable?
//
private var readTask: Task<(), Never>?
let provider: PacketTunnelProvider
private var sessionManager: SessionManager
private var holerManager: HolerManager
private var arpServer: ArpServer
// stunRequestcookie
private var lastCookie: UInt32? = 0
//
private var stunCancel: AnyCancellable?
//
private var monitor = SDLNetworkMonitor()
private var monitorCancel: AnyCancellable?
// socket
private var noticeClient: SDLNoticeClient
//
private var flowTracer = SDLFlowTracerActor()
private var flowTracerCancel: AnyCancellable?
init(provider: PacketTunnelProvider, config: Configuration) throws {
self.config = config
self.rsaCipher = try RSACipher(keySize: 1024)
// mac
var devAddr = SDLDevAddr()
devAddr.mac = Self.getMacAddress()
self.devAddr = devAddr
self.provider = provider
self.sessionManager = SessionManager()
self.holerManager = HolerManager()
self.arpServer = ArpServer(known_macs: [:])
self.noticeClient = SDLNoticeClient()
}
func start() async throws {
try await self.startSuperClient()
try await self.startUDPHole()
self.noticeClient.start()
//
self.monitorCancel = self.monitor.eventFlow.sink { event in
switch event {
case .changed:
// nat
Task {
self.natType = await self.getNatType()
NSLog("didNetworkPathChanged, nat type is: \(self.natType)")
}
case .unreachable:
NSLog("didNetworkPathUnreachable")
}
}
self.monitor.start()
}
private func startSuperClient() async throws {
self.superClient = SDLSuperClient(host: config.superHost, port: config.superPort)
// super
self.superCancel?.cancel()
self.superCancel = self.superClient?.eventFlow.sink { event in
Task.detached {
await self.handleSuperEvent(event: event)
}
}
try await self.superClient?.start()
}
private func handleSuperEvent(event: SDLSuperClient.SuperEvent) async {
switch event {
case .ready:
NSLog("[SDLContext] get registerSuper, mac address: \(Self.formatMacAddress(mac: self.devAddr.mac))")
guard let message = await self.superClient?.registerSuper(context: self) else {
return
}
switch message.packet {
case .registerSuperAck(let registerSuperAck):
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType)
NSLog("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)")
self.devAddr = registerSuperAck.devAddr
if upgradeType == .force {
let forceUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
self.noticeClient.send(data: forceUpgrade.binaryData)
exit(-1)
}
// tun
await self.didNetworkConfigChanged(devAddr: self.devAddr)
self.aesCipher = AESCipher(aesKey: aesKey)
if upgradeType == .normal {
let normalUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
self.noticeClient.send(data: normalUpgrade.binaryData)
}
case .registerSuperNak(let nakPacket):
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage)
self.noticeClient.send(data: alertNotice.binaryData)
exit(-1)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage)
self.noticeClient.send(data: alertNotice.binaryData)
}
SDLLogger.log("Get a SuperNak message exit", level: .error)
default:
()
}
case .closed:
SDLLogger.log("[SDLContext] super client closed", level: .debug)
await self.arpServer.clear()
DispatchQueue.global().asyncAfter(deadline: .now() + 5) {
Task {
try await self.startSuperClient()
}
}
case .event(let evt):
switch evt {
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
NSLog("natChangedEvent, dstMac: \(dstMac)")
await sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
NSLog("sendRegisterEvent, ip: \(sendRegisterEvent)")
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.AlertMessage(alert: shutdownEvent.message)
self.noticeClient.send(data: alertNotice.binaryData)
exit(-1)
}
case .command(let packetId, let command):
switch command {
case .changeNetwork(let changeNetworkCommand):
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey))
NSLog("[SDLContext] change network command get aes_key len: \(aesKey.count)")
self.devAddr = changeNetworkCommand.devAddr
// tun
await self.didNetworkConfigChanged(devAddr: self.devAddr)
self.aesCipher = AESCipher(aesKey: aesKey)
var commandAck = SDLCommandAck()
commandAck.status = true
self.superClient?.commandAck(packetId: packetId, ack: commandAck)
}
}
}
private func startUDPHole() async throws {
self.udpHole = SDLUDPHole()
self.udpCancel?.cancel()
self.udpCancel = self.udpHole?.eventFlow.sink { event in
Task.detached {
await self.handleUDPEvent(event: event)
}
}
try await self.udpHole?.start()
}
private func handleUDPEvent(event: SDLUDPHole.UDPEvent) async {
switch event {
case .ready:
//
self.natType = await self.getNatType()
SDLLogger.log("[SDLContext] nat type is: \(self.natType)", level: .debug)
let timer = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
self.stunCancel = Just(Date()).merge(with: timer).sink { _ in
self.lastCookie = self.udpHole?.stunRequest(context: self)
}
case .closed:
DispatchQueue.global().asyncAfter(deadline: .now() + 5) {
Task {
try await self.startUDPHole()
}
}
case .message(let remoteAddress, let message):
switch message {
case .register(let register):
NSLog("register packet: \(register), dev_addr: \(self.devAddr)")
// tun,
if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID {
// ack
self.udpHole?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
case .registerAck(let registerAck):
// tun,
if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
case .stunReply(let stunReply):
let cookie = stunReply.cookie
if cookie == self.lastCookie {
// nat
//self.natAddress = stunReply.natAddress
SDLLogger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())")
}
default:
()
}
case .data(let data):
let mac = LayerPacket.MacAddress(data: data.dstMac)
guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
NSLog("[SDLContext] didReadData 1")
return
}
guard let decyptedData = try? self.aesCipher?.decypt(data: Data(data.data)) else {
NSLog("[SDLContext] didReadData 2")
return
}
do {
let layerPacket = try LayerPacket(layerData: decyptedData)
await self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == self.devAddr.netAddr {
switch arpPacket.opcode {
case .request:
NSLog("[SDLContext] get arp request packet")
let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
NSLog("[SDLContext] get arp response packet")
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
NSLog("[SDLContext] get invalid arp packet, target_ip: \(arpPacket)")
}
} else {
NSLog("[SDLContext] get invalid arp packet")
}
case .ipv4:
NSLog("[SDLContext] get ipv4 packet")
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
default:
NSLog("[SDLContext] get invalid packet")
}
} catch let err {
NSLog("[SDLContext] didReadData err: \(err)")
}
}
}
//
func flowReportTask() {
Task {
//
self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
.sink { _ in
Task {
let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
}
}
}
}
//
private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServers: [String]? = nil) async {
let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen))
let routes = [Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress)]
// Add code here to start the process of connecting the tunnel.
let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8")
networkSettings.mtu = 1460
// DNS
if let dnsServers {
networkSettings.dnsSettings = NEDNSSettings(servers: dnsServers)
} else {
networkSettings.dnsSettings = NEDNSSettings(servers: ["8.8.8.8", "114.114.114.114"])
}
let ipv4Settings = NEIPv4Settings(addresses: [netAddress.ipAddress], subnetMasks: [netAddress.maskAddress])
//
//NEIPv4Route.default()
ipv4Settings.includedRoutes = routes.map { route in
NEIPv4Route(destinationAddress: route.dstAddress, subnetMask: route.subnetMask)
}
networkSettings.ipv4Settings = ipv4Settings
//
do {
try await self.provider.setTunnelNetworkSettings(networkSettings)
await self.holerManager.cleanup()
self.startReader()
} catch let err {
SDLLogger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
}
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task(priority: .high) {
repeat {
if Task.isCancelled {
break
}
let (packets, numbers) = await self.provider.packetFlow.readPackets()
for (data, number) in zip(packets, numbers) where number == 2 {
if let packet = IPPacket(data) {
Task.detached {
let dstIp = packet.header.destination
// , ip
if dstIp == self.devAddr.netAddr {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([nePacket])
return
}
// arpmac
if let dstMac = await self.arpServer.query(ip: dstIp) {
await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
// arp
let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF])
let arpReqeust: ARPPacket = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp)
await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal())
NSLog("[SDLContext] dstIp: \(dstIp) arp query not found")
}
}
}
}
} while true
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async {
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data)
guard let encodedPacket = try? self.aesCipher?.encrypt(data: layerPacket.marshal()) else {
return
}
// session
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
NSLog("[SDLContext] send packet by session: \(session)")
self.udpHole?.sendPacket(context: self, session: session, data: encodedPacket)
await self.flowTracer.inc(num: data.count, type: .p2p)
} else {
// super_node
self.udpHole?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket)
//
await self.flowTracer.inc(num: data.count, type: .forward)
//
await self.holerManager.addHoler(dstMac: dstMac) {
self.holerTask(dstMac: dstMac)
}
}
}
deinit {
self.stunCancel?.cancel()
self.udpHole = nil
self.superClient = nil
}
}
//--MARK: RSA
extension SDLContext {
struct RSACipher {
let pubKey: String
let privateKeyDER: Data
init(keySize: Int) throws {
let (privateKey, publicKey) = try Self.loadKeys(keySize: keySize)
let privKeyStr = SwKeyConvert.PrivateKey.derToPKCS1PEM(privateKey)
self.pubKey = SwKeyConvert.PublicKey.derToPKCS8PEM(publicKey)
self.privateKeyDER = try SwKeyConvert.PrivateKey.pemToPKCS1DER(privKeyStr)
}
public func decode(data: Data) throws -> Data {
let tag = Data()
let (decryptedData, _) = try CC.RSA.decrypt(data, derKey: self.privateKeyDER, tag: tag, padding: .pkcs1, digest: .none)
return decryptedData
}
private static func loadKeys(keySize: Int) throws -> (Data, Data) {
if let privateKey = UserDefaults.standard.data(forKey: "privateKey"),
let publicKey = UserDefaults.standard.data(forKey: "publicKey") {
return (privateKey, publicKey)
} else {
let (privateKey, publicKey) = try CC.RSA.generateKeyPair(keySize)
UserDefaults.standard.setValue(privateKey, forKey: "privateKey")
UserDefaults.standard.setValue(publicKey, forKey: "publicKey")
return (privateKey, publicKey)
}
}
}
}
// --MARK: AES, AES256
extension SDLContext {
struct AESCipher {
let aesKey: Data
let ivData: Data
init(aesKey: Data) {
self.aesKey = aesKey
self.ivData = Data(aesKey.prefix(16))
}
func decypt(data: Data) throws -> Data {
return try CC.crypt(.decrypt, blockMode: .cbc, algorithm: .aes, padding: .pkcs7Padding, data: data, key: aesKey, iv: ivData)
}
func encrypt(data: Data) throws -> Data {
return try CC.crypt(.encrypt, blockMode: .cbc, algorithm: .aes, padding: .pkcs7Padding, data: data, key: aesKey, iv: ivData)
}
}
}
// --MARK: session, session10s使使
extension SDLContext {
struct Session {
// ip,
let dstMac: Data
// nat
let natAddress: SocketAddress
// 使
var lastTimestamp: Int32
init(dstMac: Data, natAddress: SocketAddress) {
self.dstMac = dstMac
self.natAddress = natAddress
self.lastTimestamp = Int32(Date().timeIntervalSince1970)
}
mutating func updateLastTimestamp(_ lastTimestamp: Int32) {
self.lastTimestamp = lastTimestamp
}
}
actor SessionManager {
private var sessions: [Data:Session] = [:]
// session
private let ttl: Int32 = 10
func getSession(toAddress: Data) -> Session? {
let timestamp = Int32(Date().timeIntervalSince1970)
if let session = self.sessions[toAddress] {
if session.lastTimestamp >= timestamp + ttl {
self.sessions[toAddress]?.updateLastTimestamp(timestamp)
return session
} else {
self.sessions.removeValue(forKey: toAddress)
}
}
return nil
}
func addSession(session: Session) {
self.sessions[session.dstMac] = session
}
func removeSession(dstMac: Data) {
self.sessions.removeValue(forKey: dstMac)
}
}
}
// --MARK: known_ips
extension SDLContext {
actor ArpServer {
private var known_macs: [UInt32:Data] = [:]
init(known_macs: [UInt32:Data]) {
self.known_macs = known_macs
}
func query(ip: UInt32) -> Data? {
return self.known_macs[ip]
}
func append(ip: UInt32, mac: Data) {
self.known_macs[ip] = mac
}
func remove(ip: UInt32) {
self.known_macs.removeValue(forKey: ip)
}
func clear() {
self.known_macs = [:]
}
}
}
// --MARK:
extension SDLContext {
actor HolerManager {
private var holers: [Data:Task<(), Never>] = [:]
func addHoler(dstMac: Data, creator: @escaping () -> Task<(), Never>) {
if let task = self.holers[dstMac] {
if task.isCancelled {
self.holers[dstMac] = creator()
}
} else {
self.holers[dstMac] = creator()
}
}
func cleanup() {
for holer in holers.values {
holer.cancel()
}
self.holers.removeAll()
}
}
func holerTask(dstMac: Data) -> Task<(), Never> {
return Task {
guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac) else {
return
}
switch message.packet {
case .empty:
SDLLogger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug)
case .peerInfo(let peerInfo):
if let remoteAddress = peerInfo.v4Info.socketAddress() {
SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .warning)
// register
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac)
} else {
SDLLogger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
}
default:
SDLLogger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning)
}
}
}
}
//--MARK:
extension SDLContext {
// nat
enum NatType: UInt8, Encodable {
case blocked = 0
case noNat = 1
case fullCone = 2
case portRestricted = 3
case coneRestricted = 4
case symmetric = 5
}
private func getNatAddress(remoteAddress: SocketAddress, attr: SDLProbeAttr) async -> SocketAddress? {
let stunProbeReply = await self.udpHole?.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: 5)
return stunProbeReply?.socketAddress()
}
// nat
func getNatType() async -> NatType {
let addressArray = config.stunProbeSocketAddressArray
// step1: ip1:port1 <---- ip1:port1
guard let natAddress1 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .none) else {
return .blocked
}
// nat
if natAddress1 == self.udpHole?.localAddress {
return .noNat
}
// step2: ip2:port2 <---- ip2:port2
guard let natAddress2 = await getNatAddress(remoteAddress: addressArray[1][1], attr: .none) else {
return .blocked
}
// natAddress2 IPIPNAT;
// ip{dstIp, dstPort, srcIp, srcPort}, ip
NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2)")
if let ipAddress1 = natAddress1.ipAddress, let ipAddress2 = natAddress2.ipAddress, ipAddress1 != ipAddress2 {
return .symmetric
}
// step3: ip1:port1 <---- ip2:port2 (ipport)
// IPNAT
if let natAddress3 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .peer) {
NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address3: \(natAddress3)")
return .fullCone
}
// step3: ip1:port1 <---- ip1:port2 (port)
// IPNAT
if let natAddress4 = await getNatAddress(remoteAddress: addressArray[0][0], attr: .port) {
NSLog("nat_address1: \(natAddress1), nat_address2: \(natAddress2), nat_address4: \(natAddress4)")
return .coneRestricted
} else {
return .portRestricted
}
}
}
//--MARK: UUID
extension SDLContext {
static func getUUID() -> String {
let userDefaults = UserDefaults.standard
if let uuid = userDefaults.value(forKey: "gClientId") as? String {
return uuid
} else {
let uuid = UUID().uuidString.replacingOccurrences(of: "-", with: "").lowercased()
userDefaults.setValue(uuid, forKey: "gClientId")
return uuid
}
}
// mac
static func getMacAddress() -> Data {
let key = "gMacAddress2"
let userDefaults = UserDefaults.standard
if let mac = userDefaults.value(forKey: key) as? Data {
return mac
}
else {
let mac = generateMacAddress()
userDefaults.setValue(mac, forKey: key)
return mac
}
}
// mac
private static func generateMacAddress() -> Data {
var macAddress = [UInt8](repeating: 0, count: 6)
for i in 0..<6 {
macAddress[i] = UInt8.random(in: 0...255)
}
return Data(macAddress)
}
// mac
private static func formatMacAddress(mac: Data) -> String {
let bytes = [UInt8](mac)
return bytes.map { String(format: "%02X", $0) }.joined(separator: ":").lowercased()
}
}