punchnet-macos/Tun/Punchnet/SDLContextActor.swift
2026-02-02 12:07:29 +08:00

528 lines
22 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
//
/*
1. rsa
*/
actor SDLContextActor {
nonisolated let config: SDLConfiguration
// nat
var natType: SDLNATProberActor.NatType = .blocked
// AES
nonisolated let aesCipher: AESCipher
// aes
var aesKey: Data = Data()
// rsa, public_key
nonisolated let rsaCipher: RSACipher
//
var udpHole: SDLUDPHole?
nonisolated let providerAdapter: SDLTunnelProviderAdapter
var puncherActor: SDLPuncherActor?
// dnsclient
var dnsClient: SDLDNSClient?
//
var proberActor: SDLNATProberActor?
//
private var readTask: Task<(), Never>?
private var sessionManager: SessionManager
private var arpServer: ArpServer
//
private var monitor: SDLNetworkMonitor?
// socket
private var noticeClient: SDLNoticeClient?
//
nonisolated private let flowTracer = SDLFlowTracer()
nonisolated private let logger: SDLLogger
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
self.sessionManager = SessionManager()
self.arpServer = ArpServer(known_macs: [:])
self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger)
}
public func startNoticeClient() async throws {
// noticeClient
self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
self.logger.log("[SDLContext] noticeClient started")
try await self.noticeClient?.waitClose()
}
public func startMonitor() async throws {
// monitor
let monitor = SDLNetworkMonitor()
monitor.start()
self.logger.log("[SDLContext] monitor started")
self.monitor = monitor
for await event in monitor.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info)
case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning)
}
}
}
public func startDnsClient() async throws {
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
let channel = try dnsClient.start()
self.logger.log("[SDLContext] dnsClient started")
self.dnsClient = dnsClient
try await withThrowingTaskGroup(of: Void.self) {group in
group.addTask {
//
for await packet in dnsClient.packetFlow {
try Task.checkCancellation()
let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
}
}
group.addTask {
try await channel.closeFuture.get()
}
try await group.next()
self.logger.log("[SDLContext] taskGroup cancel")
group.cancelAll()
}
}
public func startUDPHole() async throws {
// udp
let udpHole = try SDLUDPHole(logger: self.logger)
let channel = try udpHole.start()
self.logger.log("[SDLContext] udpHole started")
self.udpHole = udpHole
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await channel.closeFuture.get()
}
// UDP
group.addTask {
while true {
try Task.checkCancellation()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
try Task.checkCancellation()
await self.sendStunRequest()
}
}
// event
group.addTask {
for try await event in udpHole.eventStream {
try Task.checkCancellation()
switch event {
case .ready:
await self.handleUDPHoleReady()
case .closed:
()
}
}
}
//
group.addTask {
for try await data in udpHole.dataStream {
try Task.checkCancellation()
try await self.handleData(data: data)
}
}
// signal
group.addTask {
for try await(remoteAddress, signal) in udpHole.signalStream {
try Task.checkCancellation()
switch signal {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
try await group.next()
group.cancelAll()
self.logger.log("[SDLContext] taskGroup cancel")
}
}
public func stop() async {
self.udpHole = nil
self.noticeClient = nil
self.readTask?.cancel()
}
private func setNatType(natType: SDLNATProberActor.NatType) {
self.natType = natType
}
private func handleUDPHoleReady() async {
if let udpHole = self.udpHole {
self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger)
self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger)
}
await withDiscardingTaskGroup { group in
group.addTask {
// nat
if let natType = await self.proberActor?.probeNatType() {
await self.setNatType(natType: natType)
self.logger.log("[SDLContext] nat_type is: \(natType)")
}
}
group.addTask {
var registerSuper = SDLRegisterSuper()
registerSuper.pktID = 0
registerSuper.clientID = self.config.clientId
registerSuper.networkID = self.config.networkAddress.networkId
registerSuper.mac = self.config.networkAddress.mac
registerSuper.ip = self.config.networkAddress.ip
registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen)
registerSuper.hostname = self.config.hostname
registerSuper.pubKey = self.rsaCipher.pubKey
registerSuper.accessToken = self.config.accessToken
if let registerSuperData = try? registerSuper.serializedData() {
self.logger.log("[SDLContext] will send register super")
await self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress)
}
}
}
}
private func sendStunRequest() async {
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
if let stunData = try? stunRequest.serializedData() {
let remoteAddress = self.config.stunSocketAddress
self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info)
// tun
do {
let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer)
self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress))
self.startReader()
} catch let err {
self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
self.aesKey = aesKey
}
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) {
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
exit(-1)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
}
self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning)
}
private func handleEvent(event: SDLEvent) throws {
switch event {
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
var register = SDLRegister()
register.networkID = self.config.networkAddress.networkId
register.srcMac = self.config.networkAddress.mac
register.dstMac = sendRegisterEvent.dstMac
self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
self.noticeClient?.send(data: alertNotice)
exit(-1)
}
}
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
let networkAddr = config.networkAddress
self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
}
private func handleData(data: SDLData) throws {
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else {
return
}
let layerPacket = try LayerPacket(layerData: decyptedData)
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
self.logger.log("[SDLContext] get arp request packet", level: .debug)
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
self.logger.log("[SDLContext] get arp response packet", level: .debug)
self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet", level: .debug)
}
case .ipv4:
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [packet])
default:
self.logger.log("[SDLContext] get invalid packet", level: .debug)
}
}
//
// public func flowReportTask() {
// Task {
// //
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
// .sink { _ in
// Task {
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
// }
// }
// }
// }
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task(priority: .high) {
while true {
if Task.isCancelled {
return
}
let packets = await self.providerAdapter.readPackets()
let ipPackets = packets.compactMap { IPPacket($0) }
await self.batchProcessPackets(batchSize: 20, packets: ipPackets)
}
}
}
// ip
private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async {
for startIndex in stride(from: 0, to: packets.count, by: batchSize) {
let endIndex = Swift.min(startIndex + batchSize, packets.count)
let chunkPackets = packets[startIndex..<endIndex]
await withDiscardingTaskGroup() { group in
for packet in chunkPackets {
group.addTask {
await self.dealPacket(packet: packet)
}
}
}
}
}
//
private func dealPacket(packet: IPPacket) {
let networkAddr = self.config.networkAddress
if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
self.logger.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug)
self.dnsClient?.forward(ipPacket: packet)
return
}
let dstIp = packet.header.destination
// , ip
if dstIp == networkAddr.ip {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
return
}
// arpmac
if let dstMac = self.arpServer.query(ip: dstIp) {
self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
self.logger.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug)
// arp广
let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp)
self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal())
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) {
let networkAddr = self.config.networkAddress
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data)
guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else {
return
}
//
var dataPacket = SDLData()
dataPacket.networkID = networkAddr.networkId
dataPacket.srcMac = networkAddr.mac
dataPacket.dstMac = dstMac
dataPacket.ttl = 255
dataPacket.data = encodedPacket
let data = try! dataPacket.serializedData()
// 广
if ARPPacket.isBroadcastMac(dstMac) {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
}
else {
// session
if let session = self.sessionManager.getSession(toAddress: dstMac) {
self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug)
self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress)
self.flowTracer.inc(num: data.count, type: .p2p)
}
else {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
//
self.flowTracer.inc(num: data.count, type: .forward)
//
self.puncherActor?.submitRegisterRequest(request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId))
}
}
}
deinit {
self.udpHole = nil
self.dnsClient = nil
}
}
private extension UInt32 {
// ip
func asIpAddress() -> String {
return SDLUtil.int32ToIp(self)
}
}