diff --git a/Tun/Punchnet/Actors/SDLContextActor.swift b/Tun/Punchnet/Actors/SDLContextActor.swift index f696110..0eaed43 100644 --- a/Tun/Punchnet/Actors/SDLContextActor.swift +++ b/Tun/Punchnet/Actors/SDLContextActor.swift @@ -131,18 +131,41 @@ actor SDLContextActor { } private func startQUICClient() async throws -> SDLQUICClient { + self.quicWorker?.cancel() + self.quicClient?.stop() + // 启动monitor let quicClient = SDLQUICClient(host: "118.178.229.213", port: 1365) quicClient.start() // 等待quic准备好 try await quicClient.waitReady() + try await Task.sleep(for: .seconds(0.2)) SDLLogger.shared.log("[SDLContext] start quic client ready") - + self.quicWorker = Task.detached { - for await message in quicClient.receiveStream(maxLen: 86400) { - SDLLogger.shared.log("[SDLContext] quic client receive message: \(message)") + let reader = quicClient.getReader() + while let frame = try? await reader.next() { + if let message = SDLQUICCodec.decode(frame: frame) { + switch message { + case .welcome(let welcome): + SDLLogger.shared.log("[SDLContext] quic welcome: \(welcome)") + case .registerSuperAck(let registerSuperAck): + await self.handleRegisterSuperAck(registerSuperAck: registerSuperAck) + case .registerSuperNak(let registerSuperNak): + await self.handleRegisterSuperNak(nakPacket: registerSuperNak) + case .peerInfo(let peerInfo): + SDLLogger.shared.log("[SDLContext] peer message: \(peerInfo)") + case .event(let event): + await self.handleEvent(event: event) + case .policyReponse(let policyResponse): + // 处理权限的请求问题 + await self.identifyStore.apply(policyResponse: policyResponse) + } + + } } + } self.quicClient = quicClient @@ -258,11 +281,6 @@ actor SDLContextActor { await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck) case .data(let data): try? await self.handleData(data: data) - -// case .policyReponse(let policyResponse): -// SDLLogger.shared.log("[SDLContext] get a policyResponse: \(policyResponse.totalNum) of \(policyResponse.index), bytes: \(policyResponse.rules.count)") -// // 处理权限的请求问题 -// await self.identifyStore.apply(policyResponse: policyResponse) } } @@ -278,8 +296,6 @@ actor SDLContextActor { self.setNatType(natType: natType) SDLLogger.shared.log("[SDLContext] nat_type is: \(natType)") } - - return udpHole } @@ -367,7 +383,7 @@ actor SDLContextActor { } - private func handleEvent(event: SDLEvent) async throws { + private func handleEvent(event: SDLEvent) async { switch event { // case .dropMacs(let dropMacsEvent): // SDLLogger.shared.log("[SDLContext] drop macs", level: .info) @@ -385,7 +401,7 @@ actor SDLContextActor { register.networkID = self.config.networkAddress.networkId register.srcMac = self.config.networkAddress.mac register.dstMac = sendRegisterEvent.dstMac - self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) + self.udpHole?.send(type: .register, data: try! register.serializedData(), remoteAddress: remoteAddress) } case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) @@ -576,10 +592,12 @@ actor SDLContextActor { // 处理读取的每个数据包 private func dealPacket(packet: IPPacket) async { let networkAddr = self.config.networkAddress + + // TODO if SDLDNSClient.Helper.isDnsRequestPacket(ipPacket: packet) { let destIp = packet.header.destination_ip SDLLogger.shared.log("[DNSQuery] destIp: \(destIp), int: \(packet.header.destination.asIpAddress())", level: .debug) - self.dnsClient?.forward(ipPacket: packet) + //self.dnsClient?.forward(ipPacket: packet) return } diff --git a/Tun/Punchnet/Actors/SDLQuicClient.swift b/Tun/Punchnet/Actors/SDLQuicClient.swift index 38c1878..71778c7 100644 --- a/Tun/Punchnet/Actors/SDLQuicClient.swift +++ b/Tun/Punchnet/Actors/SDLQuicClient.swift @@ -15,15 +15,15 @@ enum SDLQUICError: Error { case connectionCancelled case timeout case decodeError(String) + case packetTooLarge } final class SDLQUICClient { private let transport: SDLQUICTransport - private let allocator = ByteBufferAllocator() private let queue = DispatchQueue(label: "com.sdl.QUICClient.queue") // 专用队列保证线程安全 - private var closeCont: CheckedContinuation? - private var readyCont: CheckedContinuation? + private let (closeStream, closeCont) = AsyncStream.makeStream(of: Void.self) + private let (readyStream, readyCont) = AsyncStream.makeStream(of: Void.self) init(host: String, port: UInt16) { self.transport = SDLQUICTransport(host: host, port: port) @@ -33,23 +33,17 @@ final class SDLQUICClient { self.transport.start(queue: self.queue) { event in switch event { case .ready: - self.readyCont?.resume() - self.readyCont = nil - case .failed(_): - self.closeCont?.resume() - self.closeCont = nil - case .cancelled: - self.closeCont?.resume() - self.closeCont = nil + self.readyCont.yield() + self.readyCont.finish() + case .failed(_), .cancelled: + self.closeCont.yield() + self.closeCont.finish() } } } - func receiveStream(maxLen: Int) -> AsyncCompactMapSequence, SDLQUICInboundMessage> { - return transport.receiveMessageStream(maxLen: maxLen).compactMap { data in - var buf = self.allocator.buffer(bytes: data) - return try? QUICCodec.decode(buffer: &buf) - } + func getReader() -> SDLQUICReader { + return transport.getReader() } func send(type: SDLPacketType, data: Data) { @@ -57,23 +51,14 @@ final class SDLQUICClient { } func waitReady() async throws { - return try await withCheckedThrowingContinuation { cont in - self.readyCont = cont - } + for await _ in readyStream {} } func waitClose() async { - return await withCheckedContinuation { cont in - self.closeCont = cont - } + for await _ in closeStream {} } - deinit { - self.readyCont?.resume(throwing: SDLQUICError.connectionCancelled) - self.readyCont = nil - self.closeCont?.resume() - self.closeCont = nil - + func stop() { self.transport.stop() } @@ -89,7 +74,19 @@ final class SDLQUICTransport { private let connection: NWConnection init(host: String, port: UInt16) { - let params = NWParameters(quic: .init(alpn: ["punchnet/1.0"])) + let options = NWProtocolQUIC.Options(alpn: ["punchnet/1.0"]) + + // TODO 这里设置证书的校验逻辑 + sec_protocol_options_set_verify_block( + options.securityProtocolOptions, + { metadata, trust, complete in + // 你可以自己决定是否信任 + complete(true) // true = 接受证书 + }, + DispatchQueue.global() + ) + + let params = NWParameters(quic: options) self.connection = NWConnection(host: .init(host), port: .init(rawValue: port)!, using: params) } @@ -107,56 +104,13 @@ final class SDLQUICTransport { connection.start(queue: queue) } - func receiveMessageStream(maxLen: Int) -> AsyncStream { - let connection = self.connection - - return AsyncStream { continuation in - var buffer = Data() - - func tryParse() { - while true { - // 至少要有长度 - guard buffer.count >= 2 else { - return - } - - let len0 = UInt16(bigEndian: buffer.withUnsafeBytes { $0.load(as: UInt16.self) }) - let len = Int(len0) - - // 数据不够一个完整包 - guard buffer.count >= 2 + len else { - return - } - - // 取 body - let body = buffer.subdata(in: 2 ..< 2 + len) - continuation.yield(body) - - // 移除已消费 - buffer.removeSubrange(0 ..< 2 + len) - } - } - - func loopReceive() { - connection.receive(minimumIncompleteLength: 1, maximumLength: maxLen) { data, _, _, error in - if let data, !data.isEmpty { - buffer.append(data) - tryParse() - } - if error == nil { - loopReceive() - } else { - continuation.finish() - } - } - } - - loopReceive() - } + func getReader() -> SDLQUICReader { + return SDLQUICReader(connection: self.connection) } func send(type: SDLPacketType, data: Data) { - var len = UInt16(data.count).bigEndian + var len = UInt16(data.count + 1).bigEndian + var packet = Data(Data(bytes: &len, count: 2)) packet.append(type.rawValue) packet.append(data) @@ -170,82 +124,194 @@ final class SDLQUICTransport { } -extension SDLQUICClient { +actor SDLQUICReader: AsyncIteratorProtocol { - struct QUICCodec { - // --MARK: 编解码器 - public static func decode(buffer: inout ByteBuffer) throws -> SDLQUICInboundMessage? { - guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type) else { - return nil + typealias Element = ByteBuffer + + private let allocator = ByteBufferAllocator() + private var buffer: ByteBuffer + // 用来缓存包,有可能一次读取到多个包 + private var packets: [ByteBuffer] = [] + // 单个包最大64K + private let maxPacketSize: Int + // 最大缓冲区区为2M + private let maxBufferSize: Int + + // 是否已经读取完成 + private var isComplete: Bool = false + + private let connection: NWConnection + + init(connection: NWConnection, maxPacketSize: Int = 64 * 1024, maxBufferSize: Int = 2 * 1024 * 1024) { + self.connection = connection + self.maxBufferSize = maxBufferSize + self.maxPacketSize = maxPacketSize + self.buffer = allocator.buffer(capacity: maxBufferSize) + } + + func next() async throws -> ByteBuffer? { + // 如果还有包 + if !self.packets.isEmpty { + return self.packets.removeFirst() + } + + // 尝试读取,并返回 + self.packets = try await self.readPacket() + if !self.packets.isEmpty { + return self.packets.removeFirst() + } else { + return nil + } + } + + private func readPacket() async throws -> [ByteBuffer] { + while true { + if self.isComplete { + return try parseFrames() } - switch packetType { - case .welcome: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let welcome = try? SDLWelcome(serializedBytes: bytes) else { - return nil + let (isComplete, data) = try await readOnce() + self.isComplete = isComplete + + if !data.isEmpty { + buffer.writeBytes(data) + // 尝试解析出完整的包 + let packets = try parseFrames() + if !packets.isEmpty { + return packets } - return .welcome(welcome) - - case .registerSuperAck: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { - return nil - } - return .registerSuperAck(registerSuperAck) - case .registerSuperNak: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { - return nil - } - return .registerSuperNak(registerSuperNak) - case .peerInfo: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { - return nil - } - return .peerInfo(peerInfo) - case .policyResponse: - guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else { - return nil - } - return .policyReponse(policyResponse) - case .event: - guard let eventVal = buffer.readInteger(as: UInt8.self), - let event = SDLEventType(rawValue: eventVal), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { - SDLLogger.shared.log("[SDLUDPHole] decode error 15") - return nil + } + } + } + + // 尝试解析数据 + private func parseFrames() throws -> [ByteBuffer] { + guard buffer.readableBytes >= 2 else { + return [] + } + + var frames: [ByteBuffer] = [] + while true { + guard let len = buffer.getInteger(at: buffer.readerIndex, endianness: .big, as: UInt16.self) else { + break + } + + if len > self.maxPacketSize { + throw SDLQUICError.packetTooLarge + } + + guard buffer.readableBytes >= len + 2 else { + break + } + + buffer.moveReaderIndex(forwardBy: 2) + if let buf = buffer.readSlice(length: Int(len)) { + frames.append(buf) + } + } + + if buffer.readerIndex > maxBufferSize / 10 * 6 { + buffer.discardReadBytes() + } + + return frames + } + + // 读取一次数据 + private func readOnce() async throws -> (Bool, Data) { + return try await withCheckedThrowingContinuation { cont in + connection.receive(minimumIncompleteLength: 1, maximumLength: maxPacketSize) { data, _, isComplete, error in + if let error { + cont.resume(throwing: error) + return } - switch event { - case .natChanged: - guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { - SDLLogger.shared.log("[SDLUDPHole] decode error 16") - return nil - } - return .event(.natChanged(natChangedEvent)) - case .sendRegister: - guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { - SDLLogger.shared.log("[SDLUDPHole] decode error 17") - return nil - } - return .event(.sendRegister(sendRegisterEvent)) - case .networkShutdown: - guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { - SDLLogger.shared.log("[SDLUDPHole] decode error 18") - return nil - } - return .event(.networkShutdown(networkShutdownEvent)) + if let data, !data.isEmpty { + SDLLogger.shared.log("[SDLQUICTransport] read bytes: \(data.count)") + cont.resume(returning: (isComplete, data)) + } else { + cont.resume(returning: (isComplete, Data())) } - default: - SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)") - - return nil } } } } + +struct SDLQUICCodec { + // --MARK: 编解码器 + public static func decode(frame: ByteBuffer) -> SDLQUICInboundMessage? { + var buffer = frame + guard let type = buffer.readInteger(as: UInt8.self), + let packetType = SDLPacketType(rawValue: type) else { + return nil + } + + switch packetType { + case .welcome: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let welcome = try? SDLWelcome(serializedBytes: bytes) else { + return nil + } + return .welcome(welcome) + + case .registerSuperAck: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + return nil + } + return .registerSuperAck(registerSuperAck) + case .registerSuperNak: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + return nil + } + return .registerSuperNak(registerSuperNak) + case .peerInfo: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + return nil + } + return .peerInfo(peerInfo) + case .policyResponse: + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let policyResponse = try? SDLPolicyResponse(serializedBytes: bytes) else { + return nil + } + return .policyReponse(policyResponse) + case .event: + guard let eventVal = buffer.readInteger(as: UInt8.self), + let event = SDLEventType(rawValue: eventVal), + let bytes = buffer.readBytes(length: buffer.readableBytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 15") + return nil + } + + switch event { + case .natChanged: + guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 16") + return nil + } + return .event(.natChanged(natChangedEvent)) + case .sendRegister: + guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 17") + return nil + } + return .event(.sendRegister(sendRegisterEvent)) + case .networkShutdown: + guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 18") + return nil + } + return .event(.networkShutdown(networkShutdownEvent)) + } + default: + SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)") + + return nil + } + } +} + diff --git a/tracelog.sh b/tracelog.sh index 4bacfda..528edfd 100755 --- a/tracelog.sh +++ b/tracelog.sh @@ -1,3 +1,3 @@ #! /bin/sh -log stream --predicate 'subsystem == "com.jihe.punchnet"' --info +log stream --predicate 'subsystem == "com.jihe.punchnet"' --info --style compact