216 lines
7.1 KiB
Swift
216 lines
7.1 KiB
Swift
//
|
||
// SDLanServer.swift
|
||
// Tun
|
||
//
|
||
// Created by 安礼成 on 2024/1/31.
|
||
//
|
||
|
||
import Foundation
|
||
import NIOCore
|
||
import NIOPosix
|
||
import SwiftProtobuf
|
||
|
||
// 处理和sn-server服务器之间的通讯
|
||
final class SDLUDPHole: ChannelInboundHandler {
|
||
typealias InboundIn = AddressedEnvelope<ByteBuffer>
|
||
|
||
private enum State: Equatable {
|
||
case idle
|
||
case ready
|
||
case stopping
|
||
case stopped
|
||
}
|
||
|
||
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
|
||
private var channel: Channel?
|
||
private var closeFuture: EventLoopFuture<Void>?
|
||
private var state: State = .idle
|
||
private var didFinishMessageStream: Bool = false
|
||
|
||
public let messageStream: AsyncStream<(SocketAddress, SDLHoleMessage)>
|
||
private let messageContinuation: AsyncStream<(SocketAddress, SDLHoleMessage)>.Continuation
|
||
|
||
// 启动函数
|
||
init() throws {
|
||
let (stream, continuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleMessage).self, bufferingPolicy: .bufferingNewest(2048))
|
||
self.messageStream = stream
|
||
self.messageContinuation = continuation
|
||
}
|
||
|
||
func start() throws -> SocketAddress {
|
||
switch self.state {
|
||
case .ready:
|
||
guard let channel = self.channel else {
|
||
preconditionFailure("SDLUDPHole is ready but channel is nil")
|
||
}
|
||
precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind")
|
||
return channel.localAddress!
|
||
case .stopping, .stopped:
|
||
preconditionFailure("SDLUDPHole cannot be restarted after stop")
|
||
case .idle:
|
||
break
|
||
}
|
||
|
||
let bootstrap = DatagramBootstrap(group: group)
|
||
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
|
||
.channelInitializer { channel in
|
||
channel.pipeline.addHandler(self)
|
||
}
|
||
|
||
// 绑定到IPv6通配地址,依赖SwiftNIO创建dual-stack socket,同时接收IPv4/IPv6流量
|
||
let channel = try bootstrap.bind(host: "::", port: 0).wait()
|
||
self.channel = channel
|
||
self.closeFuture = channel.closeFuture
|
||
self.state = .ready
|
||
precondition(channel.localAddress != nil, "UDP channel has no localAddress after bind")
|
||
|
||
return channel.localAddress!
|
||
}
|
||
|
||
func waitClose() async throws {
|
||
switch self.state {
|
||
case .idle:
|
||
return
|
||
case .ready, .stopping, .stopped:
|
||
guard let closeFuture = self.closeFuture else {
|
||
return
|
||
}
|
||
try await closeFuture.get()
|
||
}
|
||
}
|
||
|
||
func stop() {
|
||
switch self.state {
|
||
case .stopping, .stopped:
|
||
return
|
||
case .idle:
|
||
self.state = .stopped
|
||
self.finishMessageStream()
|
||
return
|
||
case .ready:
|
||
self.state = .stopping
|
||
}
|
||
|
||
self.finishMessageStream()
|
||
self.channel?.close(promise: nil)
|
||
}
|
||
|
||
// --MARK: ChannelInboundHandler delegate
|
||
|
||
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
||
guard case .ready = self.state else {
|
||
return
|
||
}
|
||
|
||
let envelope = unwrapInboundIn(data)
|
||
|
||
var buffer = envelope.data
|
||
let remoteAddress = envelope.remoteAddress
|
||
|
||
if let rawBytes = buffer.getBytes(at: buffer.readerIndex, length: buffer.readableBytes) {
|
||
SDLLogger.log("[SDLUDPHole] get raw bytes: \(rawBytes.count), from: \(remoteAddress)", for: .debug)
|
||
}
|
||
|
||
do {
|
||
if let message = try decode(buffer: &buffer) {
|
||
self.messageContinuation.yield((remoteAddress, message))
|
||
} else {
|
||
SDLLogger.log("[SDLUDPHole] decode message, get null", for: .debug)
|
||
}
|
||
} catch let err {
|
||
SDLLogger.log("[SDLUDPHole] decode message, get error: \(err)", for: .debug)
|
||
}
|
||
}
|
||
|
||
func channelInactive(context: ChannelHandlerContext) {
|
||
self.finishMessageStream()
|
||
self.channel = nil
|
||
self.state = .stopped
|
||
}
|
||
|
||
func errorCaught(context: ChannelHandlerContext, error: any Error) {
|
||
SDLLogger.log("[SDLUDPHole] channel error: \(error)", for: .debug)
|
||
self.finishMessageStream()
|
||
if self.state != .stopped {
|
||
self.state = .stopping
|
||
}
|
||
context.close(promise: nil)
|
||
}
|
||
|
||
// MARK: 处理写入逻辑
|
||
func send(type: SDLPacketType, data: Data, remoteAddress: SocketAddress) {
|
||
guard case .ready = self.state, let channel = self.channel else {
|
||
return
|
||
}
|
||
|
||
var buffer = channel.allocator.buffer(capacity: data.count + 1)
|
||
buffer.writeBytes([type.rawValue])
|
||
buffer.writeBytes(data)
|
||
|
||
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buffer)
|
||
_ = channel.eventLoop.submit {
|
||
channel.writeAndFlush(envelope, promise: nil)
|
||
}
|
||
}
|
||
|
||
// --MARK: 编解码器
|
||
private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? {
|
||
guard let type = buffer.readInteger(as: UInt8.self),
|
||
let packetType = SDLPacketType(rawValue: type) else {
|
||
return nil
|
||
}
|
||
|
||
switch packetType {
|
||
case .data:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let dataPacket = try? SDLData(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .data(dataPacket)
|
||
case .register:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerPacket = try? SDLRegister(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .register(registerPacket)
|
||
case .registerAck:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .registerAck(registerAck)
|
||
case .stunProbeReply:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .stunProbeReply(stunProbeReply)
|
||
case .stunReply:
|
||
guard let bytes = buffer.readBytes(length: buffer.readableBytes),
|
||
let stunReply = try? SDLStunReply(serializedBytes: bytes) else {
|
||
return nil
|
||
}
|
||
return .stunReply(stunReply)
|
||
default:
|
||
SDLLogger.log("SDLUDPHole decode miss type: \(type)", for: .debug)
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
private func finishMessageStream() {
|
||
guard !self.didFinishMessageStream else {
|
||
return
|
||
}
|
||
|
||
self.didFinishMessageStream = true
|
||
self.messageContinuation.finish()
|
||
}
|
||
|
||
deinit {
|
||
self.stop()
|
||
try? self.group.syncShutdownGracefully()
|
||
}
|
||
|
||
}
|