简化逻辑
This commit is contained in:
parent
8ebfdd1edf
commit
66721ce7b1
@ -9,58 +9,20 @@ import Network
|
|||||||
|
|
||||||
actor SDLIPV6AssistClient {
|
actor SDLIPV6AssistClient {
|
||||||
|
|
||||||
struct Packet: Sendable {
|
|
||||||
enum IPVersion: UInt8, Sendable {
|
|
||||||
case ipv4 = 4
|
|
||||||
case ipv6 = 6
|
|
||||||
|
|
||||||
var protocolFamily: Int32 {
|
|
||||||
switch self {
|
|
||||||
case .ipv4:
|
|
||||||
return 2
|
|
||||||
case .ipv6:
|
|
||||||
return 30
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let packetId: UInt32
|
|
||||||
let ipPacketData: Data
|
|
||||||
let ipVersion: IPVersion
|
|
||||||
|
|
||||||
var protocolFamily: Int32 {
|
|
||||||
return self.ipVersion.protocolFamily
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private enum State {
|
private enum State {
|
||||||
case idle
|
case idle
|
||||||
case running
|
case running
|
||||||
case stopped
|
case stopped
|
||||||
}
|
}
|
||||||
|
|
||||||
private struct PendingPacket: Sendable {
|
|
||||||
let packetId: UInt32
|
|
||||||
let ipVersion: Packet.IPVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
private enum PacketParseError: Error {
|
|
||||||
case packetTooShort
|
|
||||||
case unmatchedPacketId(UInt32)
|
|
||||||
case invalidIPVersion(UInt8)
|
|
||||||
case unsupportedIPVersion(UInt8)
|
|
||||||
}
|
|
||||||
|
|
||||||
private var state: State = .idle
|
private var state: State = .idle
|
||||||
private var connection: NWConnection?
|
private var connection: NWConnection?
|
||||||
private var receiveTask: Task<Void, Never>?
|
private var receiveTask: Task<Void, Never>?
|
||||||
private let assistServerAddress: NWEndpoint
|
private let assistServerAddress: NWEndpoint
|
||||||
private var idGenerator: SDLIdGenerator
|
|
||||||
private var pendingPackets: [UInt32: PendingPacket] = [:]
|
|
||||||
|
|
||||||
// 用于对外输出收到的原始 IP 响应包
|
// 用于对外输出收到的原始 IP 响应包
|
||||||
let packetFlow: AsyncStream<Packet>
|
let packetFlow: AsyncStream<SDLV6AssistProbeReply>
|
||||||
private let packetContinuation: AsyncStream<Packet>.Continuation
|
private let packetContinuation: AsyncStream<SDLV6AssistProbeReply>.Continuation
|
||||||
private var didFinishPacketFlow = false
|
private var didFinishPacketFlow = false
|
||||||
|
|
||||||
// 用来处理关闭事件
|
// 用来处理关闭事件
|
||||||
@ -73,7 +35,7 @@ actor SDLIPV6AssistClient {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
let (packetStream, packetContinuation) = AsyncStream.makeStream(of: Packet.self, bufferingPolicy: .bufferingNewest(256))
|
let (packetStream, packetContinuation) = AsyncStream.makeStream(of: SDLV6AssistProbeReply.self, bufferingPolicy: .bufferingNewest(256))
|
||||||
self.packetFlow = packetStream
|
self.packetFlow = packetStream
|
||||||
self.packetContinuation = packetContinuation
|
self.packetContinuation = packetContinuation
|
||||||
|
|
||||||
@ -82,7 +44,6 @@ actor SDLIPV6AssistClient {
|
|||||||
self.closeContinuation = closeContinuation
|
self.closeContinuation = closeContinuation
|
||||||
|
|
||||||
self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port)))
|
self.assistServerAddress = .hostPort(host: NWEndpoint.Host(host), port: NWEndpoint.Port(integerLiteral: UInt16(assistServerInfo.port)))
|
||||||
self.idGenerator = SDLIdGenerator(seed: UInt32.random(in: 1..<UInt32.max))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func start() {
|
func start() {
|
||||||
@ -145,33 +106,16 @@ actor SDLIPV6AssistClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 发送 DNS 查询包(由 TUN 拦截到的原始 IP 包数据)
|
func probe() {
|
||||||
/// 当前实现会在原始 IP 包前增加 4 字节的 packetId,用于建立请求和响应的对应关系。
|
|
||||||
@discardableResult
|
|
||||||
func forward(ipPacketData: Data) -> UInt32? {
|
|
||||||
guard case .running = self.state, let connection = self.connection, connection.state == .ready else {
|
guard case .running = self.state, let connection = self.connection, connection.state == .ready else {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
var assistProbe = SDLV6AssistProbe()
|
||||||
|
assistProbe.assistToken = Data()
|
||||||
|
|
||||||
let ipVersion: Packet.IPVersion
|
if let data = try? assistProbe.serializedData() {
|
||||||
do {
|
connection.send(content: data, completion: .contentProcessed { _ in})
|
||||||
ipVersion = try self.parseIPVersion(packetData: ipPacketData)
|
|
||||||
} catch {
|
|
||||||
SDLLogger.log("[SDLIPV6AssistClient] Invalid outbound packet: \(error)", for: .debug)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let packetId = self.idGenerator.nextId()
|
|
||||||
self.pendingPackets[packetId] = .init(packetId: packetId, ipVersion: ipVersion)
|
|
||||||
|
|
||||||
let outboundPacket = Data(components: Data(uint32: packetId), ipPacketData)
|
|
||||||
connection.send(content: outboundPacket, completion: .contentProcessed { [weak self] error in
|
|
||||||
Task {
|
|
||||||
await self?.handleSendCompletion(packetId: packetId, error: error)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return packetId
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func stop() {
|
func stop() {
|
||||||
@ -182,7 +126,6 @@ actor SDLIPV6AssistClient {
|
|||||||
self.state = .stopped
|
self.state = .stopped
|
||||||
self.receiveTask?.cancel()
|
self.receiveTask?.cancel()
|
||||||
self.receiveTask = nil
|
self.receiveTask = nil
|
||||||
self.pendingPackets.removeAll()
|
|
||||||
self.connection?.cancel()
|
self.connection?.cancel()
|
||||||
self.connection = nil
|
self.connection = nil
|
||||||
self.finishPacketFlowIfNeeded()
|
self.finishPacketFlowIfNeeded()
|
||||||
@ -232,24 +175,13 @@ actor SDLIPV6AssistClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
do {
|
do {
|
||||||
let packet = try self.parseInboundPacket(data)
|
let packet = try SDLV6AssistProbeReply(serializedBytes: data)
|
||||||
self.packetContinuation.yield(packet)
|
self.packetContinuation.yield(packet)
|
||||||
} catch {
|
} catch {
|
||||||
SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug)
|
SDLLogger.log("[SDLIPV6AssistClient] Receive error: \(error)", for: .debug)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func handleSendCompletion(packetId: UInt32, error: Error?) {
|
|
||||||
guard case .running = self.state else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if let error {
|
|
||||||
self.pendingPackets.removeValue(forKey: packetId)
|
|
||||||
SDLLogger.log("[SDLIPV6AssistClient] Send error: \(error), packetId: \(packetId)", for: .debug)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private func didFinishReceiving(for connection: NWConnection) {
|
private func didFinishReceiving(for connection: NWConnection) {
|
||||||
guard case .running = self.state else {
|
guard case .running = self.state else {
|
||||||
return
|
return
|
||||||
@ -262,42 +194,6 @@ actor SDLIPV6AssistClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private func parseInboundPacket(_ data: Data) throws -> Packet {
|
|
||||||
guard data.count > 4 else {
|
|
||||||
throw PacketParseError.packetTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
let packetId = UInt32(data: Data(data.prefix(4)))
|
|
||||||
guard let pendingPacket = self.pendingPackets.removeValue(forKey: packetId) else {
|
|
||||||
throw PacketParseError.unmatchedPacketId(packetId)
|
|
||||||
}
|
|
||||||
|
|
||||||
let ipPacketData = Data(data.dropFirst(4))
|
|
||||||
let ipVersion = try self.parseIPVersion(packetData: ipPacketData)
|
|
||||||
|
|
||||||
if ipVersion != pendingPacket.ipVersion {
|
|
||||||
SDLLogger.log("[SDLIPV6AssistClient] packet version mismatch, packetId: \(packetId), request: \(pendingPacket.ipVersion.rawValue), response: \(ipVersion.rawValue)", for: .debug)
|
|
||||||
}
|
|
||||||
|
|
||||||
return .init(packetId: pendingPacket.packetId, ipPacketData: ipPacketData, ipVersion: ipVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
private func parseIPVersion(packetData: Data) throws -> Packet.IPVersion {
|
|
||||||
guard let firstByte = packetData.first else {
|
|
||||||
throw PacketParseError.packetTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
let rawVersion = firstByte >> 4
|
|
||||||
guard let ipVersion = Packet.IPVersion(rawValue: rawVersion) else {
|
|
||||||
throw PacketParseError.invalidIPVersion(rawVersion)
|
|
||||||
}
|
|
||||||
guard ipVersion == .ipv6 else {
|
|
||||||
throw PacketParseError.unsupportedIPVersion(rawVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
private func finishPacketFlowIfNeeded() {
|
private func finishPacketFlowIfNeeded() {
|
||||||
guard !self.didFinishPacketFlow else {
|
guard !self.didFinishPacketFlow else {
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user