punchnet-macos/Tun/Punchnet/SDLContext.swift
2026-01-29 22:13:31 +08:00

516 lines
21 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//
// SDLContext.swift
// Tun
//
// Created by on 2024/2/29.
//
import Foundation
import NetworkExtension
import NIOCore
import Combine
//
/*
1. rsa
*/
@available(macOS 14, *)
public class SDLContext {
//
struct Route {
let dstAddress: String
let subnetMask: String
var debugInfo: String {
return "\(dstAddress):\(subnetMask)"
}
}
let config: SDLConfiguration
// nat
var natType: SDLNATProberActor.NatType = .blocked
// AES
var aesCipher: AESCipher
// aes
var aesKey: Data = Data()
// rsa, public_key
let rsaCipher: RSACipher
//
var udpHole: SDLUDPHole?
var providerActor: SDLTunnelProviderActor
var puncherActor: SDLPuncherActor
// dnsclient
var dnsClient: SDLDNSClient?
//
var proberActor: SDLNATProberActor?
//
private var readTask: Task<(), Never>?
private var sessionManager: SessionManager
private var arpServer: ArpServer
// stunRequestcookie
private var lastCookie: UInt32? = 0
//
private var monitor: SDLNetworkMonitor?
// socket
private var noticeClient: SDLNoticeClient?
//
private var flowTracer = SDLFlowTracerActor()
private var flowTracerCancel: AnyCancellable?
private let logger: SDLLogger
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
self.sessionManager = SessionManager()
self.arpServer = ArpServer(known_macs: [:])
self.providerActor = SDLTunnelProviderActor(provider: provider, logger: logger)
self.puncherActor = SDLPuncherActor(querySocketAddress: config.stunSocketAddress, logger: logger)
}
public func start() async throws {
// udp
self.udpHole = try SDLUDPHole(logger: self.logger)
try self.udpHole?.start()
self.logger.log("[SDLContext] udpHole started")
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
self.dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
try self.dnsClient?.start()
self.logger.log("[SDLContext] dnsClient started")
// noticeClient
self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
try self.noticeClient?.start()
self.logger.log("[SDLContext] noticeClient started")
// monitor
self.monitor = SDLNetworkMonitor()
self.monitor?.start()
self.logger.log("[SDLContext] monitor started")
try await withThrowingTaskGroup(of: Void.self) { group in
// UDP
group.addTask {
while true {
try Task.checkCancellation()
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
try Task.checkCancellation()
await self.sendStunRequest()
}
}
group.addTask {
if let eventStream = self.udpHole?.eventStream {
for try await event in eventStream {
try Task.checkCancellation()
Task {
try await self.dispatchEvent(event: event)
}
}
}
}
// DNS
group.addTask {
if let packetFlow = self.dnsClient?.packetFlow {
for await packet in packetFlow {
let nePacket = NEPacket(data: packet, protocolFamily: 2)
await self.providerActor.writePackets(packets: [nePacket])
}
}
}
// Monitor
group.addTask {
for await event in self.monitor!.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info)
case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning)
}
}
}
if let _ = try await group.next() {
self.logger.log("[SDLContext] taskGroup cancel")
group.cancelAll()
}
}
}
public func stop() async {
self.udpHole = nil
self.noticeClient = nil
self.readTask?.cancel()
}
private func handleUDPHoleReady() async throws {
await self.puncherActor.setUDPHoleActor(udpHole: self.udpHole)
await withTaskGroup(of: Void.self) { group in
group.addTask {
// nat
if let udpHoleActor = self.udpHole {
self.proberActor = SDLNATProberActor(udpHole: udpHoleActor, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger)
self.natType = await self.proberActor!.probeNatType()
self.logger.log("[SDLContext] nat_type is: \(self.natType)")
}
}
group.addTask {
var registerSuper = SDLRegisterSuper()
registerSuper.pktID = 0
registerSuper.clientID = self.config.clientId
registerSuper.networkID = self.config.networkAddress.networkId
registerSuper.mac = self.config.networkAddress.mac
registerSuper.ip = self.config.networkAddress.ip
registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen)
registerSuper.hostname = self.config.hostname
registerSuper.pubKey = self.rsaCipher.pubKey
registerSuper.accessToken = self.config.accessToken
if let registerSuperData = try? registerSuper.serializedData() {
self.logger.log("[SDLContext] will send register super")
self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress)
}
}
}
}
private func sendStunRequest() async {
var stunRequest = SDLStunRequest()
stunRequest.clientID = self.config.clientId
stunRequest.networkID = self.config.networkAddress.networkId
stunRequest.ip = self.config.networkAddress.ip
stunRequest.mac = self.config.networkAddress.mac
stunRequest.natType = UInt32(self.natType.rawValue)
if let stunData = try? stunRequest.serializedData() {
let remoteAddress = self.config.stunSocketAddress
self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
}
}
private func dispatchEvent(event: SDLUDPHole.UDPHoleEvent) async throws {
switch event {
case .ready:
try await self.handleUDPHoleReady()
case .message(let remoteAddress, let message):
switch message {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .data(let data):
try await self.handleData(data: data)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info)
// tun
do {
let ipAddress = try await self.providerActor.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer)
self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress))
self.startReader()
} catch let err {
self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
exit(-1)
}
self.aesKey = aesKey
}
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) async {
let errorMessage = nakPacket.errorMessage
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
return
}
switch errorCode {
case .invalidToken, .nodeDisabled:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
exit(-1)
case .noIpAddress, .networkFault, .internalFault:
let alertNotice = NoticeMessage.alert(alert: errorMessage)
self.noticeClient?.send(data: alertNotice)
}
self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning)
}
private func handleEvent(event: SDLEvent) async throws {
switch event {
case .natChanged(let natChangedEvent):
let dstMac = natChangedEvent.mac
self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
await sessionManager.removeSession(dstMac: dstMac)
case .sendRegister(let sendRegisterEvent):
self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
// register
var register = SDLRegister()
register.networkID = self.config.networkAddress.networkId
register.srcMac = self.config.networkAddress.mac
register.dstMac = sendRegisterEvent.dstMac
self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
}
case .networkShutdown(let shutdownEvent):
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
self.noticeClient?.send(data: alertNotice)
exit(-1)
}
}
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) async throws {
let networkAddr = config.networkAddress
self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) async {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
await self.sessionManager.addSession(session: session)
} else {
self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
}
}
private func handleData(data: SDLData) async throws {
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else {
return
}
let layerPacket = try LayerPacket(layerData: decyptedData)
await self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
self.logger.log("[SDLContext] get arp request packet", level: .debug)
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
self.logger.log("[SDLContext] get arp response packet", level: .debug)
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug)
}
} else {
self.logger.log("[SDLContext] get invalid arp packet", level: .debug)
}
case .ipv4:
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
return
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
await self.providerActor.writePackets(packets: [packet])
default:
self.logger.log("[SDLContext] get invalid packet", level: .debug)
}
}
//
// public func flowReportTask() {
// Task {
// //
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
// .sink { _ in
// Task {
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
// }
// }
// }
// }
// , 线packetFlow
private func startReader() {
//
self.readTask?.cancel()
//
self.readTask = Task(priority: .high) {
repeat {
let packets = await self.providerActor.readPackets()
Task {
let ipPackets = packets.compactMap { IPPacket($0) }
await self.batchProcessPackets(batchSize: 20, packets: ipPackets)
}
} while true
}
}
// ip
private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async {
for startIndex in stride(from: 0, to: packets.count, by: batchSize) {
let endIndex = Swift.min(startIndex + batchSize, packets.count)
let chunkPackets = packets[startIndex..<endIndex]
await withTaskGroup(of: Void.self) { group in
for packet in chunkPackets {
group.addTask {
await self.dealPacket(packet: packet)
}
}
}
}
}
//
private func dealPacket(packet: IPPacket) async {
let networkAddr = self.config.networkAddress
if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
self.logger.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug)
self.dnsClient?.forward(ipPacket: packet)
return
}
let dstIp = packet.header.destination
// , ip
if dstIp == networkAddr.ip {
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
await self.providerActor.writePackets(packets: [nePacket])
return
}
// arpmac
if let dstMac = await self.arpServer.query(ip: dstIp) {
await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
}
else {
self.logger.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug)
// arp广
let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp)
await self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal())
}
}
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async {
let networkAddr = self.config.networkAddress
// 2
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data)
guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else {
return
}
//
var dataPacket = SDLData()
dataPacket.networkID = networkAddr.networkId
dataPacket.srcMac = networkAddr.mac
dataPacket.dstMac = dstMac
dataPacket.ttl = 255
dataPacket.data = encodedPacket
let data = try! dataPacket.serializedData()
// 广
if ARPPacket.isBroadcastMac(dstMac) {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
}
else {
// session
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug)
self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress)
await self.flowTracer.inc(num: data.count, type: .p2p)
}
else {
// super_node
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
//
await self.flowTracer.inc(num: data.count, type: .forward)
//
await self.puncherActor.submitRegisterRequest(request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId))
}
}
}
deinit {
self.udpHole = nil
self.dnsClient = nil
}
}
private extension UInt32 {
// ip
func asIpAddress() -> String {
return SDLUtil.int32ToIp(self)
}
}