528 lines
22 KiB
Swift
528 lines
22 KiB
Swift
//
|
||
// SDLContext.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2024/2/29.
|
||
//
|
||
|
||
import Foundation
|
||
import NetworkExtension
|
||
import NIOCore
|
||
|
||
// 上下文环境变量,全局共享
|
||
/*
|
||
1. 处理rsa的加解密逻辑
|
||
*/
|
||
actor SDLContextActor {
|
||
nonisolated let config: SDLConfiguration
|
||
|
||
// nat的网络类型
|
||
var natType: SDLNATProberActor.NatType = .blocked
|
||
|
||
// AES加密,授权通过后,对象才会被创建
|
||
nonisolated let aesCipher: AESCipher
|
||
|
||
// aes
|
||
var aesKey: Data = Data()
|
||
|
||
// rsa的相关配置, public_key是本地生成的
|
||
nonisolated let rsaCipher: RSACipher
|
||
|
||
// 依赖的变量
|
||
var udpHole: SDLUDPHole?
|
||
nonisolated let providerAdapter: SDLTunnelProviderAdapter
|
||
var puncherActor: SDLPuncherActor?
|
||
// dns的client对象
|
||
var dnsClient: SDLDNSClient?
|
||
|
||
// 网络探测对象
|
||
var proberActor: SDLNATProberActor?
|
||
|
||
// 数据包读取任务
|
||
private var readTask: Task<(), Never>?
|
||
|
||
private var sessionManager: SessionManager
|
||
private var arpServer: ArpServer
|
||
|
||
// 网络状态变化的健康
|
||
private var monitor: SDLNetworkMonitor?
|
||
|
||
// 内部socket通讯
|
||
private var noticeClient: SDLNoticeClient?
|
||
|
||
// 流量统计
|
||
nonisolated private let flowTracer = SDLFlowTracer()
|
||
|
||
nonisolated private let logger: SDLLogger
|
||
|
||
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
|
||
self.logger = logger
|
||
self.config = config
|
||
self.rsaCipher = rsaCipher
|
||
self.aesCipher = aesCipher
|
||
|
||
self.sessionManager = SessionManager()
|
||
self.arpServer = ArpServer(known_macs: [:])
|
||
self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger)
|
||
}
|
||
|
||
public func startNoticeClient() async throws {
|
||
// 启动noticeClient
|
||
self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
|
||
self.logger.log("[SDLContext] noticeClient started")
|
||
try await self.noticeClient?.waitClose()
|
||
}
|
||
|
||
public func startMonitor() async throws {
|
||
// 启动monitor
|
||
let monitor = SDLNetworkMonitor()
|
||
monitor.start()
|
||
self.logger.log("[SDLContext] monitor started")
|
||
self.monitor = monitor
|
||
|
||
for await event in monitor.eventStream {
|
||
switch event {
|
||
case .changed:
|
||
// 需要重新探测网络的nat类型
|
||
//self.natType = await self.getNatType()
|
||
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info)
|
||
case .unreachable:
|
||
self.logger.log("didNetworkPathUnreachable", level: .warning)
|
||
}
|
||
}
|
||
}
|
||
|
||
public func startDnsClient() async throws {
|
||
// 启动dns服务
|
||
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
|
||
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
|
||
let channel = try dnsClient.start()
|
||
self.logger.log("[SDLContext] dnsClient started")
|
||
self.dnsClient = dnsClient
|
||
|
||
try await withThrowingTaskGroup(of: Void.self) {group in
|
||
group.addTask {
|
||
// 处理事件流
|
||
for await packet in dnsClient.packetFlow {
|
||
try Task.checkCancellation()
|
||
let nePacket = NEPacket(data: packet, protocolFamily: 2)
|
||
self.providerAdapter.writePackets(packets: [nePacket])
|
||
}
|
||
}
|
||
|
||
group.addTask {
|
||
try await channel.closeFuture.get()
|
||
}
|
||
|
||
try await group.next()
|
||
self.logger.log("[SDLContext] taskGroup cancel")
|
||
group.cancelAll()
|
||
}
|
||
}
|
||
|
||
public func startUDPHole() async throws {
|
||
// 启动udp服务器
|
||
let udpHole = try SDLUDPHole(logger: self.logger)
|
||
let channel = try udpHole.start()
|
||
self.logger.log("[SDLContext] udpHole started")
|
||
self.udpHole = udpHole
|
||
|
||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||
group.addTask {
|
||
try await channel.closeFuture.get()
|
||
}
|
||
|
||
// 处理UDP的事件流
|
||
group.addTask {
|
||
while true {
|
||
try Task.checkCancellation()
|
||
try await Task.sleep(nanoseconds: 5 * 1_000_000_000)
|
||
try Task.checkCancellation()
|
||
await self.sendStunRequest()
|
||
}
|
||
}
|
||
|
||
// 处理event事件流
|
||
group.addTask {
|
||
for try await event in udpHole.eventStream {
|
||
try Task.checkCancellation()
|
||
switch event {
|
||
case .ready:
|
||
await self.handleUDPHoleReady()
|
||
case .closed:
|
||
()
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理数据流
|
||
group.addTask {
|
||
for try await data in udpHole.dataStream {
|
||
try Task.checkCancellation()
|
||
try await self.handleData(data: data)
|
||
}
|
||
}
|
||
|
||
// 处理signal信号流
|
||
group.addTask {
|
||
for try await(remoteAddress, signal) in udpHole.signalStream {
|
||
try Task.checkCancellation()
|
||
switch signal {
|
||
case .registerSuperAck(let registerSuperAck):
|
||
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
|
||
case .registerSuperNak(let registerSuperNak):
|
||
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
|
||
case .peerInfo(let peerInfo):
|
||
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
|
||
case .event(let event):
|
||
try await self.handleEvent(event: event)
|
||
case .stunProbeReply(let probeReply):
|
||
await self.proberActor?.handleProbeReply(reply: probeReply)
|
||
case .register(let register):
|
||
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
|
||
case .registerAck(let registerAck):
|
||
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
|
||
}
|
||
}
|
||
}
|
||
|
||
try await group.next()
|
||
group.cancelAll()
|
||
self.logger.log("[SDLContext] taskGroup cancel")
|
||
}
|
||
}
|
||
|
||
public func stop() async {
|
||
self.udpHole = nil
|
||
self.noticeClient = nil
|
||
|
||
self.readTask?.cancel()
|
||
}
|
||
|
||
private func setNatType(natType: SDLNATProberActor.NatType) {
|
||
self.natType = natType
|
||
}
|
||
|
||
private func handleUDPHoleReady() async {
|
||
if let udpHole = self.udpHole {
|
||
self.puncherActor = SDLPuncherActor(udpHole: udpHole, querySocketAddress: config.stunSocketAddress, logger: logger)
|
||
self.proberActor = SDLNATProberActor(udpHole: udpHole, addressArray: self.config.stunProbeSocketAddressArray, logger: self.logger)
|
||
}
|
||
|
||
await withDiscardingTaskGroup { group in
|
||
group.addTask {
|
||
// 开始探测nat的类型
|
||
if let natType = await self.proberActor?.probeNatType() {
|
||
await self.setNatType(natType: natType)
|
||
self.logger.log("[SDLContext] nat_type is: \(natType)")
|
||
}
|
||
}
|
||
|
||
group.addTask {
|
||
var registerSuper = SDLRegisterSuper()
|
||
registerSuper.pktID = 0
|
||
registerSuper.clientID = self.config.clientId
|
||
registerSuper.networkID = self.config.networkAddress.networkId
|
||
registerSuper.mac = self.config.networkAddress.mac
|
||
registerSuper.ip = self.config.networkAddress.ip
|
||
registerSuper.maskLen = UInt32(self.config.networkAddress.maskLen)
|
||
registerSuper.hostname = self.config.hostname
|
||
registerSuper.pubKey = self.rsaCipher.pubKey
|
||
registerSuper.accessToken = self.config.accessToken
|
||
|
||
if let registerSuperData = try? registerSuper.serializedData() {
|
||
self.logger.log("[SDLContext] will send register super")
|
||
await self.udpHole?.send(type: .registerSuper, data: registerSuperData, remoteAddress: self.config.stunSocketAddress)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
private func sendStunRequest() async {
|
||
var stunRequest = SDLStunRequest()
|
||
stunRequest.clientID = self.config.clientId
|
||
stunRequest.networkID = self.config.networkAddress.networkId
|
||
stunRequest.ip = self.config.networkAddress.ip
|
||
stunRequest.mac = self.config.networkAddress.mac
|
||
stunRequest.natType = UInt32(self.natType.rawValue)
|
||
|
||
if let stunData = try? stunRequest.serializedData() {
|
||
let remoteAddress = self.config.stunSocketAddress
|
||
self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress)
|
||
}
|
||
}
|
||
|
||
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
|
||
// 需要对数据通过rsa的私钥解码
|
||
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
|
||
|
||
self.logger.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info)
|
||
// 服务器分配的tun网卡信息
|
||
do {
|
||
let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer)
|
||
self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress))
|
||
|
||
self.startReader()
|
||
} catch let err {
|
||
self.logger.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error)
|
||
exit(-1)
|
||
}
|
||
|
||
self.aesKey = aesKey
|
||
}
|
||
|
||
private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) {
|
||
let errorMessage = nakPacket.errorMessage
|
||
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
|
||
return
|
||
}
|
||
|
||
switch errorCode {
|
||
case .invalidToken, .nodeDisabled:
|
||
let alertNotice = NoticeMessage.alert(alert: errorMessage)
|
||
self.noticeClient?.send(data: alertNotice)
|
||
exit(-1)
|
||
case .noIpAddress, .networkFault, .internalFault:
|
||
let alertNotice = NoticeMessage.alert(alert: errorMessage)
|
||
self.noticeClient?.send(data: alertNotice)
|
||
}
|
||
self.logger.log("[SDLContext] Get a SuperNak message exit", level: .warning)
|
||
|
||
}
|
||
|
||
private func handleEvent(event: SDLEvent) throws {
|
||
switch event {
|
||
case .natChanged(let natChangedEvent):
|
||
let dstMac = natChangedEvent.mac
|
||
self.logger.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info)
|
||
sessionManager.removeSession(dstMac: dstMac)
|
||
case .sendRegister(let sendRegisterEvent):
|
||
self.logger.log("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)", level: .debug)
|
||
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
|
||
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
|
||
// 发送register包
|
||
var register = SDLRegister()
|
||
register.networkID = self.config.networkAddress.networkId
|
||
register.srcMac = self.config.networkAddress.mac
|
||
register.dstMac = sendRegisterEvent.dstMac
|
||
self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress)
|
||
}
|
||
|
||
case .networkShutdown(let shutdownEvent):
|
||
let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message)
|
||
self.noticeClient?.send(data: alertNotice)
|
||
exit(-1)
|
||
}
|
||
}
|
||
|
||
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
|
||
let networkAddr = config.networkAddress
|
||
self.logger.log("register packet: \(register), network_address: \(networkAddr)", level: .debug)
|
||
|
||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
|
||
// 回复ack包
|
||
var registerAck = SDLRegisterAck()
|
||
registerAck.networkID = networkAddr.networkId
|
||
registerAck.srcMac = networkAddr.mac
|
||
registerAck.dstMac = register.srcMac
|
||
|
||
self.udpHole?.send(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
|
||
// 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址
|
||
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
|
||
self.sessionManager.addSession(session: session)
|
||
} else {
|
||
self.logger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
|
||
}
|
||
}
|
||
|
||
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
|
||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||
let networkAddr = config.networkAddress
|
||
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
|
||
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
|
||
self.sessionManager.addSession(session: session)
|
||
} else {
|
||
self.logger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
|
||
}
|
||
}
|
||
|
||
private func handleData(data: SDLData) throws {
|
||
let mac = LayerPacket.MacAddress(data: data.dstMac)
|
||
|
||
let networkAddr = config.networkAddress
|
||
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
|
||
return
|
||
}
|
||
|
||
guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else {
|
||
return
|
||
}
|
||
|
||
let layerPacket = try LayerPacket(layerData: decyptedData)
|
||
|
||
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
|
||
// 处理arp请求
|
||
switch layerPacket.type {
|
||
case .arp:
|
||
// 判断如果收到的是arp请求
|
||
if let arpPacket = ARPPacket(data: layerPacket.data) {
|
||
if arpPacket.targetIP == networkAddr.ip {
|
||
switch arpPacket.opcode {
|
||
case .request:
|
||
self.logger.log("[SDLContext] get arp request packet", level: .debug)
|
||
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
|
||
self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
|
||
case .response:
|
||
self.logger.log("[SDLContext] get arp response packet", level: .debug)
|
||
self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
|
||
}
|
||
} else {
|
||
self.logger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))", level: .debug)
|
||
}
|
||
} else {
|
||
self.logger.log("[SDLContext] get invalid arp packet", level: .debug)
|
||
}
|
||
case .ipv4:
|
||
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == networkAddr.ip else {
|
||
return
|
||
}
|
||
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
|
||
self.providerAdapter.writePackets(packets: [packet])
|
||
default:
|
||
self.logger.log("[SDLContext] get invalid packet", level: .debug)
|
||
}
|
||
}
|
||
|
||
|
||
// 流量统计
|
||
// public func flowReportTask() {
|
||
// Task {
|
||
// // 每分钟汇报一次
|
||
// self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
|
||
// .sink { _ in
|
||
// Task {
|
||
// let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
|
||
// await self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
|
||
// }
|
||
// }
|
||
// }
|
||
// }
|
||
|
||
// 开始读取数据, 用单独的线程处理packetFlow
|
||
private func startReader() {
|
||
// 停止之前的任务
|
||
self.readTask?.cancel()
|
||
|
||
// 开启新的任务
|
||
self.readTask = Task(priority: .high) {
|
||
while true {
|
||
if Task.isCancelled {
|
||
return
|
||
}
|
||
|
||
let packets = await self.providerAdapter.readPackets()
|
||
let ipPackets = packets.compactMap { IPPacket($0) }
|
||
await self.batchProcessPackets(batchSize: 20, packets: ipPackets)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 批量分发ip数据包
|
||
private func batchProcessPackets(batchSize: Int, packets: [IPPacket]) async {
|
||
for startIndex in stride(from: 0, to: packets.count, by: batchSize) {
|
||
let endIndex = Swift.min(startIndex + batchSize, packets.count)
|
||
let chunkPackets = packets[startIndex..<endIndex]
|
||
await withDiscardingTaskGroup() { group in
|
||
for packet in chunkPackets {
|
||
group.addTask {
|
||
await self.dealPacket(packet: packet)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理读取的每个数据包
|
||
private func dealPacket(packet: IPPacket) {
|
||
let networkAddr = self.config.networkAddress
|
||
if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
|
||
let destIp = packet.header.destination_ip
|
||
self.logger.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug)
|
||
self.dnsClient?.forward(ipPacket: packet)
|
||
return
|
||
}
|
||
|
||
let dstIp = packet.header.destination
|
||
// 本地通讯, 目标地址是本地服务器的ip地址
|
||
if dstIp == networkAddr.ip {
|
||
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
|
||
self.providerAdapter.writePackets(packets: [nePacket])
|
||
return
|
||
}
|
||
|
||
// 查找arp缓存中是否有目标mac地址
|
||
if let dstMac = self.arpServer.query(ip: dstIp) {
|
||
self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
|
||
}
|
||
else {
|
||
self.logger.log("[SDLContext] dstIp: \(dstIp.asIpAddress()) arp query not found, broadcast", level: .debug)
|
||
// 构造arp广播
|
||
let arpReqeust = ARPPacket.arpRequest(senderIP: networkAddr.ip, senderMAC: networkAddr.mac, targetIP: dstIp)
|
||
self.routeLayerPacket(dstMac: ARPPacket.broadcastMac , type: .arp, data: arpReqeust.marshal())
|
||
}
|
||
}
|
||
|
||
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) {
|
||
let networkAddr = self.config.networkAddress
|
||
// 将数据封装层2层的数据包
|
||
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: networkAddr.mac, type: type, data: data)
|
||
guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else {
|
||
return
|
||
}
|
||
|
||
// 构造数据包
|
||
var dataPacket = SDLData()
|
||
dataPacket.networkID = networkAddr.networkId
|
||
dataPacket.srcMac = networkAddr.mac
|
||
dataPacket.dstMac = dstMac
|
||
dataPacket.ttl = 255
|
||
dataPacket.data = encodedPacket
|
||
|
||
let data = try! dataPacket.serializedData()
|
||
// 广播地址不要去尝试打洞
|
||
if ARPPacket.isBroadcastMac(dstMac) {
|
||
// 通过super_node进行转发
|
||
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
|
||
}
|
||
else {
|
||
// 通过session发送到对端
|
||
if let session = self.sessionManager.getSession(toAddress: dstMac) {
|
||
self.logger.log("[SDLContext] send packet by session: \(session)", level: .debug)
|
||
self.udpHole?.send(type: .data, data: data, remoteAddress: session.natAddress)
|
||
self.flowTracer.inc(num: data.count, type: .p2p)
|
||
}
|
||
else {
|
||
// 通过super_node进行转发
|
||
self.udpHole?.send(type: .data, data: data, remoteAddress: self.config.stunSocketAddress)
|
||
// 流量统计
|
||
self.flowTracer.inc(num: data.count, type: .forward)
|
||
// 尝试打洞
|
||
self.puncherActor?.submitRegisterRequest(request: .init(srcMac: networkAddr.mac, dstMac: dstMac, networkId: networkAddr.networkId))
|
||
}
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
self.udpHole = nil
|
||
self.dnsClient = nil
|
||
}
|
||
}
|
||
|
||
private extension UInt32 {
|
||
// 转换成ip地址
|
||
func asIpAddress() -> String {
|
||
return SDLUtil.int32ToIp(self)
|
||
}
|
||
}
|