fix context

This commit is contained in:
anlicheng 2026-02-02 12:07:29 +08:00
parent 352dff8e19
commit 57dd0d9538
6 changed files with 168 additions and 116 deletions

View File

@ -13,47 +13,41 @@ enum TunnelError: Error {
}
class PacketTunnelProvider: NEPacketTunnelProvider {
var context: SDLContext?
var contextSupervisor: SDLContextSupervisor?
private var rootTask: Task<Void, Error>?
override func startTunnel(options: [String: NSObject]?, completionHandler: @escaping (Error?) -> Void) {
let logger = SDLLogger(level: .debug)
// host: "192.168.0.101", port: 1265
guard let options, let config = SDLConfiguration.parse(options: options) else {
completionHandler(TunnelError.invalidConfiguration)
return
}
//
guard self.context == nil else {
guard self.contextSupervisor == nil else {
completionHandler(TunnelError.invalidContext)
return
}
//
let rsaCipher = try! CCRSACipher(keySize: 1024)
let aesChiper = CCAESChiper()
let logger = SDLLogger(level: .debug)
self.rootTask = Task {
do {
self.context = SDLContext(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger)
try await self.context?.start()
} catch let err {
logger.log("[PacketTunnelProvider] exit with error: \(err)")
exit(-1)
}
self.contextSupervisor = SDLContextSupervisor()
await self.contextSupervisor?.start(provider: self, config: config, rsaCipher: rsaCipher, aesCipher: aesChiper, logger: logger)
completionHandler(nil)
}
completionHandler(nil)
}
override func stopTunnel(with reason: NEProviderStopReason, completionHandler: @escaping () -> Void) {
// Add code here to start the process of stopping the tunnel.
self.rootTask?.cancel()
Task {
await self.context?.stop()
await self.contextSupervisor?.stop()
}
self.context = nil
self.contextSupervisor = nil
self.rootTask = nil
completionHandler()

View File

@ -0,0 +1,44 @@
//
// SDLContextSupervisor.swift
// Tun
//
// Created by on 2026/2/2.
//
import Foundation
import NetworkExtension
actor SDLContextSupervisor {
private var context: SDLContextActor?
private var tasks: [Task<Void, Never>] = []
public func start(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) async {
let context = SDLContextActor(provider: provider, config: config, rsaCipher: rsaCipher, aesCipher: aesCipher, logger: logger)
self.context = context
tasks.append(spawnLoop { try await context.startNoticeClient()})
tasks.append(spawnLoop { try await context.startUDPHole()})
tasks.append(spawnLoop { try await context.startDnsClient()})
tasks.append(spawnLoop { try await context.startMonitor()})
}
func stop() {
tasks.forEach {$0.cancel()}
tasks.removeAll()
}
private func spawnLoop(_ body: @escaping () async throws -> Void) -> Task<Void, Never> {
return Task.detached {
while !Task.isCancelled {
do {
try await body()
} catch is CancellationError {
break
} catch {
try? await Task.sleep(nanoseconds: 2_000_000_000)
}
}
}
}
}

View File

@ -13,7 +13,7 @@ import NIOCore
/*
1. rsa的加解密逻辑
*/
actor SDLContext {
actor SDLContextActor {
nonisolated let config: SDLConfiguration
// nat
@ -57,7 +57,6 @@ actor SDLContext {
public init(provider: NEPacketTunnelProvider, config: SDLConfiguration, rsaCipher: RSACipher, aesCipher: AESCipher, logger: SDLLogger) {
self.logger = logger
self.config = config
self.rsaCipher = rsaCipher
self.aesCipher = aesCipher
@ -67,29 +66,72 @@ actor SDLContext {
self.providerAdapter = SDLTunnelProviderAdapter(provider: provider, logger: logger)
}
public func start() async throws {
// udp
self.udpHole = try SDLUDPHole(logger: self.logger)
try self.udpHole?.start()
self.logger.log("[SDLContext] udpHole started")
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
self.dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
try self.dnsClient?.start()
self.logger.log("[SDLContext] dnsClient started")
public func startNoticeClient() async throws {
// noticeClient
self.noticeClient = try SDLNoticeClient(noticePort: self.config.noticePort, logger: self.logger)
try self.noticeClient?.start()
self.logger.log("[SDLContext] noticeClient started")
try await self.noticeClient?.waitClose()
}
public func startMonitor() async throws {
// monitor
self.monitor = SDLNetworkMonitor()
self.monitor?.start()
let monitor = SDLNetworkMonitor()
monitor.start()
self.logger.log("[SDLContext] monitor started")
self.monitor = monitor
for await event in monitor.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
self.logger.log("didNetworkPathChanged, nat type is: \(self.natType)", level: .info)
case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning)
}
}
}
public func startDnsClient() async throws {
// dns
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(self.config.remoteDnsServer, port: 15353)
let dnsClient = try await SDLDNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
let channel = try dnsClient.start()
self.logger.log("[SDLContext] dnsClient started")
self.dnsClient = dnsClient
try await withThrowingTaskGroup(of: Void.self) {group in
group.addTask {
//
for await packet in dnsClient.packetFlow {
try Task.checkCancellation()
let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
}
}
group.addTask {
try await channel.closeFuture.get()
}
try await group.next()
self.logger.log("[SDLContext] taskGroup cancel")
group.cancelAll()
}
}
public func startUDPHole() async throws {
// udp
let udpHole = try SDLUDPHole(logger: self.logger)
let channel = try udpHole.start()
self.logger.log("[SDLContext] udpHole started")
self.udpHole = udpHole
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await channel.closeFuture.get()
}
// UDP
group.addTask {
while true {
@ -102,86 +144,51 @@ actor SDLContext {
// event
group.addTask {
if let eventStream = await self.udpHole?.eventStream {
for try await event in eventStream {
try Task.checkCancellation()
switch event {
case .ready:
await self.handleUDPHoleReady()
case .closed:
()
}
for try await event in udpHole.eventStream {
try Task.checkCancellation()
switch event {
case .ready:
await self.handleUDPHoleReady()
case .closed:
()
}
}
}
//
group.addTask {
if let dataStream = await self.udpHole?.dataStream {
for try await data in dataStream {
try Task.checkCancellation()
Task {
try await self.handleData(data: data)
}
}
for try await data in udpHole.dataStream {
try Task.checkCancellation()
try await self.handleData(data: data)
}
}
// signal
group.addTask {
if let signalStream = await self.udpHole?.signalStream {
for try await(remoteAddress, signal) in signalStream {
try Task.checkCancellation()
Task {
switch signal {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
for try await(remoteAddress, signal) in udpHole.signalStream {
try Task.checkCancellation()
switch signal {
case .registerSuperAck(let registerSuperAck):
await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck)
case .registerSuperNak(let registerSuperNak):
await self.handleRegisterSuperNak(nakPacket: registerSuperNak)
case .peerInfo(let peerInfo):
await self.puncherActor?.handlePeerInfo(peerInfo: peerInfo)
case .event(let event):
try await self.handleEvent(event: event)
case .stunProbeReply(let probeReply):
await self.proberActor?.handleProbeReply(reply: probeReply)
case .register(let register):
try await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
}
}
}
// DNS
group.addTask {
if let packetFlow = await self.dnsClient?.packetFlow {
for await packet in packetFlow {
let nePacket = NEPacket(data: packet, protocolFamily: 2)
self.providerAdapter.writePackets(packets: [nePacket])
}
}
}
// Monitor
group.addTask {
for await event in await self.monitor!.eventStream {
switch event {
case .changed:
// nat
//self.natType = await self.getNatType()
self.logger.log("didNetworkPathChanged, nat type is: \(await self.natType)", level: .info)
case .unreachable:
self.logger.log("didNetworkPathUnreachable", level: .warning)
}
}
}
if let _ = try await group.next() {
self.logger.log("[SDLContext] taskGroup cancel")
group.cancelAll()
}
try await group.next()
group.cancelAll()
self.logger.log("[SDLContext] taskGroup cancel")
}
}
@ -409,11 +416,15 @@ actor SDLContext {
//
self.readTask = Task(priority: .high) {
repeat {
while true {
if Task.isCancelled {
return
}
let packets = await self.providerAdapter.readPackets()
let ipPackets = packets.compactMap { IPPacket($0) }
await self.batchProcessPackets(batchSize: 20, packets: ipPackets)
} while true
}
}
}

View File

@ -29,15 +29,18 @@ final class SDLDNSClient: ChannelInboundHandler {
(self.packetFlow, self.packetContinuation) = AsyncStream.makeStream(of: Data.self, bufferingPolicy: .unbounded)
}
func start() throws {
func start() throws -> Channel {
let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
channel.pipeline.addHandler(self)
}
self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[DNSClient] started", level: .debug)
self.channel = channel
return channel
}
// --MARK: ChannelInboundHandler delegate
@ -79,7 +82,6 @@ final class SDLDNSClient: ChannelInboundHandler {
}
extension SDLDNSClient {
struct Helper {
static let dnsServer: String = "100.100.100.100"
// dns

View File

@ -21,7 +21,7 @@ import NIOPosix
// sn-server
final class SDLNoticeClient {
private let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
private var channel: Channel?
private var channel: Channel
private let logger: SDLLogger
private let noticePort: Int
@ -29,9 +29,7 @@ final class SDLNoticeClient {
init(noticePort: Int, logger: SDLLogger) throws {
self.logger = logger
self.noticePort = noticePort
}
func start() throws {
let bootstrap = DatagramBootstrap(group: self.group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
@ -44,19 +42,19 @@ final class SDLNoticeClient {
//
func send(data: Data) {
guard let channel = self.channel else {
return
}
if let remoteAddress = try? SocketAddress(ipAddress: "127.0.0.1", port: noticePort) {
let buf = channel.allocator.buffer(bytes: data)
let envelope = AddressedEnvelope<ByteBuffer>(remoteAddress: remoteAddress, data: buf)
channel.eventLoop.execute {
channel.writeAndFlush(envelope, promise: nil)
self.channel.eventLoop.execute {
self.channel.writeAndFlush(envelope, promise: nil)
}
}
}
func waitClose() async throws {
try await self.channel.closeFuture.get()
}
deinit {
try? self.group.syncShutdownGracefully()
}

View File

@ -49,15 +49,18 @@ final class SDLUDPHole: ChannelInboundHandler {
(self.eventStream, self.eventContinuation) = AsyncStream.makeStream(of: HoleEvent.self, bufferingPolicy: .unbounded)
}
func start() throws {
func start() throws -> Channel {
let bootstrap = DatagramBootstrap(group: group)
.channelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.channelInitializer { channel in
channel.pipeline.addHandler(self)
}
self.channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait()
self.logger.log("[UDPHole] started", level: .debug)
self.channel = channel
return channel
}
// --MARK: ChannelInboundHandler delegate