swiftlib_sdlan/Sources/sdlan/SDLUDPHole.swift
2025-07-15 00:17:20 +08:00

328 lines
12 KiB
Swift

//
// SDLanServer.swift
// Tun
//
// Created by on 2024/1/31.
//
import Foundation
import NIOCore
import NIOPosix
import Combine
// sn-server
class SDLUDPHole: ChannelInboundHandler, @unchecked Sendable {
public typealias InboundIn = AddressedEnvelope<ByteBuffer>
public typealias OutboundOut = AddressedEnvelope<ByteBuffer>
//
public typealias CallbackFun = (SDLStunProbeReply?) -> Void
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var cookieGenerator = SDLIdGenerator(seed: 1)
private let callbackManager = HoleCallbackManager()
public var localAddress: SocketAddress?
public var channel: Channel?
public var eventFlow = PassthroughSubject<UDPEvent, Never>()
//
enum UDPEvent {
case ready
case closed
case message(SocketAddress, SDLHoleInboundMessage)
case data(SDLData)
}
init() {
}
// MARK: super_node apis
func stunRequest(context ctx: SDLContext) -> UInt32 {
let cookie = self.cookieGenerator.nextId()
let remoteAddress = ctx.config.stunSocketAddress
var stunRequest = SDLStunRequest()
stunRequest.cookie = cookie
stunRequest.clientID = ctx.config.clientId
stunRequest.networkID = ctx.devAddr.networkID
stunRequest.ip = ctx.devAddr.netAddr
stunRequest.mac = ctx.devAddr.mac
stunRequest.natType = UInt32(ctx.natType.rawValue)
SDLLogger.log("[SDLUDPHole] stunRequest: \(remoteAddress), host: \(ctx.config.stunServers[0].host):\(ctx.config.stunServers[0].ports[0])", level: .warning)
self.send(remoteAddress: remoteAddress, type: .stunRequest, data: try! stunRequest.serializedData())
return cookie
}
// tun
func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int = 5) async -> SDLStunProbeReply? {
return await withCheckedContinuation { continuation in
self.stunProbe(remoteAddress: remoteAddress, attr: attr, timeout: timeout) { probeReply in
continuation.resume(returning: probeReply)
}
}
}
private func stunProbe(remoteAddress: SocketAddress, attr: SDLProbeAttr = .none, timeout: Int, callback: @escaping CallbackFun) {
let cookie = self.cookieGenerator.nextId()
var stunProbe = SDLStunProbe()
stunProbe.cookie = cookie
stunProbe.attr = UInt32(attr.rawValue)
self.send(remoteAddress: remoteAddress, type: .stunProbe, data: try! stunProbe.serializedData())
SDLLogger.log("[SDLUDPHole] stunProbe: \(remoteAddress)", level: .warning)
self.callbackManager.addCallback(id: cookie, timeout: timeout, callback: callback)
}
// MARK: client-client apis
// session
func sendPacket(context ctx: SDLContext, session: Session, data: Data) {
let remoteAddress = session.natAddress
var dataPacket = SDLData()
dataPacket.networkID = ctx.devAddr.networkID
dataPacket.srcMac = ctx.devAddr.mac
dataPacket.dstMac = session.dstMac
dataPacket.ttl = 255
dataPacket.data = data
let packet = try! dataPacket.serializedData()
SDLLogger.log("[SDLUDPHole] sendPacket: \(remoteAddress), count: \(packet.count)", level: .debug)
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
}
// sn, data
func forwardPacket(context ctx: SDLContext, dst_mac: Data, data: Data) {
let remoteAddress = ctx.config.stunSocketAddress
var dataPacket = SDLData()
dataPacket.networkID = ctx.devAddr.networkID
dataPacket.srcMac = ctx.devAddr.mac
dataPacket.dstMac = dst_mac
dataPacket.ttl = 255
dataPacket.data = data
let packet = try! dataPacket.serializedData()
NSLog("[SDLContext] forward packet, remoteAddress: \(remoteAddress), data size: \(packet.count)")
self.send(remoteAddress: remoteAddress, type: .data, data: packet)
}
// register
func sendRegister(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
var register = SDLRegister()
register.networkID = ctx.devAddr.networkID
register.srcMac = ctx.devAddr.mac
register.dstMac = dst_mac
SDLLogger.log("[SDLUDPHole] SendRegister: \(remoteAddress), src_mac: \(LayerPacket.MacAddress.description(data: ctx.devAddr.mac)), dst_mac: \(LayerPacket.MacAddress.description(data: dst_mac))", level: .debug)
self.send(remoteAddress: remoteAddress, type: .register, data: try! register.serializedData())
}
// registerAck
func sendRegisterAck(context ctx: SDLContext, remoteAddress: SocketAddress, dst_mac: Data) {
var registerAck = SDLRegisterAck()
registerAck.networkID = ctx.devAddr.networkID
registerAck.srcMac = ctx.devAddr.mac
registerAck.dstMac = dst_mac
SDLLogger.log("[SDLUDPHole] SendRegisterAck: \(remoteAddress), \(registerAck)", level: .debug)
self.send(remoteAddress: remoteAddress, type: .registerAck, data: try! registerAck.serializedData())
}
//
func start() async throws {
let bootstrap = DatagramBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
//
return channel.setOption(ChannelOptions.socketOption(.so_rcvbuf), value: 5 * 1024 * 1024)
.flatMap {
channel.setOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_SNDBUF), value: 5 * 1024 * 1024)
}.flatMap {
channel.pipeline.addHandler(self)
}
}
let channel = try await bootstrap.bind(host: "0.0.0.0", port: 0).get()
SDLLogger.log("[UDPHole] started and listening on: \(channel.localAddress!)", level: .debug)
self.localAddress = channel.localAddress
self.channel = channel
}
// -- MARK: ChannelInboundHandler Methods
public func channelActive(context: ChannelHandlerContext) {
self.eventFlow.send(.ready)
}
// ,
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = self.unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
do {
if let message = try decode(buffer: &buffer) {
Task {
switch message {
case .data(let data):
SDLLogger.log("[SDLUDPHole] read data: \(data.format()), from: \(remoteAddress)", level: .debug)
self.eventFlow.send(.data(data))
case .stunProbeReply(let probeReply):
self.callbackManager.fireCallback(message: probeReply)
default:
self.eventFlow.send(.message(remoteAddress, message))
}
}
} else {
SDLLogger.log("[SDLUDPHole] decode message, get null", level: .warning)
}
} catch let err {
SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", level: .debug)
}
}
public func errorCaught(context: ChannelHandlerContext, error: Error) {
SDLLogger.log("[SDLUDPHole] get error: \(error)", level: .error)
// As we are not really interested getting notified on success or failure we just pass nil as promise to
// reduce allocations.
context.close(promise: nil)
self.channel = nil
self.eventFlow.send(.closed)
}
public func channelInactive(context: ChannelHandlerContext) {
self.channel = nil
context.close(promise: nil)
}
//
func send(remoteAddress: SocketAddress, type: SDLPacketType, data: Data) {
guard let channel = self.channel else {
return
}
channel.eventLoop.execute {
var buffer = channel.allocator.buffer(capacity: data.count + 1)
buffer.writeBytes([type.rawValue])
buffer.writeBytes(data)
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
channel.writeAndFlush(self.wrapOutboundOut(envelope), promise: nil)
}
}
deinit {
try? self.group.syncShutdownGracefully()
}
}
//--MARK:
extension SDLUDPHole {
func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? {
guard let type = buffer.readInteger(as: UInt8.self),
let packetType = SDLPacketType(rawValue: type),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
SDLLogger.log("[SDLUDPHole] decode error", level: .error)
return nil
}
switch packetType {
case .data:
let dataPacket = try SDLData(serializedBytes: bytes)
return .data(dataPacket)
case .register:
let registerPacket = try SDLRegister(serializedBytes: bytes)
return .register(registerPacket)
case .registerAck:
let registerAck = try SDLRegisterAck(serializedBytes: bytes)
return .registerAck(registerAck)
case .stunReply:
let stunReply = try SDLStunReply(serializedBytes: bytes)
return .stunReply(stunReply)
case .stunProbeReply:
let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes)
return .stunProbeReply(stunProbeReply)
default:
return nil
}
}
}
// --MARK:
extension SDLUDPHole {
private final class HoleCallbackManager {
//
private var callbacks: [UInt32:CallbackFun] = [:]
private let locker = NSLock()
func addCallback(id: UInt32, timeout: Int, callback: @escaping CallbackFun) {
locker.lock()
defer {
locker.unlock()
}
DispatchQueue.global().asyncAfter(deadline: .now() + Double(timeout)) {
self.fireCallback(cookie: id)
}
self.callbacks[id] = callback
}
func fireCallback(message: SDLStunProbeReply) {
locker.lock()
defer {
locker.unlock()
}
if let callback = self.callbacks[message.cookie] {
callback(message)
self.callbacks.removeValue(forKey: message.cookie)
}
}
func fireAllCallbacks(message: SDLSuperInboundMessage) {
locker.lock()
defer {
locker.unlock()
}
for (_, callback) in self.callbacks {
callback(nil)
}
self.callbacks.removeAll()
}
private func fireCallback(cookie: UInt32) {
locker.lock()
defer {
locker.unlock()
}
if let callback = self.callbacks[cookie] {
callback(nil)
self.callbacks.removeValue(forKey: cookie)
}
}
}
}