punchnet-macos/Tun/Punchnet/Actors/SDLContextActor.swift
2026-02-15 00:34:40 +08:00

725 lines
29 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
//
/*
1. rsa
*/
actor SDLContextActor {
enum State {
case unregistered
case registered
}
nonisolated let config: SDLConfiguration
private var state: State = .unregistered
// nat
var natType: SDLNATProberActor.NatType = .blocked
// AES
nonisolated let aesCipher: AESCipher
// aes
private var aesKey: Data?
// session token
private var sessionToken: Data?
// rsa, public_key
nonisolated let rsaCipher: RSACipher
//
private var udpHole: SDLUDPHole?
private var udpHoleWorkers: [Task<Void, Never>]?
// dnsclient
private var dnsClient: SDLDNSClient?
private var dnsWorker: Task<Void, Never>?
private var quicClient: SDLQUICClient?
private var quicWorker: Task<Void, Never>?
nonisolated private let puncherActor: SDLPuncherActor
//
nonisolated private let proberActor: SDLNATProberActor
//
private var readTask: Task<(), Never>?
private var sessionManager: SessionManager
private var arpServer: ArpServerActor
//
private var monitor: SDLNetworkMonitor?
private var monitorWorker: Task<Void, Never>?
// socket
private var noticeClient: SDLNoticeClient?
//
nonisolated private let flowTracer = SDLFlowTracer()
//
private var loopChildWorkers: [Task<Void, Never>] = []
private let provider: NEPacketTunnelProvider
//
private let identifyStore: IdentityStore
private let snapshotPublisher: SnapshotPublisher<IdentitySnapshot>
private let policyRequesterActor: PolicyRequesterActor
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) {
self.provider = provider
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
self.sessionManager = SessionManager()
self.arpServer = ArpServerActor(known_macs: [:])
self.puncherActor = SDLPuncherActor(querySocketAddress: config.stunSocketAddress)
self.proberActor = SDLNATProberActor(addressArray: config.stunProbeSocketAddressArray)
//
let snapshotPublisher = SnapshotPublisher(initial: IdentitySnapshot.empty())
self.identifyStore = IdentityStore(publisher: snapshotPublisher)
self.snapshotPublisher = snapshotPublisher
self.policyRequesterActor = PolicyRequesterActor(querySocketAddress: config.stunSocketAddress)
}
public func start() {
self.startMonitor()
self.loopChildWorkers.append(spawnLoop {
SDLLogger.shared.log("[SDLContext] try start quicClient")
let quicClient = try await self.startQUICClient()
SDLLogger.shared.log("[SDLContext] quicClient running!!!!")
await quicClient.waitClose()
SDLLogger.shared.log("[SDLContext] quicClient closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let noticeClient = try self.startNoticeClient()
SDLLogger.shared.log("[SDLContext] noticeClient running!!!!")
try await noticeClient.waitClose()
SDLLogger.shared.log("[SDLContext] noticeClient closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let dnsClient = try await self.startDnsClient()
SDLLogger.shared.log("[SDLContext] dns running!!!!")
try await dnsClient.waitClose()
SDLLogger.shared.log("[SDLContext] dns closed!!!!")
})
self.loopChildWorkers.append(spawnLoop {
let udpHole = try await self.startUDPHole()
SDLLogger.shared.log("[SDLContext] udp running!!!!")
try await udpHole.waitClose()
SDLLogger.shared.log("[SDLContext] udp closed!!!!")
})
}
private func startQUICClient() async throws -> SDLQUICClient {
// monitor
let quicClient = SDLQUICClient(host: "118.178.229.213", port: 1365)
quicClient.start()
// quic
try await quicClient.waitReady()
SDLLogger.shared.log("[SDLContext] start quic client ready")
self.quicWorker = Task.detached {
for await message in quicClient.receiveStream(maxLen: 86400) {
SDLLogger.shared.log("[SDLContext] quic client receive message: \(message)")
}
}
self.quicClient = quicClient
//
self.doRegisterSuper()
return quicClient
}
private func startNoticeClient() throws -> SDLNoticeClient {
// noticeClient
let noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: SDLLogger.shared)
noticeClient.start()
SDLLogger.shared.log("[SDLContext] noticeClient started")
self.noticeClient = noticeClient
return noticeClient
}
private func startMonitor() {
self.monitorWorker?.cancel()
self.monitorWorker = nil
// monitor
let monitor = SDLNetworkMonitor()
monitor.start()
SDLLogger.shared.log("[SDLContext] monitor started")
self.monitor = monitor
self.monitorWorker = Task.detached {
for await event in monitor.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
SDLLogger.shared.log("didNetworkPathChanged, nat type is:", level: .info)
case .unreachable:
SDLLogger.shared.log("didNetworkPathUnreachable", level: .warning)
}
}
}
}
private func startDnsClient() async throws -> SDLDNSClient {
self.dnsWorker?.cancel()
self.dnsWorker = nil
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: SDLLogger.shared)
try dnsClient.start()
SDLLogger.shared.log("[SDLContext] dnsClient started")
self.dnsClient = dnsClient
self.dnsWorker = Task.detached {
//
for await packet in dnsClient.packetFlow {
if Task.isCancelled {
break
}
let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([nePacket])
}
}
return dnsClient
}
private func startUDPHole() async throws -> SDLUDPHole {
self.udpHoleWorkers?.forEach {$0.cancel()}
self.udpHoleWorkers = nil
// udp
let udpHole = try SDLUDPHole()
try udpHole.start()
SDLLogger.shared.log("[SDLContext] udpHole started")
// udp
let localAddress = udpHole.getLocalAddress()
// udpHole
await udpHole.channelIsActived()
//
let pingTask = Task.detached {
let (stream, cont) = AsyncStream.makeStream(of: Void.self)
let timerStream = SDLAsyncTimerStream()
timerStream.start(cont)
for await _ in stream {
if Task.isCancelled {
break
}
await self.sendStunRequest()
}
SDLLogger.shared.log("[SDLContext] udp pingTask cancel")
}
//
let messageTask = Task.detached {
for await (remoteAddress, message) in udpHole.messageStream {
if Task.isCancelled {
break
}
switch message {
case .stunProbeReply(let probeReply):
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
case .register(let register):
try? await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
case .data(let data):
try? await self.handleData(data: data)
// case .policyReponse(let policyResponse):
// SDLLogger.shared.log("[SDLContext] get a policyResponse: \(policyResponse.totalNum) of \(policyResponse.index), bytes: \(policyResponse.rules.count)")
// //
// await self.identifyStore.apply(policyResponse: policyResponse)
}
}
SDLLogger.shared.log("[SDLContext] udp signalTask cancel")
}
self.udpHole = udpHole
self.udpHoleWorkers = [pingTask, messageTask]
// nat
Task {
let natType = await self.proberActor.probeNatType(using: udpHole)
self.setNatType(natType: natType)
SDLLogger.shared.log("[SDLContext] nat_type is: \(natType)")
}
return udpHole
}
// context
public func stop() async {
self.loopChildWorkers.forEach { $0.cancel() }
self.loopChildWorkers.removeAll()
self.udpHoleWorkers?.forEach { $0.cancel() }
self.udpHoleWorkers = nil
self.dnsWorker?.cancel()
self.dnsWorker = nil
self.monitorWorker?.cancel()
self.monitorWorker = nil
self.readTask?.cancel()
self.readTask = nil
}
private func setNatType(natType: SDLNATProberActor.NatType) {
self.natType = natType
}
private func sendStunRequest() {
guard let sessionToken = self.sessionToken else {
return
}
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
stunRequest.sessionToken = sessionToken
if let stunData = try? stunRequest.serializedData() {
let remoteAddress = self.config.stunSocketAddress
self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa
self.aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
self.sessionToken = registerSuperAck.sessionToken
await self.triggerPolicy()
SDLLogger.shared.log("[SDLContext] get registerSuperAck, aes_key len: \(self.aesKey!.count)", level: .info)
// tun
do {
try await self.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer)
SDLLogger.shared.log("[SDLContext] setNetworkSettings successed")
self.state = .registered
self.startReader()
} catch let err {
SDLLogger.shared.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
self.provider.cancelTunnelWithError(err)
}
}
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) {
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
// 退
let error = NSError(domain: "com.jihe.punchnet.tun", code: -1)
self.provider.cancelTunnelWithError(error)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
}
SDLLogger.shared.log("[SDLContext] Get a SuperNak message exit", level: .warning)
}
private func handleEvent(event: SDLEvent) async throws {
switch event {
// case .dropMacs(let dropMacsEvent):
// SDLLogger.shared.log("[SDLContext] drop macs", level: .info)
// await self.arpServer.dropMacs(macs: dropMacsEvent.macs)
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
SDLLogger.shared.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
await sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
SDLLogger.shared.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
var register = SDLRegister()
register.networkID = self.config.networkAddress.networkId
register.srcMac = self.config.networkAddress.mac
register.dstMac = sendRegisterEvent.dstMac
self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
self.noticeClient?.send(data: alertNotice)
// 退
let error = NSError(domain: "com.jihe.punchnet.tun", code: -2)
self.provider.cancelTunnelWithError(error)
}
}
private func doRegisterSuper() {
//
var registerSuper = SDLRegisterSuper()
registerSuper.clientID = self.config.clientId
registerSuper.networkID = self.config.networkAddress.networkId
registerSuper.mac = self.config.networkAddress.mac
registerSuper.ip = self.config.networkAddress.ip
registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen)
registerSuper.hostname = self.config.hostname
registerSuper.pubKey = self.rsaCipher.pubKey
registerSuper.accessToken = self.config.accessToken
if let registerSuperData = try? registerSuper.serializedData() {
SDLLogger.shared.log("[SDLContext] will send register super")
self.quicClient?.send(type: .registerSuper, data: registerSuperData)
// 5
Task {
try await Task.sleep(for: .seconds(5))
self.checkRegisterState()
}
}
}
//
private func checkRegisterState() {
if self.state == .unregistered {
SDLLogger.shared.log("[SDLContext] register super failed, retry")
self.doRegisterSuper()
}
}
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws {
let networkAddr = config.networkAddress
SDLLogger.shared.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
SDLLogger.shared.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
SDLLogger.shared.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
}
private func handleData(data: SDLData) async throws {
guard let aesKey = self.aesKey else {
return
}
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
guard let decyptedData = try? self.aesCipher.decypt(aesKey: aesKey, data: Data(data.data)) else {
return
}
let layerPacket = try LayerPacket(layerData: decyptedData)
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
SDLLogger.shared.log("[SDLContext] get arp request packet", level: .debug)
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
SDLLogger.shared.log("[SDLContext] get arp response packet", level: .debug)
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
SDLLogger.shared.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug)
}
} else {
SDLLogger.shared.log("[SDLContext] get invalid arp packet", level: .debug)
}
case .ipv4:
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
return
}
//
let identitySnapshot = self.snapshotPublisher.current()
if let ruleMap = identitySnapshot.lookup(data.identityID) {
let proto = ipPacket.header.proto
switch TransportProtocol(rawValue: proto) {
case .udp, .tcp:
if let dstPort = ipPacket.getDstPort(), ruleMap.isAllow(proto: proto, port: dstPort) {
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
}
case .icmp:
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
default:
()
}
} else {
//
if let sessionToken = self.sessionToken {
var policyRequest = SDLPolicyRequest()
policyRequest.srcIdentityID = data.identityID
policyRequest.dstIdentityID = self.config.identityId
await self.policyRequesterActor.submitPolicyRequest(using: self.udpHole, request: &policyRequest)
}
}
default:
SDLLogger.shared.log("[SDLContext] get invalid packet", level: .debug)
}
}
//
// public func flowReportTask() {
// Task {
// //
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
// .sink { _ in
// Task {
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
// }
// }
// }
// }
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task.detached(priority: .high) {
while true {
if Task.isCancelled {
return
}
let (packets, numbers) = await self.provider.packetFlow.readPackets()
for (data, number) in zip(packets, numbers) where number == 2 {
if let ipPacket = IPPacket(data) {
await self.dealPacket(packet: ipPacket)
}
}
}
}
}
//
private func dealPacket(packet: IPPacket) async {
let networkAddr = self.config.networkAddress
if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
SDLLogger.shared.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug)
self.dnsClient?.forward(ipPacket: packet)
return
}
let dstIp = packet.header.destination
// , ip
if dstIp == networkAddr.ip {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([nePacket])
return
}
// arpmac
if let dstMac = await self.arpServer.query(ip: dstIp) {
await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
SDLLogger.shared.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug)
// arp广
let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp)
await self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal())
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async {
let networkAddr = self.config.networkAddress
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data)
guard let udpHole = self.udpHole, let aesKey = self.aesKey, let encodedPacket = try? self.aesCipher.encrypt(aesKey: aesKey, data: layerPacket.marshal()) else {
return
}
//
var dataPacket = SDLData()
dataPacket.networkID = networkAddr.networkId
dataPacket.srcMac = networkAddr.mac
dataPacket.dstMac = dstMac
dataPacket.ttl = 255
dataPacket.identityID = self.config.identityId
dataPacket.data = encodedPacket
let data = try! dataPacket.serializedData()
// 广
if ARPPacket.isBroadcastMac(dstMac) {
// super_node
udpHole.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
}
else {
// session
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
SDLLogger.shared.log("[SDLContext] send packet by session: \(session)", level: .debug)
udpHole.send(type: .data, data: data, remoteAddress: session.natAddress)
self.flowTracer.inc(num: data.count, type: .p2p)
}
else {
// super_node
udpHole.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
//
self.flowTracer.inc(num: data.count, type: .forward)
//
Task.detached {
await self.puncherActor.submitRegisterRequest(using: udpHole, request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId))
}
}
}
}
//
private func setNetworkSettings(networkAddress: SDLConfiguration.NetworkAddress, dnsServer: String) async throws {
let routes: [NEIPv4Route] = [
NEIPv4Route(destinationAddress: networkAddress.netAddress, subnetMask: networkAddress.maskAddress),
NEIPv4Route(destinationAddress: dnsServer, subnetMask: "255.255.255.255"),
]
// Add code here to start the process of connecting the tunnel.
let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8")
networkSettings.mtu = 1250
// DNS
let networkDomain = networkAddress.networkDomain
let dnsSettings = NEDNSSettings(servers: [dnsServer])
dnsSettings.searchDomains = [networkDomain]
dnsSettings.matchDomains = [networkDomain]
dnsSettings.matchDomainsNoSearch = false
networkSettings.dnsSettings = dnsSettings
let ipv4Settings = NEIPv4Settings(addresses: [networkAddress.ipAddress], subnetMasks: [networkAddress.maskAddress])
//
//NEIPv4Route.default()
ipv4Settings.includedRoutes = routes
networkSettings.ipv4Settings = ipv4Settings
//
try await self.provider.setTunnelNetworkSettings(networkSettings)
}
// , 线packetFlow
func readPackets() async -> [Data] {
let (packets, numbers) = await self.provider.packetFlow.readPackets()
return zip(packets, numbers).compactMap { (data, number) in
return number == 2 ? data : nil
}
}
private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task<Void, Never> {
return Task.detached {
while !Task.isCancelled {
do {
try await body()
} catch is CancellationError {
break
} catch {
try? await Task.sleep(nanoseconds: 2_000_000_000)
}
}
}
}
// todo
private func triggerPolicy() async {
//
if let sessionToken = self.sessionToken {
var policyRequest = SDLPolicyRequest()
policyRequest.srcIdentityID = 1234
policyRequest.dstIdentityID = self.config.identityId
await self.policyRequesterActor.submitPolicyRequest(using: self.udpHole, request: &policyRequest)
}
}
deinit {
self.udpHole = nil
self.dnsClient = nil
}
}
private extension UInt32 {
// ip
func asIpAddress() -> String {
return SDLUtil.int32ToIp(self)
}
}