307 lines
11 KiB
Swift
307 lines
11 KiB
Swift
//
|
||
// SDLanServer.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2024/1/31.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import NIOPosix
|
||
import Combine
|
||
|
||
// 处理和sn-server服务器之间的通讯
|
||
class SDLUDPHole: ChannelInboundHandler, @unchecked Sendable {
|
||
public typealias InboundIn = AddressedEnvelope<ByteBuffer>
|
||
public typealias OutboundOut = AddressedEnvelope<ByteBuffer>
|
||
|
||
// 回调函数
|
||
public typealias CallbackFun = (SDLStunProbeReply?) -> Void
|
||
|
||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||
|
||
private var cookieGenerator = SDLIdGenerator(seed: 1)
|
||
private var callbackManager = HoleCallbackManager()
|
||
|
||
public var localAddress: SocketAddress?
|
||
public var channel: Channel?
|
||
|
||
public var eventFlow = PassthroughSubject<UDPEvent, Never>()
|
||
|
||
// 定义事件类型
|
||
enum UDPEvent {
|
||
case ready
|
||
case closed
|
||
case message(SocketAddress, SDLHoleInboundMessage)
|
||
case data(SDLData)
|
||
}
|
||
|
||
init() {
|
||
|
||
}
|
||
|
||
// MARK: super_node apis
|
||
|
||
func stunRequest(context ctx: SDLContext) -> UInt32 {
|
||
let cookie = self.cookieGenerator.nextId()
|
||
let remoteAddress = ctx.config.stunSocketAddress
|
||
|
||
var stunRequest = SDLStunRequest()
|
||
stunRequest.cookie = cookie
|
||
stunRequest.clientID = ctx.config.clientId
|
||
stunRequest.networkID = ctx.devAddr.networkID
|
||
stunRequest.ip = ctx.devAddr.netAddr
|
||
stunRequest.mac = ctx.devAddr.mac
|
||
stunRequest.natType = UInt32(ctx.natType.rawValue)
|
||
|
||
SDLLogger.log("[SDLUDPHole] stunRequest: \(remoteAddress), host: \(ctx.config.stunServers[0].host):\(ctx.config.stunServers[0].ports[0])", level: .warning)
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! stunRequest.serializedData())
|
||
|
||
return cookie
|
||
}
|
||
|
||
// 探测tun信息
|
||
func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async -> SDLStunProbeReply? {
|
||
return await withCheckedContinuation { continuation in
|
||
self.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout) { probeReply in
|
||
continuation.resume(returning: probeReply)
|
||
}
|
||
}
|
||
}
|
||
|
||
private func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int, callback: @escaping CallbackFun) {
|
||
let cookie = self.cookieGenerator.nextId()
|
||
|
||
var stunProbe = SDLStunProbe()
|
||
stunProbe.cookie = cookie
|
||
stunProbe.attr = UInt32(attr.rawValue)
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData())
|
||
|
||
SDLLogger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .warning)
|
||
|
||
self.callbackManager.addCallback(id: cookie, callback: callback)
|
||
}
|
||
|
||
// MARK: client-client apis
|
||
|
||
// 发送数据包到其他session
|
||
func sendPacket(context ctx: SDLContext, session: Session, data: Data) {
|
||
let remoteAddress = session.natAddress
|
||
|
||
var dataPacket = SDLData()
|
||
dataPacket.networkID = ctx.devAddr.networkID
|
||
dataPacket.srcMac = ctx.devAddr.mac
|
||
dataPacket.dstMac = session.dstMac
|
||
dataPacket.ttl = 255
|
||
dataPacket.data = data
|
||
let packet = try! dataPacket.serializedData()
|
||
|
||
SDLLogger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug)
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
|
||
}
|
||
|
||
// 通过sn服务器转发数据包, data已经是加密过后的数据
|
||
func forwardPacket(context ctx: SDLContext, dst_mac: Data, data: Data) {
|
||
let remoteAddress = ctx.config.stunSocketAddress
|
||
|
||
var dataPacket = SDLData()
|
||
dataPacket.networkID = ctx.devAddr.networkID
|
||
dataPacket.srcMac = ctx.devAddr.mac
|
||
dataPacket.dstMac = dst_mac
|
||
dataPacket.ttl = 255
|
||
dataPacket.data = data
|
||
|
||
let packet = try! dataPacket.serializedData()
|
||
|
||
NSLog("[SDLContext] forward packet, remoteAddress: \(remoteAddress), data size: \(packet.count)")
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
|
||
}
|
||
|
||
// 发送register包
|
||
func sendRegister(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
|
||
var register = SDLRegister()
|
||
register.networkID = ctx.devAddr.networkID
|
||
register.srcMac = ctx.devAddr.mac
|
||
register.dstMac = dst_mac
|
||
|
||
SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: ctx.devAddr.mac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug)
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .register, data: try! register.serializedData())
|
||
}
|
||
|
||
// 回复registerAck
|
||
func sendRegisterAck(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
|
||
var registerAck = SDLRegisterAck()
|
||
registerAck.networkID = ctx.devAddr.networkID
|
||
registerAck.srcMac = ctx.devAddr.mac
|
||
registerAck.dstMac = dst_mac
|
||
|
||
SDLLogger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug)
|
||
|
||
self.send(remoteAddress: remoteAddress, type: .registerAck, data: try! registerAck.serializedData())
|
||
}
|
||
|
||
// 启动函数
|
||
func start() async throws {
|
||
let bootstrap = DatagramBootstrap(group: self.group)
|
||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||
.channelInitializer { channel in
|
||
// 接收缓冲区
|
||
return channel.setOption(ChannelOptions.socketOption(.so_rcvbuf), value: 5 * 1024 * 1024)
|
||
.flatMap {
|
||
channel.setOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_SNDBUF), value: 5 * 1024 * 1024)
|
||
}.flatMap {
|
||
channel.pipeline.addHandler(self)
|
||
}
|
||
}
|
||
|
||
let channel = try await bootstrap.bind(host: "0.0.0.0", port: 0).get()
|
||
|
||
SDLLogger.log("[UDPHole] started and listening on: \(channel.localAddress!)", level: .debug)
|
||
self.localAddress = channel.localAddress
|
||
self.channel = channel
|
||
}
|
||
|
||
// -- MARK: ChannelInboundHandler Methods
|
||
|
||
public func channelActive(context: ChannelHandlerContext) {
|
||
self.eventFlow.send(.ready)
|
||
}
|
||
|
||
// 接收到的消息, 消息需要根据类型分流
|
||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||
let envelope = self.unwrapInboundIn(data)
|
||
var buffer = envelope.data
|
||
let remoteAddress = envelope.remoteAddress
|
||
|
||
do {
|
||
if let message = try decode(buffer: &buffer) {
|
||
Task {
|
||
switch message {
|
||
case .data(let data):
|
||
SDLLogger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug)
|
||
self.eventFlow.send(.data(data))
|
||
case .stunProbeReply(let probeReply):
|
||
self.callbackManager.fireCallback(message: probeReply)
|
||
default:
|
||
self.eventFlow.send(.message(remoteAddress, message))
|
||
}
|
||
}
|
||
} else {
|
||
SDLLogger.log("[SDLUDPHole] decode message, get null", level: .warning)
|
||
}
|
||
} catch let err {
|
||
SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", level: .debug)
|
||
}
|
||
}
|
||
|
||
public func errorCaught(context: ChannelHandlerContext, error: Error) {
|
||
SDLLogger.log("[SDLUDPHole] get error: \(error)", level: .error)
|
||
// As we are not really interested getting notified on success or failure we just pass nil as promise to
|
||
// reduce allocations.
|
||
context.close(promise: nil)
|
||
self.channel = nil
|
||
self.eventFlow.send(.closed)
|
||
}
|
||
|
||
public func channelInactive(context: ChannelHandlerContext) {
|
||
self.channel = nil
|
||
context.close(promise: nil)
|
||
}
|
||
|
||
// 处理写入逻辑
|
||
func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) {
|
||
guard let channel = self.channel else {
|
||
return
|
||
}
|
||
|
||
channel.eventLoop.execute {
|
||
var buffer = channel.allocator.buffer(capacity: data.count + 1)
|
||
buffer.writeBytes([type.rawValue])
|
||
buffer.writeBytes(data)
|
||
|
||
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
|
||
channel.writeAndFlush(self.wrapOutboundOut(envelope), promise: nil)
|
||
}
|
||
}
|
||
|
||
deinit {
|
||
try? self.group.syncShutdownGracefully()
|
||
}
|
||
}
|
||
|
||
//--MARK: 编解码器
|
||
extension SDLUDPHole {
|
||
|
||
func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? {
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type),
|
||
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
|
||
SDLLogger.log("[SDLUDPHole] decode error", level: .error)
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .data:
|
||
let dataPacket = try SDLData(serializedBytes: bytes)
|
||
return .data(dataPacket)
|
||
case .register:
|
||
let registerPacket = try SDLRegister(serializedBytes: bytes)
|
||
return .register(registerPacket)
|
||
case .registerAck:
|
||
let registerAck = try SDLRegisterAck(serializedBytes: bytes)
|
||
return .registerAck(registerAck)
|
||
case .stunReply:
|
||
let stunReply = try SDLStunReply(serializedBytes: bytes)
|
||
return .stunReply(stunReply)
|
||
case .stunProbeReply:
|
||
let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes)
|
||
return .stunProbeReply(stunProbeReply)
|
||
default:
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
// --MARK: 回调函数管理器
|
||
extension SDLUDPHole {
|
||
|
||
private struct HoleCallbackManager {
|
||
// 存储回调函数和对应的超时任务
|
||
private var callbacks: [UInt32: CallbackFun] = [:]
|
||
|
||
//private var timeoutCallbacks: [UInt32: CallbackFun] = [:]
|
||
|
||
// 添加回调并设置超时
|
||
mutating func addCallback(id: UInt32, callback: @escaping CallbackFun) {
|
||
// 存储回调
|
||
self.callbacks[id] = callback
|
||
}
|
||
|
||
// 正常触发回调(收到响应)
|
||
mutating func fireCallback(message: SDLStunProbeReply) {
|
||
let id = message.cookie
|
||
// 执行并移除回调
|
||
if let callback = callbacks[id] {
|
||
callback(message)
|
||
self.callbacks.removeValue(forKey: id)
|
||
}
|
||
}
|
||
|
||
// 触发所有回调(清理场景)
|
||
mutating func fireAllCallbacks(message: SDLSuperInboundMessage) {
|
||
// 触发所有回调
|
||
for callback in callbacks.values {
|
||
callback(nil)
|
||
}
|
||
self.callbacks.removeAll()
|
||
}
|
||
|
||
}
|
||
|
||
}
|