punchnet-macos/Tun/Punchnet/Actors/SDLContextNew.swift
2026-01-07 16:29:03 +08:00

696 lines
29 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
import Combine
//
/*
1. rsa
*/
@available(macOS 14, *)
public class SDLContextNew {
//
struct Route {
let dstAddress: String
let subnetMask: String
var debugInfo: String {
return "\(dstAddress):\(subnetMask)"
}
}
let config: SDLConfiguration
// tun
var devAddr: SDLDevAddr
// nat,
//var natAddress: SDLNatAddress?
// nat
var natType: SDLNatProber.NatType = .blocked
// AES
var aesCipher: AESCipher
// aes
var aesKey: Data = Data()
// rsa, public_key
let rsaCipher: RSACipher
//
var udpHoleActor: SDLUDPHoleActor?
var superClientActor: SDLSuperClientActor?
var providerActor: SDLTunnelProviderActor
// dnsclient
var dnsClient: DNSClient?
//
private var readTask: Task<(), Never>?
private var sessionManager: SessionManager
private var arpServer: ArpServer
// stunRequestcookie
private var lastCookie: UInt32? = 0
//
private var monitor: SDLNetworkMonitor?
// socket
private var noticeClient: SDLNoticeClient?
//
private var flowTracer = SDLFlowTracerActor()
private var flowTracerCancel: AnyCancellable?
// holer
private var holerPublishers: [Data:PassthroughSubject<RegisterRequest, Never>] = [:]
private var bag = Set<AnyCancellable>()
private var locker = NSLock()
private let logger: SDLLogger
private var rootTask: Task<Void, Error>?
struct RegisterRequest {
let srcMac: Data
let dstMac: Data
let networkId: UInt32
}
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
// mac
var devAddr = SDLDevAddr()
devAddr.mac = Self.getMacAddress()
self.devAddr = devAddr
self.sessionManager = SessionManager()
self.arpServer = ArpServer(known_macs: [:])
self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger)
}
public func start() async throws {
self.rootTask = Task {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
while !Task.isCancelled {
do {
try await self.startDnsClient()
} catch let err {
self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning)
try await Task.sleep(for: .seconds(2))
}
}
}
group.addTask {
while !Task.isCancelled {
do {
try await self.startUDPHole()
} catch let err {
self.logger.log("[SDLContext] UDPHole get err: \(err)", level: .warning)
try await Task.sleep(for: .seconds(2))
}
}
}
group.addTask {
while !Task.isCancelled {
do {
try await self.startSuperClient()
} catch let err {
self.logger.log("[SDLContext] SuperClient get error: \(err), will restart", level: .warning)
await self.arpServer.clear()
try await Task.sleep(for: .seconds(2))
}
}
}
group.addTask {
await self.startMonitor()
}
group.addTask {
while !Task.isCancelled {
do {
try await self.startNoticeClient()
} catch let err {
self.logger.log("[SDLContext] noticeClient get err: \(err)", level: .warning)
try await Task.sleep(for: .seconds(2))
}
}
}
try await group.waitForAll()
}
}
try await self.rootTask?.value
}
public func stop() async {
self.rootTask?.cancel()
self.superClientActor = nil
self.udpHoleActor = nil
self.noticeClient = nil
self.readTask?.cancel()
}
private func startNoticeClient() async throws {
self.noticeClient = try await SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
try await self.noticeClient?.start()
self.logger.log("[SDLContext] notice_client task cancel", level: .warning)
}
private func startUDPHole() async throws {
self.udpHoleActor = try await SDLUDPHoleActor(logger: self.logger)
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await self.udpHoleActor?.start()
}
group.addTask {
while !Task.isCancelled {
try Task.checkCancellation()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
try Task.checkCancellation()
if let udpHoleActor = self.udpHoleActor {
let cookie = await udpHoleActor.getCookieId()
var stunRequest = SDLStunRequest()
stunRequest.cookie = cookie
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.devAddr.networkID
stunRequest.ip = self.devAddr.netAddr
stunRequest.mac = self.devAddr.mac
stunRequest.natType = UInt32(self.natType.rawValue)
let remoteAddress = self.config.stunSocketAddress
await udpHoleActor.send(type: .stunRequest, data: try stunRequest.serializedData(), remoteAddress: remoteAddress)
self.lastCookie = cookie
}
}
}
group.addTask {
if let eventFlow = self.udpHoleActor?.eventFlow {
for try await event in eventFlow {
try Task.checkCancellation()
try await self.handleUDPEvent(event: event)
}
}
}
if let _ = try await group.next() {
group.cancelAll()
}
}
}
private func startSuperClient() async throws {
self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort, logger: self.logger)
try await withThrowingTaskGroup(of: Void.self) { group in
defer {
self.logger.log("[SDLContext] super client task cancel", level: .warning)
}
group.addTask {
try await self.superClientActor?.start()
}
group.addTask {
if let eventFlow = self.superClientActor?.eventFlow {
for try await event in eventFlow {
try await self.handleSuperEvent(event: event)
}
}
}
if let _ = try await group.next() {
group.cancelAll()
}
}
}
private func startMonitor() async {
self.monitor = SDLNetworkMonitor()
for await event in self.monitor!.eventStream {
switch event {
case .changed:
// TODO nat
//self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config, logger: self.logger)
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info)
case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning)
}
}
}
private func startDnsClient() async throws {
let remoteDnsServer = config.remoteDnsServer
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353)
self.dnsClient = try await DNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
try await withThrowingTaskGroup(of: Void.self) { group in
defer {
self.logger.log("[SDLContext] dns client task cancel", level: .warning)
}
group.addTask {
try await self.dnsClient?.start()
}
group.addTask {
if let packetFlow = self.dnsClient?.packetFlow {
for await packet in packetFlow {
let nePacket = NEPacket(data: packet, protocolFamily: 2)
await self.providerActor.writePackets(packets: [nePacket])
}
}
}
if let _ = try await group.next() {
group.cancelAll()
}
}
}
private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws {
switch event {
case .ready:
self.logger.log("[SDLContext] get registerSuper, mac address: \(SDLUtil.formatMacAddress(mac: self.devAddr.mac))", level: .debug)
var registerSuper = SDLRegisterSuper()
registerSuper.version = UInt32(self.config.version)
registerSuper.clientID = self.config.clientId
registerSuper.devAddr = self.devAddr
registerSuper.pubKey = self.rsaCipher.pubKey
registerSuper.token = self.config.token
registerSuper.networkCode = self.config.networkCode
registerSuper.hostname = self.config.hostname
guard let message = try await self.superClientActor?.request(type: .registerSuper, data: try registerSuper.serializedData()).get() else {
return
}
switch message.packet {
case .registerSuperAck(let registerSuperAck):
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType)
self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count), network_id:\(registerSuperAck.devAddr.networkID)", level: .info)
self.devAddr = registerSuperAck.devAddr
if upgradeType == .force {
let forceUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
await self.noticeClient?.send(data: forceUpgrade)
exit(-1)
}
// tun
do {
try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer)
self.startReader()
} catch let err {
self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
self.aesKey = aesKey
if upgradeType == .normal {
let normalUpgrade = NoticeMessage.upgrade(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
await self.noticeClient?.send(data: normalUpgrade)
}
case .registerSuperNak(let nakPacket):
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
await self.noticeClient?.send(data: alertNotice)
exit(-1)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
await self.noticeClient?.send(data: alertNotice)
}
self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning)
default:
()
}
case .event(let evt):
switch evt {
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
await sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
var register = SDLRegister()
register.networkID = self.devAddr.networkID
register.srcMac = self.devAddr.mac
register.dstMac = sendRegisterEvent.dstMac
await self.udpHoleActor?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
await self.noticeClient?.send(data: alertNotice)
exit(-1)
}
case .command(let packetId, let command):
switch command {
case .changeNetwork(let changeNetworkCommand):
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey))
self.logger.log("[SDLContext] change network command get aes_key len: \(aesKey.count)", level: .info)
self.devAddr = changeNetworkCommand.devAddr
// tun
do {
try await self.providerActor.setNetworkSettings(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer)
self.startReader()
} catch let err {
self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
self.aesKey = aesKey
var commandAck = SDLCommandAck()
commandAck.status = true
await self.superClientActor?.send(type: .commandAck, packetId: packetId, data: try commandAck.serializedData())
}
}
}
private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws {
switch event {
case .ready:
//
//self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
self.logger.log("[SDLContext] nat type is: \(self.natType)", level: .debug)
case .message(let remoteAddress, let message):
switch message {
case .register(let register):
self.logger.log("register packet: \(register), dev_addr: \(self.devAddr)", level: .debug)
// tun,
if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = self.devAddr.networkID
registerAck.srcMac = self.devAddr.mac
registerAck.dstMac = register.srcMac
await self.udpHoleActor?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
case .registerAck(let registerAck):
// tun,
if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
case .stunReply(let stunReply):
let cookie = stunReply.cookie
if cookie == self.lastCookie {
// nat
//self.natAddress = stunReply.natAddress
self.logger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())", level: .debug)
}
default:
()
}
case .data(let data):
let mac = LayerPacket.MacAddress(data: data.dstMac)
guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else {
return
}
do {
let layerPacket = try LayerPacket(layerData: decyptedData)
await self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == self.devAddr.netAddr {
switch arpPacket.opcode {
case .request:
self.logger.log("[SDLContext] get arp request packet", level: .debug)
let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
self.logger.log("[SDLContext] get arp response packet", level: .debug)
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(self.devAddr.netAddr))", level: .debug)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet", level: .debug)
}
case .ipv4:
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
await self.providerActor.writePackets(packets: [packet])
default:
self.logger.log("[SDLContext] get invalid packet", level: .debug)
}
} catch let err {
self.logger.log("[SDLContext] didReadData err: \(err)", level: .warning)
}
}
}
//
// public func flowReportTask() {
// Task {
// //
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
// .sink { _ in
// Task {
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
// }
// }
// }
// }
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task(priority: .high) {
repeat {
let packets = await self.providerActor.readPackets()
for packet in packets {
await self.dealPacket(data: packet)
}
} while true
}
}
//
private func dealPacket(data: Data) async {
guard let packet = IPPacket(data) else {
return
}
if DNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
NSLog("destIp: \(destIp), int: \(packet.header.destination)")
await self.dnsClient?.forward(ipPacket: packet)
}
else {
Task.detached {
let dstIp = packet.header.destination
// , ip
if dstIp == self.devAddr.netAddr {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
await self.providerActor.writePackets(packets: [nePacket])
return
}
// arpmac
if let dstMac = await self.arpServer.query(ip: dstIp) {
await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
// arp
let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF])
let arpReqeust = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp)
await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal())
self.logger.log("[SDLContext] dstIp: \(dstIp) arp query not found", level: .debug)
}
}
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async {
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data)
guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else {
return
}
//
var dataPacket = SDLData()
dataPacket.networkID = self.devAddr.networkID
dataPacket.srcMac = self.devAddr.mac
dataPacket.dstMac = dstMac
dataPacket.ttl = 255
dataPacket.data = encodedPacket
// session
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug)
await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: session.natAddress)
await self.flowTracer.inc(num: data.count, type: .p2p)
}
else {
// super_node
let superAddress = self.config.stunSocketAddress
await self.udpHoleActor?.send(type: .data, data: try! dataPacket.serializedData(), remoteAddress: superAddress)
//
await self.flowTracer.inc(num: data.count, type: .forward)
//
let registerRequest = RegisterRequest(srcMac: self.devAddr.mac, dstMac: dstMac, networkId: self.devAddr.networkID)
self.submitRegisterRequest(request: registerRequest)
}
}
private func submitRegisterRequest(request: RegisterRequest) {
self.locker.lock()
defer {
self.locker.unlock()
}
let dstMac = request.dstMac
if let publisher = self.holerPublishers[dstMac] {
publisher.send(request)
} else {
let publisher = PassthroughSubject<RegisterRequest, Never>()
publisher.throttle(for: .seconds(5), scheduler: DispatchQueue.global(), latest: true)
.sink { request in
Task {
await self.tryHole(request: request)
}
}
.store(in: &self.bag)
self.holerPublishers[dstMac] = publisher
}
}
private func tryHole(request: RegisterRequest) async {
var queryInfo = SDLQueryInfo()
queryInfo.dstMac = request.dstMac
guard let message = try? await self.superClientActor?.request(type: .queryInfo, data: try queryInfo.serializedData()).get() else {
return
}
switch message.packet {
case .empty:
self.logger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug)
case .peerInfo(let peerInfo):
if let remoteAddress = peerInfo.v4Info.socketAddress() {
self.logger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .debug)
// register
var register = SDLRegister()
register.networkID = request.networkId
register.srcMac = request.srcMac
register.dstMac = request.dstMac
await self.udpHoleActor?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress)
} else {
self.logger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
}
default:
self.logger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning)
}
}
deinit {
self.rootTask?.cancel()
self.udpHoleActor = nil
self.superClientActor = nil
self.dnsClient = nil
}
// mac
public static func getMacAddress() -> Data {
let key = "gMacAddress2"
let userDefaults = UserDefaults.standard
if let mac = userDefaults.value(forKey: key) as? Data {
return mac
}
else {
let mac = generateMacAddress()
userDefaults.setValue(mac, forKey: key)
return mac
}
}
// mac
private static func generateMacAddress() -> Data {
var macAddress = [UInt8](repeating: 0, count: 6)
for i in 0..<6 {
macAddress[i] = UInt8.random(in: 0...255)
}
return Data(macAddress)
}
}