This commit is contained in:
anlicheng 2026-01-29 23:23:53 +08:00
parent faebe09da0
commit df236d4c1f
4 changed files with 203 additions and 187 deletions

View File

@ -119,12 +119,57 @@ public class SDLContext {
}
}
// event
group.addTask {
if let eventStream = self.udpHole?.eventStream {
for try await event in eventStream {
try Task.checkCancellation()
Task {
try await self.dispatchEvent(event: event)
switch event {
case .ready:
try await self.handleUDPHoleReady()
case .closed:
()
}
}
}
}
}
//
group.addTask {
if let dataStream = self.udpHole?.dataStream {
for try await data in dataStream {
try Task.checkCancellation()
Task {
try await self.handleData(data: data)
}
}
}
}
// signal
group.addTask {
if let signalStream = self.udpHole?.signalStream {
for try await(remoteAddress, signal) in signalStream {
try Task.checkCancellation()
Task {
switch signal {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
}
@ -215,32 +260,6 @@ public class SDLContext {
}
}
private func dispatchEvent(event: SDLUDPHole.UDPHoleEvent) async throws {
switch event {
case .ready:
try await self.handleUDPHoleReady()
case .message(let remoteAddress, let message):
switch message {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .data(let data):
try await self.handleData(data: data)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async {
// rsa
let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey))

View File

@ -10,7 +10,9 @@ import NIOCore
import NIOPosix
// sn-server
final class SDLDNSClient {
final class SDLDNSClient: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var channel: Channel?
@ -31,12 +33,31 @@ final class SDLDNSClient {
let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
channel.pipeline.addHandler(SDLDNSInboundHandler(packetContinuation: self.packetContinuation, logger: self.logger))
channel.pipeline.addHandler(self)
}
self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[DNSClient] started", level: .debug)
}
// --MARK: ChannelInboundHandler delegate
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
self.logger.log("[DNSClient] read data: \(buffer), from: \(remoteAddress)", level: .debug)
let len = buffer.readableBytes
if let bytes = buffer.readBytes(length: len) {
self.packetContinuation.yield(Data(bytes))
}
}
func channelInactive(context: ChannelHandlerContext) {
self.packetContinuation.finish()
}
func forward(ipPacket: IPPacket) {
guard let channel = self.channel else {
@ -59,37 +80,6 @@ final class SDLDNSClient {
extension SDLDNSClient {
private final class SDLDNSInboundHandler: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private var packetContinuation: AsyncStream<Data>.Continuation
private var logger: SDLLogger
// --MARK: ChannelInboundHandler delegate
init(packetContinuation: AsyncStream<Data>.Continuation, logger: SDLLogger) {
self.packetContinuation = packetContinuation
self.logger = logger
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
self.logger.log("[DNSClient] read data: \(buffer), from: \(remoteAddress)", level: .debug)
let len = buffer.readableBytes
if let bytes = buffer.readBytes(length: len) {
self.packetContinuation.yield(Data(bytes))
}
}
func channelInactive(context: ChannelHandlerContext) {
self.packetContinuation.finish()
}
}
struct Helper {
static let dnsServer: String = "100.100.100.100"
// dns

View File

@ -93,8 +93,12 @@ extension SDLStunProbeReply {
}
// --MARK: ,
enum SDLHoleMessage {
case data(SDLData)
case signal(SDLHoleSignal)
}
enum SDLHoleInboundMessage {
enum SDLHoleSignal {
case registerSuperAck(SDLRegisterSuperAck)
case registerSuperNak(SDLRegisterSuperNak)
@ -103,7 +107,6 @@ enum SDLHoleInboundMessage {
case stunProbeReply(SDLStunProbeReply)
case data(SDLData)
case register(SDLRegister)
case registerAck(SDLRegisterAck)
}

View File

@ -19,36 +19,87 @@ import NIOPosix
import SwiftProtobuf
// sn-server
final class SDLUDPHole {
final class SDLUDPHole: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var channel: Channel?
public let eventStream: AsyncStream<UDPHoleEvent>
private let eventContinuation: AsyncStream<UDPHoleEvent>.Continuation
public let signalStream: AsyncStream<(SocketAddress, SDLHoleSignal)>
private let signalContinuation: AsyncStream<(SocketAddress, SDLHoleSignal)>.Continuation
public let dataStream: AsyncStream<SDLData>
private let dataContinuation: AsyncStream<SDLData>.Continuation
public let eventStream: AsyncStream<HoleEvent>
private let eventContinuation: AsyncStream<HoleEvent>.Continuation
private let logger: SDLLogger
enum UDPHoleEvent {
enum HoleEvent {
case ready
case message(SocketAddress, SDLHoleInboundMessage)
case closed
}
//
init(logger: SDLLogger) throws {
self.logger = logger
(self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: UDPHoleEvent.self, bufferingPolicy: .unbounded)
(self.signalStream, self.signalContinuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleSignal).self, bufferingPolicy: .unbounded)
(self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded)
(self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: HoleEvent.self, bufferingPolicy: .unbounded)
}
func start() throws {
let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
channel.pipeline.addHandler(SDLUDPHoleHandler(eventContinuation: self.eventContinuation, logger: self.logger))
channel.pipeline.addHandler(self)
}
self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[UDPHole] started", level: .debug)
}
// --MARK: ChannelInboundHandler delegate
func channelActive(context: ChannelHandlerContext) {
self.eventContinuation.yield(.ready)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
do {
if let message = try decode(buffer: &buffer) {
switch message {
case .data(let data):
self.dataContinuation.yield(data)
case .signal(let signal):
self.signalContinuation.yield((remoteAddress, signal))
}
} else {
self.logger.log("[SDLUDPHole] decode message, get null", level: .warning)
}
} catch let err {
self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning)
}
}
func channelInactive(context: ChannelHandlerContext) {
self.signalContinuation.finish()
self.dataContinuation.finish()
self.eventContinuation.yield(.closed)
self.eventContinuation.finish()
context.close(promise: nil)
}
func errorCaught(context: ChannelHandlerContext, error: any Error) {
context.close(promise: nil)
}
func getLocalAddress() -> SocketAddress? {
return self.channel?.localAddress
}
@ -69,6 +120,77 @@ final class SDLUDPHole {
}
}
// --MARK:
private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? {
guard let type = buffer.readInteger(as: UInt8.self),
let packetType = SDLPacketType(rawValue: type),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch packetType {
case .data:
let dataPacket = try SDLData(serializedBytes: bytes)
return .data(dataPacket)
case .register:
let registerPacket = try SDLRegister(serializedBytes: bytes)
return .signal(.register(registerPacket))
case .registerAck:
let registerAck = try SDLRegisterAck(serializedBytes: bytes)
return .signal(.registerAck(registerAck))
case .stunProbeReply:
let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes)
return .signal(.stunProbeReply(stunProbeReply))
case .registerSuperAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
return nil
}
return .signal(.registerSuperAck(registerSuperAck))
case .registerSuperNak:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
return nil
}
return .signal(.registerSuperNak(registerSuperNak))
case .peerInfo:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
return nil
}
return .signal(.peerInfo(peerInfo))
case .event:
guard let eventVal = buffer.readInteger(as: UInt8.self),
let event = SDLEventType(rawValue: eventVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch event {
case .natChanged:
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
return nil
}
return .signal(.event(.natChanged(natChangedEvent)))
case .sendRegister:
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
return nil
}
return .signal(.event(.sendRegister(sendRegisterEvent)))
case .networkShutdown:
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
return nil
}
return .signal(.event(.networkShutdown(networkShutdownEvent)))
}
default:
return nil
}
}
deinit {
try? self.group.syncShutdownGracefully()
self.eventContinuation.finish()
@ -76,121 +198,3 @@ final class SDLUDPHole {
}
}
extension SDLUDPHole {
final class SDLUDPHoleHandler: ChannelInboundHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
private var eventContinuation: AsyncStream<UDPHoleEvent>.Continuation
private var logger: SDLLogger
// --MARK: ChannelInboundHandler delegate
init(eventContinuation: AsyncStream<UDPHoleEvent>.Continuation, logger: SDLLogger) {
self.eventContinuation = eventContinuation
self.logger = logger
}
func channelActive(context: ChannelHandlerContext) {
self.eventContinuation.yield(.ready)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = unwrapInboundIn(data)
var buffer = envelope.data
let remoteAddress = envelope.remoteAddress
do {
if let message = try decode(buffer: &buffer) {
self.eventContinuation.yield(.message(remoteAddress, message))
} else {
self.logger.log("[SDLUDPHole] decode message, get null", level: .warning)
}
} catch let err {
self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning)
}
}
func channelInactive(context: ChannelHandlerContext) {
self.eventContinuation.finish()
}
func errorCaught(context: ChannelHandlerContext, error: any Error) {
context.close(promise: nil)
}
// --MARK:
private func decode(buffer: inout ByteBuffer) throws -> SDLHoleInboundMessage? {
guard let type = buffer.readInteger(as: UInt8.self),
let packetType = SDLPacketType(rawValue: type),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch packetType {
case .data:
let dataPacket = try SDLData(serializedBytes: bytes)
return .data(dataPacket)
case .register:
let registerPacket = try SDLRegister(serializedBytes: bytes)
return .register(registerPacket)
case .registerAck:
let registerAck = try SDLRegisterAck(serializedBytes: bytes)
return .registerAck(registerAck)
case .stunProbeReply:
let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes)
return .stunProbeReply(stunProbeReply)
case .registerSuperAck:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else {
return nil
}
return .registerSuperAck(registerSuperAck)
case .registerSuperNak:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else {
return nil
}
return .registerSuperNak(registerSuperNak)
case .peerInfo:
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else {
return nil
}
return .peerInfo(peerInfo)
case .event:
guard let eventVal = buffer.readInteger(as: UInt8.self),
let event = SDLEventType(rawValue: eventVal),
let bytes = buffer.readBytes(length: buffer.readableBytes) else {
return nil
}
switch event {
case .natChanged:
guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else {
return nil
}
return .event(.natChanged(natChangedEvent))
case .sendRegister:
guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else {
return nil
}
return .event(.sendRegister(sendRegisterEvent))
case .networkShutdown:
guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else {
return nil
}
return .event(.networkShutdown(networkShutdownEvent))
}
default:
return nil
}
}
}
}