增加对域名的支持

This commit is contained in:
anlicheng 2025-12-13 17:44:05 +08:00
parent 7ca620eca7
commit 06ba79bf83
9 changed files with 84 additions and 131 deletions

View File

@ -39,6 +39,7 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
let token = options["token"] as! String
let networkCode = options["network_code"] as! String
let clientId = options["client_id"] as! String
let remoteDnsServer = options["remote_dns_server"] as! String
let stunServers = stunServersStr.split(separator: ";").compactMap { server -> SDLConfiguration.StunServer? in
let parts = server.split(separator: ":", maxSplits: 2)
@ -69,7 +70,8 @@ class PacketTunnelProvider: NEPacketTunnelProvider {
clientId: clientId,
noticePort: noticePort,
token: token,
networkCode: networkCode)
networkCode: networkCode,
remoteDnsServer: remoteDnsServer)
//
let rsaCipher = try! CCRSACipher(keySize: 1024)
let aesChiper = CCAESChiper()

View File

@ -100,3 +100,18 @@ actor DNSClient {
}
}
extension DNSClient {
struct Helper {
static let dnsServer: String = "100.100.100.100"
// dns
static let dnsDestIpAddr: UInt32 = 1684300900
// dns
static func isDnsRequestPacket(ipPacket: IPPacket) -> Bool {
return ipPacket.header.destination == dnsDestIpAddr
}
}
}

View File

@ -1,119 +0,0 @@
//
// DNSUtil.swift
// punchnet
//
// Created by on 2025/12/9.
//
import Foundation
import Network
struct DNSUtil {
static let dnsServers: [String] = ["100.100.100.100"]
// dns
static let dnsDestIpAddr: UInt32 = 1684300900
// dns
static func isDnsRequestPacket(ipPacket: IPPacket) -> Bool {
return ipPacket.header.destination == dnsDestIpAddr
}
// // DNS Header
// struct DNSHeader {
// var id: UInt16
// var flags: UInt16
// var qdCount: UInt16
// var anCount: UInt16
// var nsCount: UInt16
// var arCount: UInt16
// }
//
// // DNS Question
// struct DNSQuestion {
// var name: String
// var type: UInt16
// var qclass: UInt16
// }
//
// // DNS Label
// func parseName(from data: Data, offset: inout Int) -> String {
// var labels: [String] = []
// while true {
// let length = Int(data[offset])
// offset += 1
// if length == 0 {
// break
// }
// let labelData = data[offset..<(offset + length)]
// if let label = String(data: labelData, encoding: .utf8) {
// labels.append(label)
// }
// offset += length
// }
// return labels.joined(separator: ".")
// }
//
// // DNS
// func parseDNSRequest(_ data: Data) -> (DNSHeader, [DNSQuestion])? {
// guard data.count >= 12 else { return nil } // DNS Header 12
//
// let header = DNSHeader(
// id: data.uint16(at: 0),
// flags: data.uint16(at: 2),
// qdCount: data.uint16(at: 4),
// anCount: data.uint16(at: 6),
// nsCount: data.uint16(at: 8),
// arCount: data.uint16(at: 10)
// )
//
// var offset = 12
// var questions: [DNSQuestion] = []
//
// for _ in 0..<header.qdCount {
// let name = parseName(from: data, offset: &offset)
// let type = data.uint16(at: offset)
// offset += 2
// let qclass = data.uint16(at: offset)
// offset += 2
//
// let question = DNSQuestion(name: name, type: type, qclass: qclass)
// questions.append(question)
// }
//
// return (header, questions)
// }
//
// //
// let dnsPacket: [UInt8] = [
// 0x12, 0x34, // Transaction ID
// 0x01, 0x00, // Flags
// 0x00, 0x01, // QDCOUNT
// 0x00, 0x00, // ANCOUNT
// 0x00, 0x00, // NSCOUNT
// 0x00, 0x00, // ARCOUNT
// 0x03, 0x77, 0x77, 0x77, // w w w
// 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, // g o o g l e
// 0x03, 0x63, 0x6f, 0x6d, // c o m
// 0x00, // End of name
// 0x00, 0x01, // QTYPE A
// 0x00, 0x01 // QCLASS IN
// ]
//
// if let data = Data(exactly: dnsPacket), let (header, questions) = parseDNSRequest(data) {
// print("Transaction ID: \(header.id)")
// print("Flags: \(header.flags)")
// print("Questions count: \(header.qdCount)")
// for q in questions {
// print("Question: \(q.name), type: \(q.type), class: \(q.qclass)")
// }
// }
}
// Helper
private extension Data {
func uint16(at offset: Int) -> UInt16 {
let subdata = self[offset..<offset+2]
return subdata.withUnsafeBytes { $0.load(as: UInt16.self).bigEndian }
}
}

View File

@ -54,7 +54,6 @@ enum TransportProtocol: UInt8 {
case icmp = 1
case tcp = 6
case udp = 17
}
struct IPPacket {
@ -80,4 +79,8 @@ struct IPPacket {
self.data = data
}
//
func getPayload() -> Data {
return data.subdata(in: 20..<data.count)
}
}

View File

@ -31,6 +31,8 @@ public class SDLConfiguration {
let stunServers: [StunServer]
let remoteDnsServer: String
let noticePort: Int
lazy var stunSocketAddress: SocketAddress = {
@ -52,7 +54,7 @@ public class SDLConfiguration {
let token: String
let networkCode: String
public init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, noticePort: Int, token: String, networkCode: String) {
public init(version: UInt8, installedChannel: String, superHost: String, superPort: Int, stunServers: [StunServer], clientId: String, noticePort: Int, token: String, networkCode: String, remoteDnsServer: String) {
self.version = version
self.installedChannel = installedChannel
self.superHost = superHost
@ -62,6 +64,7 @@ public class SDLConfiguration {
self.noticePort = noticePort
self.token = token
self.networkCode = networkCode
self.remoteDnsServer = remoteDnsServer
}
}

View File

@ -251,7 +251,8 @@ public class SDLContext: @unchecked Sendable {
}
private func startDnsClient() async throws {
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost("127.0.0.1", port: 15353)
let remoteDnsServer = config.remoteDnsServer
let dnsSocketAddress = try SocketAddress.makeAddressResolvingHost(remoteDnsServer, port: 15353)
self.dnsClient = try await DNSClient(dnsServerAddress: dnsSocketAddress, logger: self.logger)
try await withThrowingTaskGroup(of: Void.self) { group in
@ -303,7 +304,7 @@ public class SDLContext: @unchecked Sendable {
}
// tun
await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServers: ["100.100.100.100"])
await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer)
self.aesKey = aesKey
if upgradeType == .normal {
@ -359,7 +360,7 @@ public class SDLContext: @unchecked Sendable {
self.devAddr = changeNetworkCommand.devAddr
// tun
await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServers: ["100.100.100.100"])
await self.didNetworkConfigChanged(devAddr: self.devAddr, dnsServer: DNSClient.Helper.dnsServer)
self.aesKey = aesKey
var commandAck = SDLCommandAck()
@ -477,11 +478,11 @@ public class SDLContext: @unchecked Sendable {
// }
//
private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServers: [String]) async {
private func didNetworkConfigChanged(devAddr: SDLDevAddr, dnsServer: String) async {
let netAddress = SDLNetAddress(ip: devAddr.netAddr, maskLen: UInt8(devAddr.netBitLen))
let routes = [
Route(dstAddress: netAddress.networkAddress, subnetMask: netAddress.maskAddress),
Route(dstAddress: "100.100.100.100", subnetMask: "255.255.255.255")
Route(dstAddress: dnsServer, subnetMask: "255.255.255.255")
]
// Add code here to start the process of connecting the tunnel.
@ -490,9 +491,10 @@ public class SDLContext: @unchecked Sendable {
// DNS
let dnsSettings = NEDNSSettings(servers: dnsServers)
dnsSettings.searchDomains = ["punchnet.ts.net"]
dnsSettings.matchDomains = ["punchnet.ts.net"]
let networkDomain = devAddr.networkDomain
let dnsSettings = NEDNSSettings(servers: [dnsServer])
dnsSettings.searchDomains = [networkDomain]
dnsSettings.matchDomains = [networkDomain]
dnsSettings.matchDomainsNoSearch = false
networkSettings.dnsSettings = dnsSettings
self.logger.log("[SDLContext] Tun started at network ip: \(netAddress.ipAddress), mask: \(netAddress.maskAddress)", level: .info)
@ -541,7 +543,7 @@ public class SDLContext: @unchecked Sendable {
return
}
if DNSUtil.isDnsRequestPacket(ipPacket: packet) {
if DNSClient.Helper.isDnsRequestPacket(ipPacket: packet) {
let destIp = packet.header.destination_ip
NSLog("destIp: \(destIp), int: \(packet.header.destination)")
await self.dnsClient?.forward(ipPacket: packet)

View File

@ -65,6 +65,8 @@ struct SDLDevAddr: @unchecked Sendable {
var netBitLen: UInt32 = 0
var networkDomain: String = String()
var unknownFields = SwiftProtobuf.UnknownStorage()
init() {}
@ -559,6 +561,7 @@ extension SDLDevAddr: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementatio
2: .same(proto: "mac"),
3: .standard(proto: "net_addr"),
4: .standard(proto: "net_bit_len"),
5: .standard(proto: "network_domain"),
]
mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
@ -571,6 +574,7 @@ extension SDLDevAddr: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementatio
case 2: try { try decoder.decodeSingularBytesField(value: &self.mac) }()
case 3: try { try decoder.decodeSingularUInt32Field(value: &self.netAddr) }()
case 4: try { try decoder.decodeSingularUInt32Field(value: &self.netBitLen) }()
case 5: try { try decoder.decodeSingularStringField(value: &self.networkDomain) }()
default: break
}
}
@ -589,6 +593,9 @@ extension SDLDevAddr: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementatio
if self.netBitLen != 0 {
try visitor.visitSingularUInt32Field(value: self.netBitLen, fieldNumber: 4)
}
if !self.networkDomain.isEmpty {
try visitor.visitSingularStringField(value: self.networkDomain, fieldNumber: 5)
}
try unknownFields.traverse(visitor: &visitor)
}
@ -597,6 +604,7 @@ extension SDLDevAddr: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementatio
if lhs.mac != rhs.mac {return false}
if lhs.netAddr != rhs.netAddr {return false}
if lhs.netBitLen != rhs.netBitLen {return false}
if lhs.networkDomain != rhs.networkDomain {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}

View File

@ -0,0 +1,38 @@
//
// UDPPacket.swift
// Tun
//
// Created by on 2025/12/13.
//
import Foundation
struct UDPHeader {
let sourcePort: UInt16
let destinationPort: UInt16
let length: UInt16
let checksum: UInt16
}
struct UDPPacket {
let header: UDPHeader
let payload: Data
init?(_ data: Data) {
// UDP header 8
guard data.count >= 8 else {
return nil
}
let header = UDPHeader(sourcePort: UInt16(bytes: (data[0], data[1])),
destinationPort: UInt16(bytes: (data[2], data[3])),
length: UInt16(bytes: (data[4], data[5])),
checksum: UInt16(bytes: (data[6], data[7]))
)
// UDP payload = length - 8
let payloadLength = Int(header.length) - 8
self.header = header
self.payload = data.subdata(in: 8..<(8 + payloadLength))
}
}

View File

@ -40,6 +40,7 @@ struct SystemConfig {
"super_ip": superIp as NSObject,
"super_port": superPort as NSObject,
"stun_servers": stunServers as NSObject,
"remote_dns_server": superIp as NSObject,
"notice_port": noticePort as NSObject
]