支持基于ip包的dns请求解析

This commit is contained in:
anlicheng 2025-12-13 14:38:02 +08:00
parent a32f00a1d8
commit 8206710798
4 changed files with 169 additions and 22 deletions

View File

@ -12,34 +12,37 @@
-behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl").
-include_lib("pkt/include/pkt.hrl").
-include("dns_proxy.hrl").
%% API
-export([start_link/4]).
-export([start_link/0]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle/1]).
-export([handle/5, handle_ip_packet/5]).
-define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool).
-record(state, {
socket,
src_ip,
src_port,
packet
}).
%%
-define(TCP_PROTOCOL, 6).
-define(UDP_PROTOCOL, 17).
-record(state, {}).
%%%===================================================================
%%% API
%%%===================================================================
start_link(Sock, Ip, Port, Packet) ->
gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []).
start_link() ->
gen_server:start_link(?MODULE, [], []).
handle(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, handle).
handle_ip_packet(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) ->
gen_server:cast(Pid, {handle_ip_packet, Sock, SrcIp, SrcPort, Packet}).
handle(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) ->
gen_server:cast(Pid, {handle, Sock, SrcIp, SrcPort, Packet}).
%%%===================================================================
%%% gen_server callbacks
@ -50,8 +53,8 @@ handle(Pid) when is_pid(Pid) ->
-spec(init(Args :: term()) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore).
init([Sock, SrcIp, SrcPort, Packet]) ->
{ok, #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
init([]) ->
{ok, #state{}}.
%% @private
%% @doc Handling call messages
@ -72,7 +75,23 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}) ->
handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket}, State) ->
{#ipv4{saddr = ReqSAddr, daddr = ReqDAddr, p = Protocol}, ReqIpPayload} = pkt:ipv4(IpPacket),
case Protocol =:= ?UDP_PROTOCOL of
true ->
{#udp{sport = ReqSPort, dport = ReqDPort}, UdpPayload} = pkt:udp(ReqIpPayload),
case resolver(UdpPayload) of
{ok, DnsResp} ->
RespIpPacket = build_ip_packet(ReqDAddr, ReqSAddr, ReqDPort, ReqSPort, DnsResp),
gen_udp:send(Sock, SrcIp, SrcPort, RespIpPacket);
{error, Reason} ->
lager:debug("[dns_handler] resolver get error: ~p", [Reason])
end;
false ->
lager:debug("[dns_handler] resolver invalid protocol: ~p", [Protocol])
end,
{stop, normal, State};
handle_cast({handle, Sock, SrcIp, SrcPort, Packet}, State) ->
case resolver(Packet) of
{ok, Resp} ->
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
@ -210,3 +229,29 @@ search_inbuilt_domain(QName) when is_binary(QName) ->
false ->
error
end.
-spec build_ip_packet(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), Payload :: binary()) -> IpPacket :: binary().
build_ip_packet(SAddr, DAddr, SPort, DPort, Payload) when is_integer(SPort), is_integer(DPort), is_binary(Payload) ->
ULen = 8 + byte_size(Payload),
RespUdpHeader = pkt:udp(#udp{
sport = SPort,
dport = DPort,
ulen = ULen,
sum = dns_utils:udp_checksum(SAddr, DAddr, SPort, DPort, Payload)
}),
IpPayload = <<RespUdpHeader/binary, Payload/binary>>,
IpPacket0 = #ipv4{
len = 20 + ULen,
ttl = 64,
off = 0,
mf = 0,
sum = 0,
saddr = SAddr,
daddr = DAddr,
opt = <<>>
},
IpCheckSum = dns_utils:ip_checksum(IpPacket0),
IpHeader = pkt:ipv4(IpPacket0#ipv4{sum = IpCheckSum}),
<<IpHeader/binary, IpPayload/binary>>.

View File

@ -16,7 +16,7 @@
%% Supervisor callbacks
-export([init/1]).
-export([start_handler/4]).
-export([start_handler/0]).
-define(SERVER, ?MODULE).
@ -60,8 +60,8 @@ init([]) ->
%%% Internal functions
%%%===================================================================
start_handler(Sock, Ip, Port, Packet) ->
case supervisor:start_child(?MODULE, [Sock, Ip, Port, Packet]) of
start_handler() ->
case supervisor:start_child(?MODULE, []) of
{ok, Pid} ->
{ok, Pid};
{error, {already_started, Pid}} ->

View File

@ -17,9 +17,9 @@ loop(Sock) ->
receive
{udp, Sock, Ip, Port, Packet} ->
lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]),
case dns_handler_sup:start_handler(Sock, Ip, Port, Packet) of
case dns_handler_sup:start_handler() of
{ok, HandlerPid} ->
dns_handler:handle(HandlerPid);
dns_handler:handle_ip_packet(HandlerPid, Sock, Ip, Port, Packet);
Error ->
lager:debug("[dns_server] start handler get error: ~p", [Error])
end,

View File

@ -9,8 +9,11 @@
-module(dns_utils).
-author("anlicheng").
-include_lib("pkt/include/pkt.hrl").
%% API
-export([ends_with/2, parse_address/1]).
-export([ends_with/2, parse_address/1, checksum/1, udp_checksum/5, ip_checksum/1]).
-export([test/0]).
-spec ends_with(Bin :: binary(), Suffix :: binary()) -> boolean().
ends_with(Bin, Suffix) when is_binary(Bin), is_binary(Suffix) ->
@ -26,3 +29,102 @@ parse_address(Ip = {Ip0, Ip1, Ip2, Ip3}) when is_integer(Ip0), is_integer(Ip1),
{ok, Ip};
parse_address(Bin) when is_binary(Bin) ->
inet:parse_address(binary_to_list(Bin)).
%%--------------------------------------------------------------------
%% @doc
%% Calculate 16-bit one's-complement checksum.
%%
%% Input:
%% Bin :: binary()
%%
%% Output:
%% Checksum :: 0..16#FFFF
%%
%% Usage:
%% Checksum = checksum(Bin).
%%
%% Notes:
%% - Bin is treated as big-endian 16-bit words
%% - If Bin length is odd, a zero byte is padded
%%--------------------------------------------------------------------
-spec checksum(binary()) -> non_neg_integer().
checksum(Bin) when is_binary(Bin) ->
Sum = checksum_sum(Bin, 0),
%% fold carry bits
Folded = fold16(Sum),
%% one's complement
(bnot Folded) band 16#FFFF.
checksum_sum(<<>>, Acc) ->
Acc;
checksum_sum(<<Word:16/big, Rest/binary>>, Acc) ->
checksum_sum(Rest, Acc + Word);
checksum_sum(<<Byte:8>>, Acc) ->
%% odd length: pad low byte with zero
checksum_sum(<<>>, Acc + (Byte bsl 8)).
fold16(S) when S > 16#FFFF ->
fold16((S band 16#FFFF) + (S bsr 16));
fold16(S) ->
S.
-spec udp_checksum(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), UDPPayload :: binary()) -> non_neg_integer().
udp_checksum({SA1, SA2, SA3, SA4}, {DA1, DA2, DA3, DA4}, SPort, DPort, UDPPayload) when is_integer(SPort), is_integer(DPort), is_binary(UDPPayload) ->
ULen = 8 + byte_size(UDPPayload),
PseudoHeader = <<SA1, SA2, SA3, SA4,
DA1, DA2, DA3, DA4,
0:8, 17:8,
ULen:16>>,
UDPHeader = <<SPort:16, DPort:16, ULen:16, 0:16>>,
CheckSum = checksum(<<PseudoHeader/binary, UDPHeader/binary, UDPPayload/binary>>),
case CheckSum of
0 ->
16#FFFF;
_ ->
CheckSum
end.
-spec ip_checksum(Ipv4 :: #ipv4{}) -> non_neg_integer().
ip_checksum(#ipv4{hl = HL, tos = ToS, len = Len,
id = Id, df = DF, mf = MF,
off = Off, ttl = TTL, p = P,
saddr = {SA1, SA2, SA3, SA4},
daddr = {DA1, DA2, DA3, DA4},
opt = Opt}) ->
IPBinForChecksum =
<<4:4, HL:4, %% Version=4 + IHL
ToS:8, %% Type of Service
Len:16/big, %% Total Length
Id:16/big, %% Identification
DF:1, MF:1, Off:14, %% Flags + Fragment offset
TTL:8, %% TTL
P:8, %% Protocol
0:16, %% checksum field set to 0 for calculation
SA1:8, SA2:8, SA3:8, SA4:8, %% Source IP
DA1:8, DA2:8, DA3:8, DA4:8, %% Dest IP
Opt/binary>>, %% Options ()
CheckSum = checksum(IPBinForChecksum),
case CheckSum of
0 ->
16#FFFF;
_ ->
CheckSum
end.
test() ->
Bin = <<69,0,0,77,48,179,0,0,64,17,28,168,100,123,0,2,100,100,100,100,252,230,0,53,0,57,6,92,152,24,1,0,0,1,0,0,0,0,0,0,2,100,98,7,95,100,110,115,45,115,100,4,95,117,100,112,8,112,117,110,99,104,110,101,116,2,116,115,3,110,101,116,0,0,12,0,1>>,
{IPPacket = #ipv4{
saddr = SAddr,
daddr = DAddr,
sum = IpSum
}, UdpPacket} = pkt:ipv4(Bin),
{UDP = #udp{sport = SPort, dport = DPort, sum = CheckSum}, UDPPayload} = pkt:udp(UdpPacket),
X = udp_checksum(SAddr, DAddr, SPort, DPort, UDPPayload),
lager:debug("ip_sum: ~p, y?: ~p, udp: ~p, checkSum: ~p, X is: ~p", [IpSum, ip_checksum(IPPacket), UDP, CheckSum, X]),
dns:decode_message(UDPPayload).