fix Context

This commit is contained in:
anlicheng 2026-04-14 20:32:21 +08:00
parent f72a9acf24
commit 3219efbd76
2 changed files with 191 additions and 150 deletions

View File

@ -328,30 +328,9 @@ actor SDLContextActor {
} }
// //
let messageStream = udpHole.messageStream
let messageTask = Task.detached { let messageTask = Task.detached {
for await (remoteAddress, message) in udpHole.messageStream { await self.consumeUDPHoleMessages(stream: messageStream, localAddress: localAddress)
if Task.isCancelled {
break
}
switch message {
case .stunProbeReply(let probeReply):
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
case .register(let register):
try? await self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
await self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
case .data(let data):
do {
try await self.handleHoleData(data: data)
} catch let err {
SDLLogger.log("[SDLContext] handleHoleData get err: \(err)")
}
case .stunReply(_):
//SDLLogger.shared.log("[SDLContext] get a stunReply: \(stunReply)")
()
}
}
} }
self.udpHole = udpHole self.udpHole = udpHole
@ -597,133 +576,6 @@ actor SDLContextActor {
} }
} }
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
let networkAddr = config.networkAddress
SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)")
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)")
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)")
}
}
private func handleHoleData(data: SDLData) async throws {
guard let dataCipher = self.dataCipher else {
return
}
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data))
let layerPacket = try LayerPacket(layerData: decyptedData)
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
SDLLogger.log("[SDLContext] get arp request packet")
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
SDLLogger.log("[SDLContext] get arp response packet")
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))")
}
} else {
SDLLogger.log("[SDLContext] get invalid arp packet")
}
case .ipv4:
// ip
guard let ipPacket = IPPacket(layerPacket.data) else {
return
}
//
let identitySnapshot = self.snapshotPublisher.current()
let ruleMap = identitySnapshot.lookup(data.identityID)
if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) {
// debug
if ipPacket.header.source == 168428037 {
SDLLogger.log("[SDLContext] hole data: \(Array(ipPacket.data)), len: \(ipPacket.data.count)", for: .trace)
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace)
}
else {
SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug)
//
await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient)
}
default:
SDLLogger.log("[SDLContext] get invalid packet", for: .debug)
}
}
private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool {
//
if let reverseFlowSession = ipPacket.flowSession()?.reverse(),
self.flowSessionManager.hasSession(reverseFlowSession) {
self.flowSessionManager.updateSession(reverseFlowSession)
return true
}
//
let proto = ipPacket.header.proto
// 访
switch ipPacket.transportPacket {
case .tcp(let tcpPacket):
if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) {
return true
}
case .udp(let udpPacket):
if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) {
return true
}
case .icmp(_):
return true
default:
return false
}
return false
}
// , 线packetFlow // , 线packetFlow
private func startReader() { private func startReader() {
// //
@ -911,6 +763,166 @@ actor SDLContextActor {
} }
} }
// Hole
extension SDLContextActor {
private func consumeUDPHoleMessages(stream: AsyncStream<(SocketAddress, SDLHoleMessage)>, localAddress: SocketAddress) async {
for await (remoteAddress, message) in stream {
if Task.isCancelled {
break
}
switch message.inboundMessage {
case .control(let controlMessage):
await self.handleHoleControlMessage(controlMessage, localAddress: localAddress, remoteAddress: remoteAddress)
case .data(let data):
try? await self.handleHoleData(data: data)
}
}
}
private func handleHoleControlMessage(_ message: SDLHoleControlMessage, localAddress: SocketAddress, remoteAddress: SocketAddress) async {
switch message {
case .stunProbeReply(let probeReply):
await self.proberActor.handleProbeReply(localAddress: localAddress, reply: probeReply)
case .register(let register):
try? self.handleRegister(remoteAddress: remoteAddress, register: register)
case .registerAck(let registerAck):
self.handleRegisterAck(remoteAddress: remoteAddress, registerAck: registerAck)
case .stunReply(_):
//SDLLogger.shared.log("[SDLContext] get a stunReply: \(stunReply)")
()
}
}
private func handleRegister(remoteAddress: SocketAddress, register: SDLRegister) throws {
let networkAddr = config.networkAddress
SDLLogger.log("[SDLContext] register packet: \(register), network_address: \(networkAddr)")
// tun,
if register.dstMac == networkAddr.mac && register.networkID == networkAddr.networkId {
// ack
var registerAck = SDLRegisterAck()
registerAck.networkID = networkAddr.networkId
registerAck.srcMac = networkAddr.mac
registerAck.dstMac = register.srcMac
self.sendPeerPacket(type: .registerAck, data: try registerAck.serializedData(), remoteAddress: remoteAddress)
// , super-nodenatudpnat
let session = Session(dstMac: register.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegister get a invalid packet, because dst_ip not matched: \(register.dstMac)")
}
}
private func handleRegisterAck(remoteAddress: SocketAddress, registerAck: SDLRegisterAck) {
// tun,
let networkAddr = config.networkAddress
if registerAck.dstMac == networkAddr.mac && registerAck.networkID == networkAddr.networkId {
let session = Session(dstMac: registerAck.srcMac, natAddress: remoteAddress)
self.sessionManager.addSession(session: session)
} else {
SDLLogger.log("[SDLContext] didReadRegisterAck get a invalid packet, because dst_mac not matched: \(registerAck.dstMac)")
}
}
private func handleHoleData(data: SDLData) async throws {
guard let dataCipher = self.dataCipher else {
return
}
let mac = LayerPacket.MacAddress(data: data.dstMac)
let networkAddr = config.networkAddress
guard (data.dstMac == networkAddr.mac || mac.isBroadcast() || mac.isMulticast()) else {
return
}
let decyptedData = try dataCipher.decrypt(cipherText: Data(data.data))
let layerPacket = try LayerPacket(layerData: decyptedData)
self.flowTracer.inc(num: decyptedData.count, type: .inbound)
// arp
switch layerPacket.type {
case .arp:
// arp
if let arpPacket = ARPPacket(data: layerPacket.data) {
if arpPacket.targetIP == networkAddr.ip {
switch arpPacket.opcode {
case .request:
SDLLogger.log("[SDLContext] get arp request packet")
let response = ARPPacket.arpResponse(for: arpPacket, mac: networkAddr.mac, ip: networkAddr.ip)
await self.routeLayerPacket(dstMac: arpPacket.senderMAC, type: .arp, data: response.marshal())
case .response:
SDLLogger.log("[SDLContext] get arp response packet")
await self.arpServer.append(ip: arpPacket.senderIP, mac: arpPacket.senderMAC)
}
} else {
SDLLogger.log("[SDLContext] get invalid arp packet: \(arpPacket), target_ip: \(SDLUtil.int32ToIp(arpPacket.targetIP)), net ip: \(SDLUtil.int32ToIp(networkAddr.ip))")
}
} else {
SDLLogger.log("[SDLContext] get invalid arp packet")
}
case .ipv4:
// ip
guard let ipPacket = IPPacket(layerPacket.data) else {
return
}
//
let identitySnapshot = self.snapshotPublisher.current()
let ruleMap = identitySnapshot.lookup(data.identityID)
if true || self.checkPolicy(ipPacket: ipPacket, ruleMap: ruleMap) {
// debug
if ipPacket.header.source == 168428037 {
SDLLogger.log("[SDLContext] hole data: \(Array(ipPacket.data)), len: \(ipPacket.data.count)", for: .trace)
}
let packet = NEPacket(data: ipPacket.data, protocolFamily: 2)
self.provider.packetFlow.writePacketObjects([packet])
SDLLogger.log("[SDLContext] hole identity: \(data.identityID), allow, data count: \(ipPacket.data.count)", for: .trace)
}
else {
SDLLogger.log("[SDLContext] not found identity: \(data.identityID) ruleMap", for: .debug)
//
await self.identifyStore.policyRequest(srcIdentityId: data.identityID, dstIdentityId: self.config.identityId, using: self.quicClient)
}
default:
SDLLogger.log("[SDLContext] get invalid packet", for: .debug)
}
}
private func checkPolicy(ipPacket: IPPacket, ruleMap: IdentityRuleMap?) -> Bool {
//
if let reverseFlowSession = ipPacket.flowSession()?.reverse(),
self.flowSessionManager.hasSession(reverseFlowSession) {
self.flowSessionManager.updateSession(reverseFlowSession)
return true
}
//
let proto = ipPacket.header.proto
// 访
switch ipPacket.transportPacket {
case .tcp(let tcpPacket):
if let ruleMap, ruleMap.isAllow(proto: proto, port: tcpPacket.header.dstPort) {
return true
}
case .udp(let udpPacket):
if let ruleMap, ruleMap.isAllow(proto: proto, port: udpPacket.dstPort) {
return true
}
case .icmp(_):
return true
default:
return false
}
return false
}
}
private extension UInt32 { private extension UInt32 {
// ip // ip
func asIpAddress() -> String { func asIpAddress() -> String {

View File

@ -124,6 +124,35 @@ enum SDLHoleMessage {
case stunReply(SDLStunReply) case stunReply(SDLStunReply)
} }
enum SDLHoleControlMessage {
case register(SDLRegister)
case registerAck(SDLRegisterAck)
case stunProbeReply(SDLStunProbeReply)
case stunReply(SDLStunReply)
}
enum SDLHoleInboundMessage {
case control(SDLHoleControlMessage)
case data(SDLData)
}
extension SDLHoleMessage {
var inboundMessage: SDLHoleInboundMessage {
switch self {
case .data(let data):
return .data(data)
case .register(let register):
return .control(.register(register))
case .registerAck(let registerAck):
return .control(.registerAck(registerAck))
case .stunProbeReply(let stunProbeReply):
return .control(.stunProbeReply(stunProbeReply))
case .stunReply(let stunReply):
return .control(.stunReply(stunReply))
}
}
}
enum SDLQUICInboundMessage { enum SDLQUICInboundMessage {
// //
case welcome(SDLWelcome) case welcome(SDLWelcome)