punchnet-macos/Tun/Punchnet/SDLContextActor.swift
2026-02-04 00:07:06 +08:00

582 lines
23 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
//
/*
1. rsa
*/
actor SDLContextActor {
nonisolated let config: SDLConfiguration
// nat
var natType: SDLNATProberActor.NatType = .blocked
// AES
nonisolated let aesCipher: AESCipher
// aes
private var aesKey: Data?
// rsa, public_key
nonisolated let rsaCipher: RSACipher
//
private var udpHole: SDLUDPHole?
private var udpHoleWorkers: [Task<Void, Never>]?
nonisolated let providerAdapter: SDLTunnelProviderAdapter
var puncherActor: SDLPuncherActor?
// dnsclient
private var dnsClient: SDLDNSClient?
private var dnsWorker: Task<Void, Never>?
//
var proberActor: SDLNATProberActor?
//
private var readTask: Task<(), Never>?
private var sessionManager: SessionManager
private var arpServer: ArpServer
//
private var monitor: SDLNetworkMonitor?
private var monitorWorker: Task<Void, Never>?
// socket
private var noticeClient: SDLNoticeClient?
//
nonisolated private let flowTracer = SDLFlowTracer()
//
private var loopChildWorkers: [Task<Void, Never>] = []
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) {
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
self.sessionManager = SessionManager()
self.arpServer = ArpServer(known_macs: [:])
self.providerAdapter = SDLTunnelProviderAdapter(provider: provider)
}
public func start() {
self.startMonitor()
self.loopChildWorkers.append(spawnLoop {
let noticeClient = try self.startNoticeClient()
SDLLogger.shared.log("[SDLContext] noticeClient running!!!!")
try await noticeClient.waitClose()
SDLLogger.shared.log("[SDLContext] noticeClient closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let dnsClient = try await self.startDnsClient()
SDLLogger.shared.log("[SDLContext] dns running!!!!")
try await dnsClient.waitClose()
SDLLogger.shared.log("[SDLContext] dns closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let udpHole = try await self.startUDPHole()
SDLLogger.shared.log("[SDLContext] udp running!!!!")
try await udpHole.waitClose()
SDLLogger.shared.log("[SDLContext] udp closed!!!!")
})
}
private func startNoticeClient() throws -> SDLNoticeClient {
// noticeClient
let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: SDLLogger.shared)
noticeClient.start()
SDLLogger.shared.log("[SDLContext] noticeClient started")
self.noticeClient = noticeClient
return noticeClient
}
private func startMonitor() {
self.monitorWorker?.cancel()
self.monitorWorker = nil
// monitor
let monitor = SDLNetworkMonitor()
monitor.start()
SDLLogger.shared.log("[SDLContext] monitor started")
self.monitor = monitor
self.monitorWorker = Task.detached {
for await event in monitor.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
SDLLogger.shared.log("didNetworkPathChanged, nat type is:", level: .info)
case .unreachable:
SDLLogger.shared.log("didNetworkPathUnreachable", level: .warning)
}
}
}
}
private func startDnsClient() async throws -> SDLDNSClient {
self.dnsWorker?.cancel()
self.dnsWorker = nil
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: SDLLogger.shared)
try dnsClient.start()
SDLLogger.shared.log("[SDLContext] dnsClient started")
self.dnsClient = dnsClient
self.dnsWorker = Task.detached {
//
for await packet in dnsClient.packetFlow {
if Task.isCancelled {
break
}
let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
}
}
return dnsClient
}
private func startUDPHole() async throws -> SDLUDPHole {
self.udpHoleWorkers?.forEach {$0.cancel()}
self.udpHoleWorkers = nil
// udp
let udpHole = try SDLUDPHole(logger: SDLLogger.shared)
try udpHole.start()
SDLLogger.shared.log("[SDLContext] udpHole started")
self.udpHole = udpHole
await udpHole.channelIsActived()
await self.handleUDPHoleReady()
//
let pingTask = Task.detached {
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
let timerStream = SDLAsyncTimerStream()
timerStream.start(cont)
for await _ in stream {
if Task.isCancelled {
break
}
SDLLogger.shared.log("[SDLContext] will do stunRequest22")
await self.sendStunRequest()
SDLLogger.shared.log("[SDLContext] will do stunRequest44")
}
SDLLogger.shared.log("[SDLContext] will do stunRequest55")
}
//
let dataTask = Task.detached {
for await data in udpHole.dataStream {
if Task.isCancelled {
break
}
try? await self.handleData(data: data)
}
}
//
let signalTask = Task.detached {
for await(remoteAddress, signal) in udpHole.signalStream {
if Task.isCancelled {
break
}
switch signal {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try? await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register):
try? await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
self.udpHoleWorkers = [pingTask, dataTask, signalTask]
return udpHole
}
// context
public func stop() async {
self.loopChildWorkers.forEach { $0.cancel() }
self.loopChildWorkers.removeAll()
self.udpHoleWorkers?.forEach { $0.cancel() }
self.udpHoleWorkers = nil
self.dnsWorker?.cancel()
self.dnsWorker = nil
self.monitorWorker?.cancel()
self.monitorWorker = nil
self.readTask?.cancel()
self.readTask = nil
}
private func setNatType(natType: SDLNATProberActor.NatType) {
self.natType = natType
}
// super
private func handleUDPHoleReady() async {
guard let udpHole = self.udpHole else {
return
}
self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress)
self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray)
// nat
Task.detached {
if let natType = await self.proberActor?.probeNatType() {
await self.setNatType(natType: natType)
SDLLogger.shared.log("[SDLContext] nat_type is: \(natType)")
}
}
//
var registerSuper = SDLRegisterSuper()
registerSuper.pktID = 0
registerSuper.clientID = self.config.clientId
registerSuper.networkID = self.config.networkAddress.networkId
registerSuper.mac = self.config.networkAddress.mac
registerSuper.ip = self.config.networkAddress.ip
registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen)
registerSuper.hostname = self.config.hostname
registerSuper.pubKey = self.rsaCipher.pubKey
registerSuper.accessToken = self.config.accessToken
if let registerSuperData = try? registerSuper.serializedData() {
SDLLogger.shared.log("[SDLContext] will send register super")
self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress)
}
}
private func sendStunRequest() {
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
SDLLogger.shared.log("[SDLContext] will send stun request")
if let stunData = try? stunRequest.serializedData() {
let remoteAddress = self.config.stunSocketAddress
self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
SDLLogger.shared.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info)
// tun
do {
let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer)
SDLLogger.shared.log("[SDLContext] setNetworkSettings successed")
self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress))
SDLLogger.shared.log("[SDLContext] send ip successed")
self.startReader()
SDLLogger.shared.log("[SDLContext] reader started")
} catch let err {
SDLLogger.shared.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
self.aesKey = aesKey
}
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) {
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
exit(-1)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
}
SDLLogger.shared.log("[SDLContext] Get a SuperNak message exit", level: .warning)
}
private func handleEvent(event: SDLEvent) throws {
switch event {
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
SDLLogger.shared.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
SDLLogger.shared.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
var register = SDLRegister()
register.networkID = self.config.networkAddress.networkId
register.srcMac = self.config.networkAddress.mac
register.dstMac = sendRegisterEvent.dstMac
self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
self.noticeClient?.send(data: alertNotice)
exit(-1)
}
}
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
let networkAddr = config.networkAddress
SDLLogger.shared.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.shared.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.shared.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
}
private func handleData(data: SDLData) throws {
guard let aesKey = self.aesKey else {
return
}
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
guard let decyptedData = try? self.aesCipher.decypt(aesKey: aesKey, data: Data(data.data)) else {
return
}
let layerPacket = try LayerPacket(layerData: decyptedData)
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
SDLLogger.shared.log("[SDLContext] get arp request packet", level: .debug)
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
SDLLogger.shared.log("[SDLContext] get arp response packet", level: .debug)
self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
SDLLogger.shared.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug)
}
} else {
SDLLogger.shared.log("[SDLContext] get invalid arp packet", level: .debug)
}
case .ipv4:
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [packet])
default:
SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug)
}
}
//
// public func flowReportTask() {
// Task {
// //
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
// .sink { _ in
// Task {
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
// }
// }
// }
// }
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task.detached(priority: .high) {
while true {
if Task.isCancelled {
return
}
let packets = await self.providerAdapter.readPackets()
let ipPackets = packets.compactMap { IPPacket($0) }
for ipPacket in ipPackets {
await self.dealPacket(packet: ipPacket)
}
}
}
}
//
private func dealPacket(packet: IPPacket) {
let networkAddr = self.config.networkAddress
if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
SDLLogger.shared.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug)
self.dnsClient?.forward(ipPacket: packet)
return
}
let dstIp = packet.header.destination
// , ip
if dstIp == networkAddr.ip {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
return
}
// arpmac
if let dstMac = self.arpServer.query(ip: dstIp) {
self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
SDLLogger.shared.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug)
// arp广
let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp)
self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal())
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) {
let networkAddr = self.config.networkAddress
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data)
guard let aesKey = self.aesKey, let encodedPacket = try? self.aesCipher.encrypt(aesKey: aesKey, data: layerPacket.marshal()) else {
return
}
//
var dataPacket = SDLData()
dataPacket.networkID = networkAddr.networkId
dataPacket.srcMac = networkAddr.mac
dataPacket.dstMac = dstMac
dataPacket.ttl = 255
dataPacket.data = encodedPacket
let data = try! dataPacket.serializedData()
// 广
if ARPPacket.isBroadcastMac(dstMac) {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
}
else {
// session
if let session = self.sessionManager.getSession(toAddress: dstMac) {
SDLLogger.shared.log("[SDLContext] send packet by session: \(session)", level: .debug)
self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress)
self.flowTracer.inc(num: data.count, type: .p2p)
}
else {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
//
self.flowTracer.inc(num: data.count, type: .forward)
//
self.puncherActor?.submitRegisterRequest(request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId))
}
}
}
private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task<Void, Never> {
return Task.detached {
while !Task.isCancelled {
do {
try await body()
} catch is CancellationError {
break
} catch {
try? await Task.sleep(nanoseconds: 2_000_000_000)
}
}
}
}
deinit {
self.udpHole = nil
self.dnsClient = nil
}
}
private extension UInt32 {
// ip
func asIpAddress() -> String {
return SDLUtil.int32ToIp(self)
}
}