diff --git a/Tun/Punchnet/SDLContextActor.swift b/Tun/Punchnet/SDLContextActor.swift index c579ac4..1954199 100644 --- a/Tun/Punchnet/SDLContextActor.swift +++ b/Tun/Punchnet/SDLContextActor.swift @@ -25,6 +25,9 @@ actor SDLContextActor { // aes private var aesKey: Data? + // session token + private var sessionToken: Data? + // rsa的相关配置, public_key是本地生成的 nonisolated let rsaCipher: RSACipher @@ -159,7 +162,7 @@ actor SDLContextActor { self.udpHoleWorkers = nil // 启动udp服务器 - let udpHole = try SDLUDPHole(logger: SDLLogger.shared) + let udpHole = try SDLUDPHole() try udpHole.start() SDLLogger.shared.log("[SDLContext] udpHole started") self.udpHole = udpHole @@ -177,12 +180,10 @@ actor SDLContextActor { if Task.isCancelled { break } - SDLLogger.shared.log("[SDLContext] will do stunRequest22") await self.sendStunRequest() - SDLLogger.shared.log("[SDLContext] will do stunRequest44") } - SDLLogger.shared.log("[SDLContext] will do stunRequest55") + SDLLogger.shared.log("[SDLContext] pingTask cancel") } // 处理数据流 @@ -284,15 +285,19 @@ actor SDLContextActor { } private func sendStunRequest() { + guard let sessionToken = self.sessionToken else { + return + } + var stunRequest = SDLStunRequest() stunRequest.clientID = self.config.clientId stunRequest.networkID = self.config.networkAddress.networkId stunRequest.ip = self.config.networkAddress.ip stunRequest.mac = self.config.networkAddress.mac stunRequest.natType = UInt32(self.natType.rawValue) + stunRequest.sessionToken = sessionToken - SDLLogger.shared.log("[SDLContext] will send stun request") - + SDLLogger.shared.log("[SDLContext] send stun request: \(stunRequest)") if let stunData = try? stunRequest.serializedData() { let remoteAddress = self.config.stunSocketAddress self.udpHole?.send(type: .stunRequest, data: stunData, remoteAddress: remoteAddress) @@ -301,23 +306,20 @@ actor SDLContextActor { private func handleRegisterSuperAck(registerSuperAck: SDLRegisterSuperAck) async { // 需要对数据通过rsa的私钥解码 - let aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) + self.aesKey = try! self.rsaCipher.decode(data: Data(registerSuperAck.aesKey)) + self.sessionToken = registerSuperAck.sessionToken - SDLLogger.shared.log("[SDLContext] get registerSuperAck, aes_key len: \(aesKey.count)", level: .info) + SDLLogger.shared.log("[SDLContext] get registerSuperAck, aes_key len: \(self.aesKey!.count)", level: .info) // 服务器分配的tun网卡信息 do { let ipAddress = try await self.providerAdapter.setNetworkSettings(networkAddress: self.config.networkAddress, dnsServer: SDLDNSClient.Helper.dnsServer) SDLLogger.shared.log("[SDLContext] setNetworkSettings successed") self.noticeClient?.send(data: NoticeMessage.ipAdress(ip: ipAddress)) - SDLLogger.shared.log("[SDLContext] send ip successed") self.startReader() - SDLLogger.shared.log("[SDLContext] reader started") } catch let err { SDLLogger.shared.log("[SDLContext] setTunnelNetworkSettings get error: \(err)", level: .error) exit(-1) } - - self.aesKey = aesKey } private func handleRegisterSuperNak(nakPacket: SDLRegisterSuperNak) { @@ -341,6 +343,9 @@ actor SDLContextActor { private func handleEvent(event: SDLEvent) throws { switch event { + case .dropMacs(let dropMacsEvent): + SDLLogger.shared.log("[SDLContext] drop macs", level: .info) + () case .natChanged(let natChangedEvent): let dstMac = natChangedEvent.mac SDLLogger.shared.log("[SDLContext] natChangedEvent, dstMac: \(dstMac)", level: .info) @@ -356,7 +361,9 @@ actor SDLContextActor { register.dstMac = sendRegisterEvent.dstMac self.udpHole?.send(type: .register, data: try register.serializedData(), remoteAddress: remoteAddress) } - + case .refreshAuth(let refreshAuthEvent): + SDLLogger.shared.log("[SDLContext] refresh auth", level: .info) + () case .networkShutdown(let shutdownEvent): let alertNotice = NoticeMessage.alert(alert: shutdownEvent.message) self.noticeClient?.send(data: alertNotice) diff --git a/Tun/Punchnet/SDLMessage.swift b/Tun/Punchnet/SDLMessage.swift index a7f0165..64d047b 100644 --- a/Tun/Punchnet/SDLMessage.swift +++ b/Tun/Punchnet/SDLMessage.swift @@ -21,15 +21,9 @@ enum SDLPacketType: UInt8 { case queryInfo = 0x06 case peerInfo = 0x07 - case ping = 0x08 - case pong = 0x09 - // 事件类型 case event = 0x10 - // 流量统计 - case flowTracer = 0x15 - case register = 0x20 case registerAck = 0x21 @@ -113,13 +107,17 @@ enum SDLHoleSignal { // 命令类型 enum SDLEventType: UInt8 { + case dropMacs = 0x02 case natChanged = 0x03 case sendRegister = 0x04 + case refreshAuth = 0x05 case networkShutdown = 0xFF } enum SDLEvent { + case dropMacs(SDLDropMacsEvent) case natChanged(SDLNatChangedEvent) case sendRegister(SDLSendRegisterEvent) + case refreshAuth(SDLRefreshAuthEvent) case networkShutdown(SDLNetworkShutdownEvent) } diff --git a/Tun/Punchnet/SDLUDPHole.swift b/Tun/Punchnet/SDLUDPHole.swift index 08bfb10..cbc8c21 100644 --- a/Tun/Punchnet/SDLUDPHole.swift +++ b/Tun/Punchnet/SDLUDPHole.swift @@ -35,16 +35,13 @@ final class SDLUDPHole: ChannelInboundHandler { private var cont: CheckedContinuation? private var isReady: Bool = false - private let logger: SDLLogger - enum HoleEvent { case ready case closed } // 启动函数 - init(logger: SDLLogger) throws { - self.logger = logger + init() throws { (self.signalStream, self.signalContinuation) = AsyncStream.makeStream(of: (SocketAddress, SDLHoleSignal).self, bufferingPolicy: .unbounded) (self.dataStream, self.dataContinuation) = AsyncStream.makeStream(of: SDLData.self, bufferingPolicy: .unbounded) } @@ -57,7 +54,7 @@ final class SDLUDPHole: ChannelInboundHandler { } let channel = try bootstrap.bind(host: "0.0.0.0", port: 0).wait() - self.logger.log("[UDPHole] started", level: .debug) + SDLLogger.shared.log("[UDPHole] started", level: .debug) self.channel = channel } @@ -100,10 +97,10 @@ final class SDLUDPHole: ChannelInboundHandler { self.signalContinuation.yield((remoteAddress, signal)) } } else { - self.logger.log("[SDLUDPHole] decode message, get null", level: .warning) + SDLLogger.shared.log("[SDLUDPHole] decode message, get null", level: .warning) } } catch let err { - self.logger.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) + SDLLogger.shared.log("[SDLUDPHole] decode message, get error: \(err)", level: .warning) } } @@ -141,72 +138,102 @@ final class SDLUDPHole: ChannelInboundHandler { // --MARK: 编解码器 private func decode(buffer: inout ByteBuffer) throws -> SDLHoleMessage? { guard let type = buffer.readInteger(as: UInt8.self), - let packetType = SDLPacketType(rawValue: type), - let bytes = buffer.readBytes(length: buffer.readableBytes) else { + let packetType = SDLPacketType(rawValue: type) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 11") return nil } switch packetType { case .data: - let dataPacket = try SDLData(serializedBytes: bytes) + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let dataPacket = try? SDLData(serializedBytes: bytes) else { + return nil + } return .data(dataPacket) case .register: - let registerPacket = try SDLRegister(serializedBytes: bytes) + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerPacket = try? SDLRegister(serializedBytes: bytes) else { + return nil + } return .signal(.register(registerPacket)) case .registerAck: - let registerAck = try SDLRegisterAck(serializedBytes: bytes) + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let registerAck = try? SDLRegisterAck(serializedBytes: bytes) else { + return nil + } return .signal(.registerAck(registerAck)) case .stunProbeReply: - let stunProbeReply = try SDLStunProbeReply(serializedBytes: bytes) + guard let bytes = buffer.readBytes(length: buffer.readableBytes), + let stunProbeReply = try? SDLStunProbeReply(serializedBytes: bytes) else { + return nil + } return .signal(.stunProbeReply(stunProbeReply)) case .registerSuperAck: guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { + let registerSuperAck = try? SDLRegisterSuperAck(serializedBytes: bytes) else { return nil } return .signal(.registerSuperAck(registerSuperAck)) case .registerSuperNak: guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { + let registerSuperNak = try? SDLRegisterSuperNak(serializedBytes: bytes) else { return nil } return .signal(.registerSuperNak(registerSuperNak)) - case .peerInfo: guard let bytes = buffer.readBytes(length: buffer.readableBytes), - let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { + let peerInfo = try? SDLPeerInfo(serializedBytes: bytes) else { return nil } - return .signal(.peerInfo(peerInfo)) case .event: guard let eventVal = buffer.readInteger(as: UInt8.self), let event = SDLEventType(rawValue: eventVal), let bytes = buffer.readBytes(length: buffer.readableBytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 15") return nil } switch event { + case .dropMacs: + guard let dropMacsEvent = try? SDLDropMacsEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 16") + return nil + } + return .signal(.event(.dropMacs(dropMacsEvent))) + case .natChanged: guard let natChangedEvent = try? SDLNatChangedEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 16") return nil } return .signal(.event(.natChanged(natChangedEvent))) case .sendRegister: guard let sendRegisterEvent = try? SDLSendRegisterEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 17") return nil } return .signal(.event(.sendRegister(sendRegisterEvent))) + case .refreshAuth: + guard let refreshAuthEvent = try? SDLRefreshAuthEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 17") + return nil + } + return .signal(.event(.refreshAuth(refreshAuthEvent))) + case .networkShutdown: guard let networkShutdownEvent = try? SDLNetworkShutdownEvent(serializedBytes: bytes) else { + SDLLogger.shared.log("[SDLUDPHole] decode error 18") return nil } return .signal(.event(.networkShutdown(networkShutdownEvent))) } - default: + SDLLogger.shared.log("SDLUDPHole decode miss type: \(type)") + return nil } + } deinit {