fix
This commit is contained in:
parent
cd24d2ac7c
commit
b55d9913cc
@ -47,11 +47,8 @@ public class SDLContext: @unchecked Sendable {
|
||||
let rsaCipher: RSACipher
|
||||
|
||||
// 依赖的变量
|
||||
var udpHole: SDLUDPHole?
|
||||
private var udpCancel: AnyCancellable?
|
||||
|
||||
var superClient: SDLSuperClient?
|
||||
private var superCancel: AnyCancellable?
|
||||
var udpHoleActor: SDLUDPHoleActor?
|
||||
var superClientActor: SDLSuperClientActor?
|
||||
|
||||
// 数据包读取任务
|
||||
private var readTask: Task<(), Never>?
|
||||
@ -97,53 +94,67 @@ public class SDLContext: @unchecked Sendable {
|
||||
}
|
||||
|
||||
public func start() async throws {
|
||||
try await self.startSuperClient()
|
||||
try await self.startUDPHole()
|
||||
self.noticeClient.start()
|
||||
self.udpHoleActor = try await SDLUDPHoleActor()
|
||||
self.superClientActor = try await SDLSuperClientActor(host: self.config.superHost, port: self.config.superPort)
|
||||
|
||||
// 启动网络监控
|
||||
self.monitorCancel = self.monitor.eventFlow.sink { event in
|
||||
switch event {
|
||||
case .changed:
|
||||
// 需要重新探测网络的nat类型
|
||||
Task {
|
||||
self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
|
||||
NSLog("didNetworkPathChanged, nat type is: \(self.natType)")
|
||||
try await withThrowingTaskGroup(of: Void.self) { group in
|
||||
group.addTask {
|
||||
try await self.udpHoleActor?.start()
|
||||
}
|
||||
case .unreachable:
|
||||
NSLog("didNetworkPathUnreachable")
|
||||
|
||||
group.addTask {
|
||||
try await self.superClientActor?.start()
|
||||
}
|
||||
|
||||
group.addTask {
|
||||
if let eventFlow = self.superClientActor?.eventFlow {
|
||||
for try await event in eventFlow {
|
||||
try await self.handleSuperEvent(event: event)
|
||||
}
|
||||
}
|
||||
self.monitor.start()
|
||||
}
|
||||
|
||||
group.addTask {
|
||||
if let eventFlow = self.udpHoleActor?.eventFlow {
|
||||
for try await event in eventFlow {
|
||||
try await self.handleUDPEvent(event: event)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
try await group.waitForAll()
|
||||
}
|
||||
|
||||
// self.noticeClient.start()
|
||||
// // 启动网络监控
|
||||
// self.monitorCancel = self.monitor.eventFlow.sink { event in
|
||||
// switch event {
|
||||
// case .changed:
|
||||
// // 需要重新探测网络的nat类型
|
||||
// Task {
|
||||
// self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
|
||||
// NSLog("didNetworkPathChanged, nat type is: \(self.natType)")
|
||||
// }
|
||||
// case .unreachable:
|
||||
// NSLog("didNetworkPathUnreachable")
|
||||
// }
|
||||
// }
|
||||
// self.monitor.start()
|
||||
}
|
||||
|
||||
public func stop() async {
|
||||
self.superCancel?.cancel()
|
||||
self.superClient = nil
|
||||
|
||||
self.udpCancel?.cancel()
|
||||
self.udpHole = nil
|
||||
self.superClientActor = nil
|
||||
self.udpHoleActor = nil
|
||||
|
||||
self.readTask?.cancel()
|
||||
}
|
||||
|
||||
private func startSuperClient() async throws {
|
||||
self.superClient = SDLSuperClient(host: config.superHost, port: config.superPort)
|
||||
// 建立super的绑定关系
|
||||
self.superCancel?.cancel()
|
||||
self.superCancel = self.superClient?.eventFlow.sink { event in
|
||||
Task {
|
||||
await self.handleSuperEvent(event: event)
|
||||
}
|
||||
}
|
||||
try await self.superClient?.start()
|
||||
}
|
||||
|
||||
private func handleSuperEvent(event: SDLSuperClient.SuperEvent) async {
|
||||
private func handleSuperEvent(event: SDLSuperClientActor.SuperEvent) async throws {
|
||||
switch event {
|
||||
case .ready:
|
||||
NSLog("[SDLContext] get registerSuper, mac address: \(Self.formatMacAddress(mac: self.devAddr.mac))")
|
||||
guard let message = await self.superClient?.registerSuper(context: self) else {
|
||||
guard let message = try await self.superClientActor?.registerSuper(context: self) else {
|
||||
return
|
||||
}
|
||||
|
||||
@ -195,9 +206,9 @@ public class SDLContext: @unchecked Sendable {
|
||||
NSLog("[SDLContext] super client closed")
|
||||
await self.arpServer.clear()
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) {
|
||||
Task {@MainActor in
|
||||
try await self.startSuperClient()
|
||||
}
|
||||
// Task {@MainActor in
|
||||
// try await self.startSuperClient()
|
||||
// }
|
||||
}
|
||||
case .event(let evt):
|
||||
switch evt {
|
||||
@ -210,7 +221,7 @@ public class SDLContext: @unchecked Sendable {
|
||||
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
|
||||
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
|
||||
// 发送register包
|
||||
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac)
|
||||
await self.udpHoleActor?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac)
|
||||
}
|
||||
|
||||
case .networkShutdown(let shutdownEvent):
|
||||
@ -233,26 +244,26 @@ public class SDLContext: @unchecked Sendable {
|
||||
var commandAck = SDLCommandAck()
|
||||
commandAck.status = true
|
||||
|
||||
self.superClient?.commandAck(packetId: packetId, ack: commandAck)
|
||||
await self.superClientActor?.commandAck(packetId: packetId, ack: commandAck)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private func startUDPHole() async throws {
|
||||
self.udpHole = SDLUDPHole()
|
||||
|
||||
self.udpCancel?.cancel()
|
||||
self.udpCancel = self.udpHole?.eventFlow.sink { event in
|
||||
Task.detached {
|
||||
await self.handleUDPEvent(event: event)
|
||||
}
|
||||
// self.udpHole = SDLUDPHole()
|
||||
//
|
||||
// self.udpCancel?.cancel()
|
||||
// self.udpCancel = self.udpHole?.eventFlow.sink { event in
|
||||
// Task.detached {
|
||||
// await self.handleUDPEvent(event: event)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// try await self.udpHole?.start()
|
||||
}
|
||||
|
||||
try await self.udpHole?.start()
|
||||
}
|
||||
|
||||
private func handleUDPEvent(event: SDLUDPHole.UDPEvent) async {
|
||||
private func handleUDPEvent(event: SDLUDPHoleActor.UDPEvent) async throws {
|
||||
switch event {
|
||||
case .ready:
|
||||
// 获取当前网络的类型
|
||||
@ -260,9 +271,9 @@ public class SDLContext: @unchecked Sendable {
|
||||
SDLLogger.log("[SDLContext] nat type is: \(self.natType)", level: .debug)
|
||||
|
||||
let timer = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
|
||||
self.stunCancel = Just(Date()).merge(with: timer).sink { _ in
|
||||
self.lastCookie = self.udpHole?.stunRequest(context: self)
|
||||
}
|
||||
// self.stunCancel = Just(Date()).merge(with: timer).sink { _ in
|
||||
// self.lastCookie = await self.udpHoleActor?.stunRequest(context: self)
|
||||
// }
|
||||
|
||||
case .closed:
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) {
|
||||
@ -278,7 +289,7 @@ public class SDLContext: @unchecked Sendable {
|
||||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||||
if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID {
|
||||
// 回复ack包
|
||||
self.udpHole?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac)
|
||||
self.udpHoleActor?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac)
|
||||
// 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址
|
||||
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
|
||||
await self.sessionManager.addSession(session: session)
|
||||
@ -367,7 +378,7 @@ public class SDLContext: @unchecked Sendable {
|
||||
.sink { _ in
|
||||
Task {
|
||||
let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
|
||||
self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
|
||||
await self.superClientActor?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -467,13 +478,13 @@ public class SDLContext: @unchecked Sendable {
|
||||
// 通过session发送到对端
|
||||
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
|
||||
NSLog("[SDLContext] send packet by session: \(session)")
|
||||
self.udpHole?.sendPacket(context: self, session: session, data: encodedPacket)
|
||||
await self.udpHoleActor?.sendPacket(context: self, session: session, data: encodedPacket)
|
||||
|
||||
await self.flowTracer.inc(num: data.count, type: .p2p)
|
||||
}
|
||||
else {
|
||||
// 通过super_node进行转发
|
||||
self.udpHole?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket)
|
||||
await self.udpHoleActor?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket)
|
||||
// 流量统计
|
||||
await self.flowTracer.inc(num: data.count, type: .forward)
|
||||
|
||||
@ -486,7 +497,7 @@ public class SDLContext: @unchecked Sendable {
|
||||
|
||||
func holerTask(dstMac: Data) -> Task<(), Never> {
|
||||
return Task {
|
||||
guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac) else {
|
||||
guard let message = try? await self.superClientActor?.queryInfo(context: self, dst_mac: dstMac) else {
|
||||
return
|
||||
}
|
||||
|
||||
@ -497,7 +508,7 @@ public class SDLContext: @unchecked Sendable {
|
||||
if let remoteAddress = peerInfo.v4Info.socketAddress() {
|
||||
SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .warning)
|
||||
// 发送register包
|
||||
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac)
|
||||
await self.udpHoleActor?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac)
|
||||
} else {
|
||||
SDLLogger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
|
||||
}
|
||||
@ -509,8 +520,8 @@ public class SDLContext: @unchecked Sendable {
|
||||
|
||||
deinit {
|
||||
self.stunCancel?.cancel()
|
||||
self.udpHole = nil
|
||||
self.superClient = nil
|
||||
self.udpHoleActor = nil
|
||||
self.superClientActor = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,567 +0,0 @@
|
||||
//
|
||||
// SDLContext.swift
|
||||
// Tun
|
||||
//
|
||||
// Created by 安礼成 on 2024/2/29.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import NetworkExtension
|
||||
import NIOCore
|
||||
import Combine
|
||||
|
||||
// 上下文环境变量,全局共享
|
||||
/*
|
||||
1. 处理rsa的加解密逻辑
|
||||
*/
|
||||
|
||||
public class SDLContextNew: @unchecked Sendable {
|
||||
|
||||
// 路由信息
|
||||
struct Route {
|
||||
let dstAddress: String
|
||||
let subnetMask: String
|
||||
|
||||
var debugInfo: String {
|
||||
return "\(dstAddress):\(subnetMask)"
|
||||
}
|
||||
}
|
||||
|
||||
let config: SDLConfiguration
|
||||
|
||||
// tun网络地址信息
|
||||
var devAddr: SDLDevAddr
|
||||
|
||||
// nat映射的相关信息, 暂时没有用处
|
||||
//var natAddress: SDLNatAddress?
|
||||
// nat的网络类型
|
||||
var natType: SDLNatProber.NatType = .blocked
|
||||
|
||||
// AES加密,授权通过后,对象才会被创建
|
||||
var aesCipher: AESCipher
|
||||
|
||||
// aes
|
||||
var aesKey: Data = Data()
|
||||
|
||||
// rsa的相关配置, public_key是本地生成的
|
||||
let rsaCipher: RSACipher
|
||||
|
||||
// 依赖的变量
|
||||
var udpHoleActor: SDLUDPHoleActor?
|
||||
private var udpCancel: AnyCancellable?
|
||||
|
||||
var superClientActor: SDLSuperClientActor?
|
||||
private var superCancel: AnyCancellable?
|
||||
|
||||
// 数据包读取任务
|
||||
private var readTask: Task<(), Never>?
|
||||
|
||||
let provider: NEPacketTunnelProvider
|
||||
|
||||
private var sessionManager: SessionManager
|
||||
private var holerManager: HolerManager
|
||||
private var arpServer: ArpServer
|
||||
|
||||
// 记录最后发送的stunRequest的cookie
|
||||
private var lastCookie: UInt32? = 0
|
||||
|
||||
// 定时器
|
||||
private var stunCancel: AnyCancellable?
|
||||
|
||||
// 网络状态变化的健康
|
||||
private var monitor = SDLNetworkMonitor()
|
||||
private var monitorCancel: AnyCancellable?
|
||||
|
||||
// 内部socket通讯
|
||||
private var noticeClient: SDLNoticeClient
|
||||
|
||||
// 流量统计
|
||||
private var flowTracer = SDLFlowTracerActor()
|
||||
private var flowTracerCancel: AnyCancellable?
|
||||
|
||||
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher) {
|
||||
self.config = config
|
||||
self.rsaCipher = rsaCipher
|
||||
self.aesCipher = aesCipher
|
||||
|
||||
// 生成mac地址
|
||||
var devAddr = SDLDevAddr()
|
||||
devAddr.mac = Self.getMacAddress()
|
||||
self.devAddr = devAddr
|
||||
|
||||
self.provider = provider
|
||||
self.sessionManager = SessionManager()
|
||||
self.holerManager = HolerManager()
|
||||
self.arpServer = ArpServer(known_macs: [:])
|
||||
self.noticeClient = SDLNoticeClient()
|
||||
}
|
||||
|
||||
public func start() async throws {
|
||||
try await self.startSuperClient()
|
||||
try await self.startUDPHole()
|
||||
self.noticeClient.start()
|
||||
|
||||
// 启动网络监控
|
||||
self.monitorCancel = self.monitor.eventFlow.sink { event in
|
||||
switch event {
|
||||
case .changed:
|
||||
// 需要重新探测网络的nat类型
|
||||
Task {
|
||||
self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
|
||||
NSLog("didNetworkPathChanged, nat type is: \(self.natType)")
|
||||
}
|
||||
case .unreachable:
|
||||
NSLog("didNetworkPathUnreachable")
|
||||
}
|
||||
}
|
||||
self.monitor.start()
|
||||
}
|
||||
|
||||
public func stop() async {
|
||||
self.superCancel?.cancel()
|
||||
self.superClient = nil
|
||||
|
||||
self.udpCancel?.cancel()
|
||||
self.udpHole = nil
|
||||
|
||||
self.readTask?.cancel()
|
||||
}
|
||||
|
||||
private func startSuperClient() async throws {
|
||||
self.superClient = SDLSuperClient(host: config.superHost, port: config.superPort)
|
||||
// 建立super的绑定关系
|
||||
self.superCancel?.cancel()
|
||||
self.superCancel = self.superClient?.eventFlow.sink { event in
|
||||
Task {
|
||||
await self.handleSuperEvent(event: event)
|
||||
}
|
||||
}
|
||||
try await self.superClient?.start()
|
||||
}
|
||||
|
||||
private func handleSuperEvent(event: SDLSuperClient.SuperEvent) async {
|
||||
switch event {
|
||||
case .ready:
|
||||
NSLog("[SDLContext] get registerSuper, mac address: \(Self.formatMacAddress(mac: self.devAddr.mac))")
|
||||
guard let message = await self.superClient?.registerSuper(context: self) else {
|
||||
return
|
||||
}
|
||||
|
||||
switch message.packet {
|
||||
case .registerSuperAck(let registerSuperAck):
|
||||
// 需要对数据通过rsa的私钥解码
|
||||
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))
|
||||
let upgradeType = SDLUpgradeType(rawValue: registerSuperAck.upgradeType)
|
||||
|
||||
NSLog("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)")
|
||||
self.devAddr = registerSuperAck.devAddr
|
||||
|
||||
if upgradeType == .force {
|
||||
let forceUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
|
||||
self.noticeClient.send(data: forceUpgrade.binaryData)
|
||||
exit(-1)
|
||||
}
|
||||
|
||||
// 服务器分配的tun网卡信息
|
||||
await self.didNetworkConfigChanged(devAddr: self.devAddr)
|
||||
self.aesKey = aesKey
|
||||
|
||||
if upgradeType == .normal {
|
||||
let normalUpgrade = NoticeMessage.UpgradeMessage(prompt: registerSuperAck.upgradePrompt, address: registerSuperAck.upgradeAddress)
|
||||
self.noticeClient.send(data: normalUpgrade.binaryData)
|
||||
}
|
||||
|
||||
case .registerSuperNak(let nakPacket):
|
||||
let errorMessage = nakPacket.errorMessage
|
||||
guard let errorCode = SDLNAKErrorCode(rawValue: UInt8(nakPacket.errorCode)) else {
|
||||
return
|
||||
}
|
||||
|
||||
switch errorCode {
|
||||
case .invalidToken, .nodeDisabled:
|
||||
let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage)
|
||||
self.noticeClient.send(data: alertNotice.binaryData)
|
||||
exit(-1)
|
||||
case .noIpAddress, .networkFault, .internalFault:
|
||||
let alertNotice = NoticeMessage.AlertMessage(alert: errorMessage)
|
||||
self.noticeClient.send(data: alertNotice.binaryData)
|
||||
}
|
||||
NSLog("[SDLContext] Get a SuperNak message exit")
|
||||
default:
|
||||
()
|
||||
}
|
||||
|
||||
case .closed:
|
||||
NSLog("[SDLContext] super client closed")
|
||||
await self.arpServer.clear()
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) {
|
||||
Task {@MainActor in
|
||||
try await self.startSuperClient()
|
||||
}
|
||||
}
|
||||
case .event(let evt):
|
||||
switch evt {
|
||||
case .natChanged(let natChangedEvent):
|
||||
let dstMac = natChangedEvent.mac
|
||||
NSLog("[SDLContext] natChangedEvent, dstMac: \(dstMac)")
|
||||
await sessionManager.removeSession(dstMac: dstMac)
|
||||
case .sendRegister(let sendRegisterEvent):
|
||||
NSLog("[SDLContext] sendRegisterEvent, ip: \(sendRegisterEvent)")
|
||||
let address = SDLUtil.int32ToIp(sendRegisterEvent.natIp)
|
||||
if let remoteAddress = try? SocketAddress.makeAddressResolvingHost(address, port: Int(sendRegisterEvent.natPort)) {
|
||||
// 发送register包
|
||||
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: sendRegisterEvent.dstMac)
|
||||
}
|
||||
|
||||
case .networkShutdown(let shutdownEvent):
|
||||
let alertNotice = NoticeMessage.AlertMessage(alert: shutdownEvent.message)
|
||||
self.noticeClient.send(data: alertNotice.binaryData)
|
||||
exit(-1)
|
||||
}
|
||||
case .command(let packetId, let command):
|
||||
switch command {
|
||||
case .changeNetwork(let changeNetworkCommand):
|
||||
// 需要对数据通过rsa的私钥解码
|
||||
let aesKey = try! self.rsaCipher.decode(data: Data(changeNetworkCommand.aesKey))
|
||||
NSLog("[SDLContext] change network command get aes_key len: \(aesKey.count)")
|
||||
self.devAddr = changeNetworkCommand.devAddr
|
||||
|
||||
// 服务器分配的tun网卡信息
|
||||
await self.didNetworkConfigChanged(devAddr: self.devAddr)
|
||||
self.aesKey = aesKey
|
||||
|
||||
var commandAck = SDLCommandAck()
|
||||
commandAck.status = true
|
||||
|
||||
self.superClient?.commandAck(packetId: packetId, ack: commandAck)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private func startUDPHole() async throws {
|
||||
self.udpHole = SDLUDPHole()
|
||||
|
||||
self.udpCancel?.cancel()
|
||||
self.udpCancel = self.udpHole?.eventFlow.sink { event in
|
||||
Task.detached {
|
||||
await self.handleUDPEvent(event: event)
|
||||
}
|
||||
}
|
||||
|
||||
try await self.udpHole?.start()
|
||||
}
|
||||
|
||||
private func handleUDPEvent(event: SDLUDPHole.UDPEvent) async {
|
||||
switch event {
|
||||
case .ready:
|
||||
// 获取当前网络的类型
|
||||
self.natType = await SDLNatProber.getNatType(udpHole: self.udpHole, config: self.config)
|
||||
SDLLogger.log("[SDLContext] nat type is: \(self.natType)", level: .debug)
|
||||
|
||||
let timer = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
|
||||
self.stunCancel = Just(Date()).merge(with: timer).sink { _ in
|
||||
self.lastCookie = self.udpHole?.stunRequest(context: self)
|
||||
}
|
||||
|
||||
case .closed:
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 5) {
|
||||
Task {
|
||||
try await self.startUDPHole()
|
||||
}
|
||||
}
|
||||
|
||||
case .message(let remoteAddress, let message):
|
||||
switch message {
|
||||
case .register(let register):
|
||||
NSLog("register packet: \(register), dev_addr: \(self.devAddr)")
|
||||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||||
if register.dstMac == self.devAddr.mac && register.networkID == self.devAddr.networkID {
|
||||
// 回复ack包
|
||||
self.udpHole?.sendRegisterAck(context: self, remoteAddress: remoteAddress, dst_mac: register.srcMac)
|
||||
// 这里需要建立到来源的会话, 在复杂网络下,通过super-node查询到的nat地址不一定靠谱,需要通过udp包的来源地址作为nat地址
|
||||
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
|
||||
await self.sessionManager.addSession(session: session)
|
||||
} else {
|
||||
SDLLogger.log("SDLContext didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)", level: .warning)
|
||||
}
|
||||
case .registerAck(let registerAck):
|
||||
// 判断目标地址是否是tun的网卡地址, 并且是在同一个网络下
|
||||
if registerAck.dstMac == self.devAddr.mac && registerAck.networkID == self.devAddr.networkID {
|
||||
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
|
||||
await self.sessionManager.addSession(session: session)
|
||||
} else {
|
||||
SDLLogger.log("SDLContext didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)", level: .warning)
|
||||
}
|
||||
case .stunReply(let stunReply):
|
||||
let cookie = stunReply.cookie
|
||||
if cookie == self.lastCookie {
|
||||
// 记录下当前在nat上的映射信息,暂时没有用;后续会用来判断网络类型
|
||||
//self.natAddress = stunReply.natAddress
|
||||
SDLLogger.log("[SDLContext] get a stunReply: \(try! stunReply.jsonString())")
|
||||
}
|
||||
default:
|
||||
()
|
||||
}
|
||||
|
||||
case .data(let data):
|
||||
let mac = LayerPacket.MacAddress(data: data.dstMac)
|
||||
guard (data.dstMac == self.devAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
|
||||
NSLog("[SDLContext] didReadData 1")
|
||||
return
|
||||
}
|
||||
|
||||
guard let decyptedData = try? self.aesCipher.decypt(aesKey: self.aesKey, data: Data(data.data)) else {
|
||||
NSLog("[SDLContext] didReadData 2")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let layerPacket = try LayerPacket(layerData: decyptedData)
|
||||
|
||||
await self.flowTracer.inc(num: decyptedData.count, type: .inbound)
|
||||
// 处理arp请求
|
||||
switch layerPacket.type {
|
||||
case .arp:
|
||||
// 判断如果收到的是arp请求
|
||||
if let arpPacket = ARPPacket(data: layerPacket.data) {
|
||||
if arpPacket.targetIP == self.devAddr.netAddr {
|
||||
switch arpPacket.opcode {
|
||||
case .request:
|
||||
NSLog("[SDLContext] get arp request packet")
|
||||
let response = ARPPacket.arpResponse(for: arpPacket, mac: self.devAddr.mac, ip: self.devAddr.netAddr)
|
||||
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
|
||||
case .response:
|
||||
NSLog("[SDLContext] get arp response packet")
|
||||
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
|
||||
}
|
||||
} else {
|
||||
NSLog("[SDLContext] get invalid arp packet, target_ip: \(arpPacket)")
|
||||
}
|
||||
} else {
|
||||
NSLog("[SDLContext] get invalid arp packet")
|
||||
}
|
||||
case .ipv4:
|
||||
NSLog("[SDLContext] get ipv4 packet")
|
||||
guard let ipPacket = IPPacket(layerPacket.data), ipPacket.header.destination == self.devAddr.netAddr else {
|
||||
return
|
||||
}
|
||||
|
||||
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
|
||||
self.provider.packetFlow.writePacketObjects([packet])
|
||||
default:
|
||||
NSLog("[SDLContext] get invalid packet")
|
||||
}
|
||||
} catch let err {
|
||||
NSLog("[SDLContext] didReadData err: \(err)")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 流量统计
|
||||
public func flowReportTask() {
|
||||
Task {
|
||||
// 每分钟汇报一次
|
||||
self.flowTracerCancel = Timer.publish(every: 60.0, on: .main, in: .common).autoconnect()
|
||||
.sink { _ in
|
||||
Task {
|
||||
let (forwardNum, p2pNum, inboundNum) = await self.flowTracer.reset()
|
||||
self.superClient?.flowReport(forwardNum: forwardNum, p2pNum: p2pNum, inboundNum: inboundNum)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 网络改变时需要重新配置网络信息
|
||||
private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServers: [String]? = nil) async {
|
||||
let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen))
|
||||
let routes = [Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress)]
|
||||
|
||||
// Add code here to start the process of connecting the tunnel.
|
||||
let networkSettings = NEPacketTunnelNetworkSettings(tunnelRemoteAddress: "8.8.8.8")
|
||||
networkSettings.mtu = 1460
|
||||
|
||||
// 设置网卡的DNS解析
|
||||
if let dnsServers {
|
||||
networkSettings.dnsSettings = NEDNSSettings(servers: dnsServers)
|
||||
} else {
|
||||
networkSettings.dnsSettings = NEDNSSettings(servers: ["8.8.8.8", "114.114.114.114"])
|
||||
}
|
||||
|
||||
NSLog("[SDLContext] Tun started at network ip: \(netAddress.ipAddress), mask: \(netAddress.maskAddress)")
|
||||
|
||||
let ipv4Settings = NEIPv4Settings(addresses: [netAddress.ipAddress], subnetMasks: [netAddress.maskAddress])
|
||||
// 设置路由表
|
||||
//NEIPv4Route.default()
|
||||
ipv4Settings.includedRoutes = routes.map { route in
|
||||
NEIPv4Route(destinationAddress: route.dstAddress, subnetMask: route.subnetMask)
|
||||
}
|
||||
networkSettings.ipv4Settings = ipv4Settings
|
||||
// 网卡配置设置必须成功
|
||||
do {
|
||||
try await self.provider.setTunnelNetworkSettings(networkSettings)
|
||||
|
||||
await self.holerManager.cleanup()
|
||||
self.startReader()
|
||||
|
||||
NSLog("[SDLContext] setTunnelNetworkSettings success, start read packet")
|
||||
} catch let err {
|
||||
NSLog("[SDLContext] setTunnelNetworkSettings get error: \(err)")
|
||||
exit(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// 开始读取数据, 用单独的线程处理packetFlow
|
||||
private func startReader() {
|
||||
// 停止之前的任务
|
||||
self.readTask?.cancel()
|
||||
|
||||
// 开启新的任务
|
||||
self.readTask = Task(priority: .high) {
|
||||
repeat {
|
||||
if Task.isCancelled {
|
||||
break
|
||||
}
|
||||
|
||||
let (packets, numbers) = await self.provider.packetFlow.readPackets()
|
||||
for (data, number) in zip(packets, numbers) where number == 2 {
|
||||
if let packet = IPPacket(data) {
|
||||
Task.detached {
|
||||
let dstIp = packet.header.destination
|
||||
// 本地通讯, 目标地址是本地服务器的ip地址
|
||||
if dstIp == self.devAddr.netAddr {
|
||||
let nePacket = NEPacket(data: packet.data, protocolFamily: 2)
|
||||
self.provider.packetFlow.writePacketObjects([nePacket])
|
||||
return
|
||||
}
|
||||
|
||||
// 查找arp缓存中是否有目标mac地址
|
||||
if let dstMac = await self.arpServer.query(ip: dstIp) {
|
||||
await self.routeLayerPacket(dstMac: dstMac, type: .ipv4, data: packet.data)
|
||||
}
|
||||
else {
|
||||
// 构造arp请求
|
||||
let broadcastMac = Data([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF])
|
||||
let arpReqeust: ARPPacket = ARPPacket.arpRequest(senderIP: self.devAddr.netAddr, senderMAC: self.devAddr.mac, targetIP: dstIp)
|
||||
await self.routeLayerPacket(dstMac: broadcastMac, type: .arp, data: arpReqeust.marshal())
|
||||
|
||||
NSLog("[SDLContext] dstIp: \(dstIp) arp query not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} while true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private func routeLayerPacket(dstMac: Data, type: LayerPacket.PacketType, data: Data) async {
|
||||
// 将数据封装层2层的数据包
|
||||
let layerPacket = LayerPacket(dstMac: dstMac, srcMac: self.devAddr.mac, type: type, data: data)
|
||||
guard let encodedPacket = try? self.aesCipher.encrypt(aesKey: self.aesKey, data: layerPacket.marshal()) else {
|
||||
return
|
||||
}
|
||||
|
||||
// 通过session发送到对端
|
||||
if let session = await self.sessionManager.getSession(toAddress: dstMac) {
|
||||
NSLog("[SDLContext] send packet by session: \(session)")
|
||||
self.udpHole?.sendPacket(context: self, session: session, data: encodedPacket)
|
||||
|
||||
await self.flowTracer.inc(num: data.count, type: .p2p)
|
||||
}
|
||||
else {
|
||||
// 通过super_node进行转发
|
||||
self.udpHole?.forwardPacket(context: self, dst_mac: dstMac, data: encodedPacket)
|
||||
// 流量统计
|
||||
await self.flowTracer.inc(num: data.count, type: .forward)
|
||||
|
||||
// 尝试打洞
|
||||
await self.holerManager.addHoler(dstMac: dstMac) {
|
||||
self.holerTask(dstMac: dstMac)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func holerTask(dstMac: Data) -> Task<(), Never> {
|
||||
return Task {
|
||||
guard let message = try? await self.superClient?.queryInfo(context: self, dst_mac: dstMac) else {
|
||||
return
|
||||
}
|
||||
|
||||
switch message.packet {
|
||||
case .empty:
|
||||
SDLLogger.log("[SDLContext] hole query_info get empty: \(message)", level: .debug)
|
||||
case .peerInfo(let peerInfo):
|
||||
if let remoteAddress = peerInfo.v4Info.socketAddress() {
|
||||
SDLLogger.log("[SDLContext] hole sock address: \(remoteAddress)", level: .warning)
|
||||
// 发送register包
|
||||
self.udpHole?.sendRegister(context: self, remoteAddress: remoteAddress, dst_mac: dstMac)
|
||||
} else {
|
||||
SDLLogger.log("[SDLContext] hole sock address is invalid: \(peerInfo.v4Info)", level: .warning)
|
||||
}
|
||||
default:
|
||||
SDLLogger.log("[SDLContext] hole query_info is packet: \(message)", level: .warning)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deinit {
|
||||
self.stunCancel?.cancel()
|
||||
self.udpHole = nil
|
||||
self.superClient = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//--MARK: 获取设备的UUID
|
||||
|
||||
extension SDLContextNew {
|
||||
|
||||
public static func getUUID() -> String {
|
||||
let userDefaults = UserDefaults.standard
|
||||
if let uuid = userDefaults.value(forKey: "gClientId") as? String {
|
||||
return uuid
|
||||
} else {
|
||||
let uuid = UUID().uuidString.replacingOccurrences(of: "-", with: "").lowercased()
|
||||
userDefaults.setValue(uuid, forKey: "gClientId")
|
||||
|
||||
return uuid
|
||||
}
|
||||
}
|
||||
|
||||
// 获取mac地址
|
||||
public static func getMacAddress() -> Data {
|
||||
let key = "gMacAddress2"
|
||||
|
||||
let userDefaults = UserDefaults.standard
|
||||
if let mac = userDefaults.value(forKey: key) as? Data {
|
||||
return mac
|
||||
}
|
||||
else {
|
||||
let mac = generateMacAddress()
|
||||
userDefaults.setValue(mac, forKey: key)
|
||||
|
||||
return mac
|
||||
}
|
||||
}
|
||||
|
||||
// 随机生成mac地址
|
||||
private static func generateMacAddress() -> Data {
|
||||
var macAddress = [UInt8](repeating: 0, count: 6)
|
||||
for i in 0..<6 {
|
||||
macAddress[i] = UInt8.random(in: 0...255)
|
||||
}
|
||||
|
||||
return Data(macAddress)
|
||||
}
|
||||
|
||||
// 将mac地址转换成字符串
|
||||
private static func formatMacAddress(mac: Data) -> String {
|
||||
let bytes = [UInt8](mac)
|
||||
|
||||
return bytes.map { String(format: "%02X", $0) }.joined(separator: ":").lowercased()
|
||||
}
|
||||
|
||||
}
|
||||
@ -1,368 +0,0 @@
|
||||
//
|
||||
// SDLWebsocketClient.swift
|
||||
// Tun
|
||||
//
|
||||
// Created by 安礼成 on 2024/3/28.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
import Combine
|
||||
|
||||
// --MARK: 和SuperNode的客户端
|
||||
class SDLSuperClient: ChannelInboundHandler, @unchecked Sendable {
|
||||
public typealias InboundIn = ByteBuffer
|
||||
public typealias OutboundOut = ByteBuffer
|
||||
|
||||
public typealias CallbackFun = (SDLSuperInboundMessage?) -> Void
|
||||
|
||||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
private var channel: Channel?
|
||||
|
||||
// id生成器
|
||||
var idGenerator = SDLIdGenerator(seed: 1)
|
||||
private var callbackManager = SuperCallbackManager()
|
||||
|
||||
let host: String
|
||||
let port: Int
|
||||
|
||||
private var pingCancel: AnyCancellable?
|
||||
|
||||
public var eventFlow = PassthroughSubject<SuperEvent, Never>()
|
||||
|
||||
// 定义事件类型
|
||||
enum SuperEvent {
|
||||
case ready
|
||||
case closed
|
||||
case event(SDLEvent)
|
||||
case command(UInt32, SDLCommand)
|
||||
}
|
||||
|
||||
init(host: String, port: Int) {
|
||||
self.host = host
|
||||
self.port = port
|
||||
}
|
||||
|
||||
func start() async throws {
|
||||
let bootstrap = ClientBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
return channel.pipeline.addHandlers([
|
||||
ByteToMessageHandler(FixedHeaderDelimiterCoder()),
|
||||
MessageToByteHandler(FixedHeaderDelimiterCoder()),
|
||||
self
|
||||
])
|
||||
}
|
||||
|
||||
do {
|
||||
NSLog("super client connect: \(self.host):\(self.port)")
|
||||
self.channel = try await bootstrap.connect(host: self.host, port: self.port).get()
|
||||
} catch let err {
|
||||
NSLog("super client get error: \(err)")
|
||||
self.eventFlow.send(.closed)
|
||||
}
|
||||
}
|
||||
|
||||
// -- MARK: apis
|
||||
|
||||
func commandAck(packetId: UInt32, ack: SDLCommandAck) {
|
||||
guard let data = try? ack.serializedData() else {
|
||||
return
|
||||
}
|
||||
|
||||
self.send(type: .commandAck, packetId: packetId, data: data)
|
||||
}
|
||||
|
||||
func registerSuper(context ctx: SDLContext) async -> SDLSuperInboundMessage? {
|
||||
return await withCheckedContinuation { c in
|
||||
self.registerSuper(context: ctx) { message in
|
||||
c.resume(returning: message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func registerSuper(context ctx: SDLContext, callback: @escaping CallbackFun) {
|
||||
var registerSuper = SDLRegisterSuper()
|
||||
registerSuper.version = UInt32(ctx.config.version)
|
||||
registerSuper.clientID = ctx.config.clientId
|
||||
registerSuper.devAddr = ctx.devAddr
|
||||
registerSuper.pubKey = ctx.rsaCipher.pubKey
|
||||
registerSuper.token = ctx.config.token
|
||||
|
||||
let data = try! registerSuper.serializedData()
|
||||
|
||||
self.write(type: .registerSuper, data: data, callback: callback)
|
||||
}
|
||||
|
||||
func queryInfo(context ctx: SDLContext, dst_mac: Data) async throws -> SDLSuperInboundMessage? {
|
||||
return await withCheckedContinuation { c in
|
||||
self.queryInfo(context: ctx, dst_mac: dst_mac) { message in
|
||||
c.resume(returning: message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询目标服务器的相关信息
|
||||
func queryInfo(context ctx: SDLContext, dst_mac: Data, callback: @escaping CallbackFun) {
|
||||
var queryInfo = SDLQueryInfo()
|
||||
queryInfo.dstMac = dst_mac
|
||||
|
||||
self.write(type: .queryInfo, data: try! queryInfo.serializedData(), callback: callback)
|
||||
}
|
||||
|
||||
func unregister(context ctx: SDLContext) throws {
|
||||
self.send(type: .unregisterSuper, packetId: 0, data: Data())
|
||||
}
|
||||
|
||||
func ping() {
|
||||
self.send(type: .ping, packetId: 0, data: Data())
|
||||
}
|
||||
|
||||
func flowReport(forwardNum: UInt32, p2pNum: UInt32, inboundNum: UInt32) {
|
||||
var flow = SDLFlows()
|
||||
flow.forwardNum = forwardNum
|
||||
flow.p2PNum = p2pNum
|
||||
flow.inboundNum = inboundNum
|
||||
|
||||
self.send(type: .flowTracer, packetId: 0, data: try! flow.serializedData())
|
||||
}
|
||||
|
||||
// --MARK: ChannelInboundHandler
|
||||
|
||||
public func channelActive(context: ChannelHandlerContext) {
|
||||
self.startPingTicker()
|
||||
self.eventFlow.send(.ready)
|
||||
}
|
||||
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
var buffer = self.unwrapInboundIn(data)
|
||||
if let message = decode(buffer: &buffer) {
|
||||
SDLLogger.log("[SDLSuperTransport] read message: \(message)", level: .warning)
|
||||
|
||||
switch message.packet {
|
||||
case .event(let event):
|
||||
self.eventFlow.send(.event(event))
|
||||
case .command(let command):
|
||||
self.eventFlow.send(.command(message.msgId, command))
|
||||
default:
|
||||
self.callbackManager.fireCallback(message: message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
SDLLogger.log("[SDLSuperTransport] error: \(error)", level: .warning)
|
||||
self.channel = nil
|
||||
self.eventFlow.send(.closed)
|
||||
context.close(promise: nil)
|
||||
}
|
||||
|
||||
public func channelInactive(context: ChannelHandlerContext) {
|
||||
SDLLogger.log("[SDLSuperTransport] channelInactive", level: .warning)
|
||||
self.channel = nil
|
||||
context.close(promise: nil)
|
||||
}
|
||||
|
||||
func write(type: SDLPacketType, data: Data, callback: @escaping CallbackFun) {
|
||||
guard let channel = self.channel else {
|
||||
return
|
||||
}
|
||||
|
||||
SDLLogger.log("[SDLSuperTransport] will write data: \(data)", level: .debug)
|
||||
|
||||
let packetId = idGenerator.nextId()
|
||||
self.callbackManager.addCallback(id: packetId, callback: callback)
|
||||
|
||||
channel.eventLoop.execute {
|
||||
var buffer = channel.allocator.buffer(capacity: data.count + 5)
|
||||
buffer.writeInteger(packetId, as: UInt32.self)
|
||||
buffer.writeBytes([type.rawValue])
|
||||
buffer.writeBytes(data)
|
||||
|
||||
channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
func send(type: SDLPacketType, packetId: UInt32, data: Data) {
|
||||
guard let channel = self.channel else {
|
||||
return
|
||||
}
|
||||
|
||||
channel.eventLoop.execute {
|
||||
var buffer = channel.allocator.buffer(capacity: data.count + 5)
|
||||
buffer.writeInteger(packetId, as: UInt32.self)
|
||||
buffer.writeBytes([type.rawValue])
|
||||
buffer.writeBytes(data)
|
||||
|
||||
channel.writeAndFlush(self.wrapOutboundOut(buffer), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
// --MARK: 心跳机制
|
||||
|
||||
private func startPingTicker() {
|
||||
self.pingCancel = Timer.publish(every: 5.0, on: .main, in: .common).autoconnect()
|
||||
.sink { _ in
|
||||
// 保持和super-node的心跳机制
|
||||
self.ping()
|
||||
}
|
||||
}
|
||||
|
||||
deinit {
|
||||
self.pingCancel?.cancel()
|
||||
try! group.syncShutdownGracefully()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// 基于2字节固定长度的分包协议
|
||||
extension SDLSuperClient {
|
||||
private final class FixedHeaderDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder {
|
||||
typealias InboundIn = ByteBuffer
|
||||
typealias InboundOut = ByteBuffer
|
||||
|
||||
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
|
||||
guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else {
|
||||
return .needMoreData
|
||||
}
|
||||
|
||||
if buffer.readableBytes >= len + 2 {
|
||||
buffer.moveReaderIndex(forwardBy: 2)
|
||||
if let bytes = buffer.readBytes(length: Int(len)) {
|
||||
context.fireChannelRead(self.wrapInboundOut(ByteBuffer(bytes: bytes)))
|
||||
}
|
||||
return .continue
|
||||
} else {
|
||||
return .needMoreData
|
||||
}
|
||||
}
|
||||
|
||||
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
|
||||
let len = data.readableBytes
|
||||
out.writeInteger(UInt16(len))
|
||||
out.writeBytes(data.readableBytesView)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回调函数管理器
|
||||
extension SDLSuperClient {
|
||||
|
||||
private struct SuperCallbackManager {
|
||||
// 对应请求体和相应的关系
|
||||
private var callbacks: [UInt32:CallbackFun] = [:]
|
||||
|
||||
mutating func addCallback(id: UInt32, callback: @escaping CallbackFun) {
|
||||
self.callbacks[id] = callback
|
||||
}
|
||||
|
||||
mutating func fireCallback(message: SDLSuperInboundMessage) {
|
||||
if let callback = self.callbacks[message.msgId] {
|
||||
callback(message)
|
||||
self.callbacks.removeValue(forKey: message.msgId)
|
||||
}
|
||||
}
|
||||
|
||||
mutating func fireAllCallbacks(message: SDLSuperInboundMessage) {
|
||||
for (_, callback) in self.callbacks {
|
||||
callback(nil)
|
||||
}
|
||||
self.callbacks.removeAll()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// --MARK: 编解码器
|
||||
extension SDLSuperClient {
|
||||
// 消息格式为: <<MsgId:32, Type:8, Body/binary>>
|
||||
func decode(buffer: inout ByteBuffer) -> SDLSuperInboundMessage? {
|
||||
guard let msgId = buffer.readInteger(as: UInt32.self),
|
||||
let type = buffer.readInteger(as: UInt8.self),
|
||||
let messageType = SDLPacketType(rawValue: type) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case .empty:
|
||||
return .init(msgId: msgId, packet: .empty)
|
||||
case .registerSuperAck:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .registerSuperAck(registerSuperAck))
|
||||
|
||||
case .registerSuperNak:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .registerSuperNak(registerSuperNak))
|
||||
|
||||
case .peerInfo:
|
||||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||||
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return .init(msgId: msgId, packet: .peerInfo(peerInfo))
|
||||
case .pong:
|
||||
return .init(msgId: msgId, packet: .pong)
|
||||
|
||||
case .command:
|
||||
guard let commandVal = buffer.readInteger(as: UInt8.self),
|
||||
let command = SDLCommandType(rawValue: commandVal),
|
||||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch command {
|
||||
case .changeNetwork:
|
||||
guard let changeNetworkCommand = try? SDLChangeNetworkCommand(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
return .init(msgId: msgId, packet: .command(.changeNetwork(changeNetworkCommand)))
|
||||
}
|
||||
|
||||
case .event:
|
||||
guard let eventVal = buffer.readInteger(as: UInt8.self),
|
||||
let event = SDLEventType(rawValue: eventVal),
|
||||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event {
|
||||
case .natChanged:
|
||||
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.natChanged(natChangedEvent)))
|
||||
case .sendRegister:
|
||||
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.sendRegister(sendRegisterEvent)))
|
||||
case .networkShutdown:
|
||||
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
|
||||
return nil
|
||||
}
|
||||
return .init(msgId: msgId, packet: .event(.networkShutdown(networkShutdownEvent)))
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
extension ByteToMessageHandler: @unchecked @retroactive Sendable {
|
||||
|
||||
}
|
||||
|
||||
extension MessageToByteHandler: @unchecked @retroactive Sendable {
|
||||
|
||||
}
|
||||
@ -9,8 +9,6 @@ import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
|
||||
|
||||
|
||||
// --MARK: 和SuperNode的客户端
|
||||
actor SDLSuperClientActor {
|
||||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
|
||||
@ -1,307 +0,0 @@
|
||||
//
|
||||
// SDLanServer.swift
|
||||
// Tun
|
||||
//
|
||||
// Created by 安礼成 on 2024/1/31.
|
||||
//
|
||||
|
||||
import Foundation
|
||||
import NIOCore
|
||||
import NIOPosix
|
||||
import Combine
|
||||
|
||||
// 处理和sn-server服务器之间的通讯
|
||||
class SDLUDPHole: ChannelInboundHandler, @unchecked Sendable {
|
||||
public typealias InboundIn = AddressedEnvelope<ByteBuffer>
|
||||
public typealias OutboundOut = AddressedEnvelope<ByteBuffer>
|
||||
|
||||
// 回调函数
|
||||
public typealias CallbackFun = (SDLStunProbeReply?) -> Void
|
||||
|
||||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||||
|
||||
private var cookieGenerator = SDLIdGenerator(seed: 1)
|
||||
private var callbackManager = HoleCallbackManager()
|
||||
|
||||
public var localAddress: SocketAddress?
|
||||
public var channel: Channel?
|
||||
|
||||
public var eventFlow = PassthroughSubject<UDPEvent, Never>()
|
||||
|
||||
// 定义事件类型
|
||||
enum UDPEvent {
|
||||
case ready
|
||||
case closed
|
||||
case message(SocketAddress, SDLHoleInboundMessage)
|
||||
case data(SDLData)
|
||||
}
|
||||
|
||||
init() {
|
||||
|
||||
}
|
||||
|
||||
// MARK: super_node apis
|
||||
|
||||
func stunRequest(context ctx: SDLContext) -> UInt32 {
|
||||
let cookie = self.cookieGenerator.nextId()
|
||||
let remoteAddress = ctx.config.stunSocketAddress
|
||||
|
||||
var stunRequest = SDLStunRequest()
|
||||
stunRequest.cookie = cookie
|
||||
stunRequest.clientID = ctx.config.clientId
|
||||
stunRequest.networkID = ctx.devAddr.networkID
|
||||
stunRequest.ip = ctx.devAddr.netAddr
|
||||
stunRequest.mac = ctx.devAddr.mac
|
||||
stunRequest.natType = UInt32(ctx.natType.rawValue)
|
||||
|
||||
SDLLogger.log("[SDLUDPHole] stunRequest: \(remoteAddress), host: \(ctx.config.stunServers[0].host):\(ctx.config.stunServers[0].ports[0])", level: .warning)
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! stunRequest.serializedData())
|
||||
|
||||
return cookie
|
||||
}
|
||||
|
||||
// 探测tun信息
|
||||
func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async -> SDLStunProbeReply? {
|
||||
return await withCheckedContinuation { continuation in
|
||||
self.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout) { probeReply in
|
||||
continuation.resume(returning: probeReply)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int, callback: @escaping CallbackFun) {
|
||||
let cookie = self.cookieGenerator.nextId()
|
||||
|
||||
var stunProbe = SDLStunProbe()
|
||||
stunProbe.cookie = cookie
|
||||
stunProbe.attr = UInt32(attr.rawValue)
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData())
|
||||
|
||||
SDLLogger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .warning)
|
||||
|
||||
self.callbackManager.addCallback(id: cookie, callback: callback)
|
||||
}
|
||||
|
||||
// MARK: client-client apis
|
||||
|
||||
// 发送数据包到其他session
|
||||
func sendPacket(context ctx: SDLContext, session: Session, data: Data) {
|
||||
let remoteAddress = session.natAddress
|
||||
|
||||
var dataPacket = SDLData()
|
||||
dataPacket.networkID = ctx.devAddr.networkID
|
||||
dataPacket.srcMac = ctx.devAddr.mac
|
||||
dataPacket.dstMac = session.dstMac
|
||||
dataPacket.ttl = 255
|
||||
dataPacket.data = data
|
||||
let packet = try! dataPacket.serializedData()
|
||||
|
||||
SDLLogger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug)
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
|
||||
}
|
||||
|
||||
// 通过sn服务器转发数据包, data已经是加密过后的数据
|
||||
func forwardPacket(context ctx: SDLContext, dst_mac: Data, data: Data) {
|
||||
let remoteAddress = ctx.config.stunSocketAddress
|
||||
|
||||
var dataPacket = SDLData()
|
||||
dataPacket.networkID = ctx.devAddr.networkID
|
||||
dataPacket.srcMac = ctx.devAddr.mac
|
||||
dataPacket.dstMac = dst_mac
|
||||
dataPacket.ttl = 255
|
||||
dataPacket.data = data
|
||||
|
||||
let packet = try! dataPacket.serializedData()
|
||||
|
||||
NSLog("[SDLContext] forward packet, remoteAddress: \(remoteAddress), data size: \(packet.count)")
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
|
||||
}
|
||||
|
||||
// 发送register包
|
||||
func sendRegister(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
|
||||
var register = SDLRegister()
|
||||
register.networkID = ctx.devAddr.networkID
|
||||
register.srcMac = ctx.devAddr.mac
|
||||
register.dstMac = dst_mac
|
||||
|
||||
SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: ctx.devAddr.mac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug)
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .register, data: try! register.serializedData())
|
||||
}
|
||||
|
||||
// 回复registerAck
|
||||
func sendRegisterAck(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
|
||||
var registerAck = SDLRegisterAck()
|
||||
registerAck.networkID = ctx.devAddr.networkID
|
||||
registerAck.srcMac = ctx.devAddr.mac
|
||||
registerAck.dstMac = dst_mac
|
||||
|
||||
SDLLogger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug)
|
||||
|
||||
self.send(remoteAddress: remoteAddress, type: .registerAck, data: try! registerAck.serializedData())
|
||||
}
|
||||
|
||||
// 启动函数
|
||||
func start() async throws {
|
||||
|
||||
let bootstrap = DatagramBootstrap(group: self.group)
|
||||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||||
.channelInitializer { channel in
|
||||
// 接收缓冲区
|
||||
return channel.setOption(ChannelOptions.socketOption(.so_rcvbuf), value: 5 * 1024 * 1024)
|
||||
.flatMap {
|
||||
channel.setOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_SNDBUF), value: 5 * 1024 * 1024)
|
||||
}.flatMap {
|
||||
channel.pipeline.addHandler(self)
|
||||
}
|
||||
}
|
||||
|
||||
let channel = try await bootstrap.bind(host: "0.0.0.0", port: 0).get()
|
||||
|
||||
SDLLogger.log("[UDPHole] started and listening on: \(channel.localAddress!)", level: .debug)
|
||||
self.localAddress = channel.localAddress
|
||||
self.channel = channel
|
||||
}
|
||||
|
||||
// -- MARK: ChannelInboundHandler Methods
|
||||
|
||||
public func channelActive(context: ChannelHandlerContext) {
|
||||
self.eventFlow.send(.ready)
|
||||
}
|
||||
|
||||
// 接收到的消息, 消息需要根据类型分流
|
||||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||||
let envelope = self.unwrapInboundIn(data)
|
||||
var buffer = envelope.data
|
||||
let remoteAddress = envelope.remoteAddress
|
||||
|
||||
do {
|
||||
if let message = try decode(buffer: &buffer) {
|
||||
Task {
|
||||
switch message {
|
||||
case .data(let data):
|
||||
SDLLogger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug)
|
||||
self.eventFlow.send(.data(data))
|
||||
case .stunProbeReply(let probeReply):
|
||||
self.callbackManager.fireCallback(message: probeReply)
|
||||
default:
|
||||
self.eventFlow.send(.message(remoteAddress, message))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SDLLogger.log("[SDLUDPHole] decode message, get null", level: .warning)
|
||||
}
|
||||
} catch let err {
|
||||
SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", level: .debug)
|
||||
}
|
||||
}
|
||||
|
||||
public func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||||
SDLLogger.log("[SDLUDPHole] get error: \(error)", level: .error)
|
||||
// As we are not really interested getting notified on success or failure we just pass nil as promise to
|
||||
// reduce allocations.
|
||||
context.close(promise: nil)
|
||||
self.channel = nil
|
||||
self.eventFlow.send(.closed)
|
||||
}
|
||||
|
||||
public func channelInactive(context: ChannelHandlerContext) {
|
||||
self.channel = nil
|
||||
context.close(promise: nil)
|
||||
}
|
||||
|
||||
// 处理写入逻辑
|
||||
func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) {
|
||||
guard let channel = self.channel else {
|
||||
return
|
||||
}
|
||||
|
||||
channel.eventLoop.execute {
|
||||
var buffer = channel.allocator.buffer(capacity: data.count + 1)
|
||||
buffer.writeBytes([type.rawValue])
|
||||
buffer.writeBytes(data)
|
||||
|
||||
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
|
||||
channel.writeAndFlush(self.wrapOutboundOut(envelope), promise: nil)
|
||||
}
|
||||
}
|
||||
|
||||
deinit {
|
||||
try? self.group.syncShutdownGracefully()
|
||||
}
|
||||
}
|
||||
|
||||
//--MARK: 编解码器
|
||||
extension SDLUDPHole {
|
||||
|
||||
func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? {
|
||||
guard let type = buffer.readInteger(as: UInt8.self),
|
||||
let packetType = SDLPacketType(rawValue: type),
|
||||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||||
SDLLogger.log("[SDLUDPHole] decode error", level: .error)
|
||||
return nil
|
||||
}
|
||||
|
||||
switch packetType {
|
||||
case .data:
|
||||
let dataPacket = try SDLData(serializedBytes: bytes)
|
||||
return .data(dataPacket)
|
||||
case .register:
|
||||
let registerPacket = try SDLRegister(serializedBytes: bytes)
|
||||
return .register(registerPacket)
|
||||
case .registerAck:
|
||||
let registerAck = try SDLRegisterAck(serializedBytes: bytes)
|
||||
return .registerAck(registerAck)
|
||||
case .stunReply:
|
||||
let stunReply = try SDLStunReply(serializedBytes: bytes)
|
||||
return .stunReply(stunReply)
|
||||
case .stunProbeReply:
|
||||
let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes)
|
||||
return .stunProbeReply(stunProbeReply)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --MARK: 回调函数管理器
|
||||
extension SDLUDPHole {
|
||||
|
||||
private struct HoleCallbackManager {
|
||||
// 存储回调函数和对应的超时任务
|
||||
private var callbacks: [UInt32: CallbackFun] = [:]
|
||||
|
||||
//private var timeoutCallbacks: [UInt32: CallbackFun] = [:]
|
||||
|
||||
// 添加回调并设置超时
|
||||
mutating func addCallback(id: UInt32, callback: @escaping CallbackFun) {
|
||||
// 存储回调
|
||||
self.callbacks[id] = callback
|
||||
}
|
||||
|
||||
// 正常触发回调(收到响应)
|
||||
mutating func fireCallback(message: SDLStunProbeReply) {
|
||||
let id = message.cookie
|
||||
// 执行并移除回调
|
||||
if let callback = callbacks[id] {
|
||||
callback(message)
|
||||
self.callbacks.removeValue(forKey: id)
|
||||
}
|
||||
}
|
||||
|
||||
// 触发所有回调(清理场景)
|
||||
mutating func fireAllCallbacks(message: SDLSuperInboundMessage) {
|
||||
// 触发所有回调
|
||||
for callback in callbacks.values {
|
||||
callback(nil)
|
||||
}
|
||||
self.callbacks.removeAll()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user