diff --git a/apps/dns_proxy/src/dns_handler.erl b/apps/dns_proxy/src/dns_handler.erl index cb5e5ce..45505ea 100644 --- a/apps/dns_proxy/src/dns_handler.erl +++ b/apps/dns_proxy/src/dns_handler.erl @@ -12,34 +12,37 @@ -behaviour(gen_server). -include_lib("dns_erlang/include/dns.hrl"). +-include_lib("pkt/include/pkt.hrl"). -include("dns_proxy.hrl"). %% API --export([start_link/4]). +-export([start_link/0]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([handle/1]). +-export([handle/5, handle_ip_packet/5]). -define(SERVER, ?MODULE). -define(RESOLVER_POOL, dns_resolver_pool). --record(state, { - socket, - src_ip, - src_port, - packet -}). +%% 协议部分 +-define(TCP_PROTOCOL, 6). +-define(UDP_PROTOCOL, 17). + +-record(state, {}). %%%=================================================================== %%% API %%%=================================================================== -start_link(Sock, Ip, Port, Packet) -> - gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []). +start_link() -> + gen_server:start_link(?MODULE, [], []). -handle(Pid) when is_pid(Pid) -> - gen_server:cast(Pid, handle). +handle_ip_packet(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) -> + gen_server:cast(Pid, {handle_ip_packet, Sock, SrcIp, SrcPort, Packet}). + +handle(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) -> + gen_server:cast(Pid, {handle, Sock, SrcIp, SrcPort, Packet}). %%%=================================================================== %%% gen_server callbacks @@ -50,8 +53,8 @@ handle(Pid) when is_pid(Pid) -> -spec(init(Args :: term()) -> {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {stop, Reason :: term()} | ignore). -init([Sock, SrcIp, SrcPort, Packet]) -> - {ok, #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}. +init([]) -> + {ok, #state{}}. %% @private %% @doc Handling call messages @@ -72,7 +75,23 @@ handle_call(_Request, _From, State = #state{}) -> {noreply, NewState :: #state{}} | {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), NewState :: #state{}}). -handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}) -> +handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket}, State) -> + {#ipv4{saddr = ReqSAddr, daddr = ReqDAddr, p = Protocol}, ReqIpPayload} = pkt:ipv4(IpPacket), + case Protocol =:= ?UDP_PROTOCOL of + true -> + {#udp{sport = ReqSPort, dport = ReqDPort}, UdpPayload} = pkt:udp(ReqIpPayload), + case resolver(UdpPayload) of + {ok, DnsResp} -> + RespIpPacket = build_ip_packet(ReqDAddr, ReqSAddr, ReqDPort, ReqSPort, DnsResp), + gen_udp:send(Sock, SrcIp, SrcPort, RespIpPacket); + {error, Reason} -> + lager:debug("[dns_handler] resolver get error: ~p", [Reason]) + end; + false -> + lager:debug("[dns_handler] resolver invalid protocol: ~p", [Protocol]) + end, + {stop, normal, State}; +handle_cast({handle, Sock, SrcIp, SrcPort, Packet}, State) -> case resolver(Packet) of {ok, Resp} -> gen_udp:send(Sock, SrcIp, SrcPort, Resp); @@ -209,4 +228,30 @@ search_inbuilt_domain(QName) when is_binary(QName) -> {ok, {192, 168, 1, 101}}; false -> error - end. \ No newline at end of file + end. + +-spec build_ip_packet(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), Payload :: binary()) -> IpPacket :: binary(). +build_ip_packet(SAddr, DAddr, SPort, DPort, Payload) when is_integer(SPort), is_integer(DPort), is_binary(Payload) -> + ULen = 8 + byte_size(Payload), + RespUdpHeader = pkt:udp(#udp{ + sport = SPort, + dport = DPort, + ulen = ULen, + sum = dns_utils:udp_checksum(SAddr, DAddr, SPort, DPort, Payload) + }), + IpPayload = <>, + + IpPacket0 = #ipv4{ + len = 20 + ULen, + ttl = 64, + off = 0, + mf = 0, + sum = 0, + saddr = SAddr, + daddr = DAddr, + opt = <<>> + }, + IpCheckSum = dns_utils:ip_checksum(IpPacket0), + IpHeader = pkt:ipv4(IpPacket0#ipv4{sum = IpCheckSum}), + + <>. \ No newline at end of file diff --git a/apps/dns_proxy/src/dns_handler_sup.erl b/apps/dns_proxy/src/dns_handler_sup.erl index 6fad53a..7c7e023 100644 --- a/apps/dns_proxy/src/dns_handler_sup.erl +++ b/apps/dns_proxy/src/dns_handler_sup.erl @@ -16,7 +16,7 @@ %% Supervisor callbacks -export([init/1]). --export([start_handler/4]). +-export([start_handler/0]). -define(SERVER, ?MODULE). @@ -60,8 +60,8 @@ init([]) -> %%% Internal functions %%%=================================================================== -start_handler(Sock, Ip, Port, Packet) -> - case supervisor:start_child(?MODULE, [Sock, Ip, Port, Packet]) of +start_handler() -> + case supervisor:start_child(?MODULE, []) of {ok, Pid} -> {ok, Pid}; {error, {already_started, Pid}} -> diff --git a/apps/dns_proxy/src/dns_server.erl b/apps/dns_proxy/src/dns_server.erl index d231d11..dbb9d69 100644 --- a/apps/dns_proxy/src/dns_server.erl +++ b/apps/dns_proxy/src/dns_server.erl @@ -17,9 +17,9 @@ loop(Sock) -> receive {udp, Sock, Ip, Port, Packet} -> lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]), - case dns_handler_sup:start_handler(Sock, Ip, Port, Packet) of + case dns_handler_sup:start_handler() of {ok, HandlerPid} -> - dns_handler:handle(HandlerPid); + dns_handler:handle_ip_packet(HandlerPid, Sock, Ip, Port, Packet); Error -> lager:debug("[dns_server] start handler get error: ~p", [Error]) end, diff --git a/apps/dns_proxy/src/dns_utils.erl b/apps/dns_proxy/src/dns_utils.erl index 7cef0ef..69f5ea2 100644 --- a/apps/dns_proxy/src/dns_utils.erl +++ b/apps/dns_proxy/src/dns_utils.erl @@ -9,8 +9,11 @@ -module(dns_utils). -author("anlicheng"). +-include_lib("pkt/include/pkt.hrl"). + %% API --export([ends_with/2, parse_address/1]). +-export([ends_with/2, parse_address/1, checksum/1, udp_checksum/5, ip_checksum/1]). +-export([test/0]). -spec ends_with(Bin :: binary(), Suffix :: binary()) -> boolean(). ends_with(Bin, Suffix) when is_binary(Bin), is_binary(Suffix) -> @@ -26,3 +29,102 @@ parse_address(Ip = {Ip0, Ip1, Ip2, Ip3}) when is_integer(Ip0), is_integer(Ip1), {ok, Ip}; parse_address(Bin) when is_binary(Bin) -> inet:parse_address(binary_to_list(Bin)). + + +%%-------------------------------------------------------------------- +%% @doc +%% Calculate 16-bit one's-complement checksum. +%% +%% Input: +%% Bin :: binary() +%% +%% Output: +%% Checksum :: 0..16#FFFF +%% +%% Usage: +%% Checksum = checksum(Bin). +%% +%% Notes: +%% - Bin is treated as big-endian 16-bit words +%% - If Bin length is odd, a zero byte is padded +%%-------------------------------------------------------------------- +-spec checksum(binary()) -> non_neg_integer(). +checksum(Bin) when is_binary(Bin) -> + Sum = checksum_sum(Bin, 0), + %% fold carry bits + Folded = fold16(Sum), + %% one's complement + (bnot Folded) band 16#FFFF. +checksum_sum(<<>>, Acc) -> + Acc; +checksum_sum(<>, Acc) -> + checksum_sum(Rest, Acc + Word); +checksum_sum(<>, Acc) -> + %% odd length: pad low byte with zero + checksum_sum(<<>>, Acc + (Byte bsl 8)). + +fold16(S) when S > 16#FFFF -> + fold16((S band 16#FFFF) + (S bsr 16)); +fold16(S) -> + S. + +-spec udp_checksum(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), UDPPayload :: binary()) -> non_neg_integer(). +udp_checksum({SA1, SA2, SA3, SA4}, {DA1, DA2, DA3, DA4}, SPort, DPort, UDPPayload) when is_integer(SPort), is_integer(DPort), is_binary(UDPPayload) -> + ULen = 8 + byte_size(UDPPayload), + PseudoHeader = <>, + UDPHeader = <>, + CheckSum = checksum(<>), + case CheckSum of + 0 -> + 16#FFFF; + _ -> + CheckSum + end. + +-spec ip_checksum(Ipv4 :: #ipv4{}) -> non_neg_integer(). +ip_checksum(#ipv4{hl = HL, tos = ToS, len = Len, + id = Id, df = DF, mf = MF, + off = Off, ttl = TTL, p = P, + saddr = {SA1, SA2, SA3, SA4}, + daddr = {DA1, DA2, DA3, DA4}, + opt = Opt}) -> + + IPBinForChecksum = + <<4:4, HL:4, %% Version=4 + IHL + ToS:8, %% Type of Service + Len:16/big, %% Total Length + Id:16/big, %% Identification + DF:1, MF:1, Off:14, %% Flags + Fragment offset + TTL:8, %% TTL + P:8, %% Protocol + 0:16, %% checksum field set to 0 for calculation + SA1:8, SA2:8, SA3:8, SA4:8, %% Source IP + DA1:8, DA2:8, DA3:8, DA4:8, %% Dest IP + Opt/binary>>, %% Options (可选) + CheckSum = checksum(IPBinForChecksum), + case CheckSum of + 0 -> + 16#FFFF; + _ -> + CheckSum + end. + +test() -> + Bin = <<69,0,0,77,48,179,0,0,64,17,28,168,100,123,0,2,100,100,100,100,252,230,0,53,0,57,6,92,152,24,1,0,0,1,0,0,0,0,0,0,2,100,98,7,95,100,110,115,45,115,100,4,95,117,100,112,8,112,117,110,99,104,110,101,116,2,116,115,3,110,101,116,0,0,12,0,1>>, + + {IPPacket = #ipv4{ + saddr = SAddr, + daddr = DAddr, + sum = IpSum + }, UdpPacket} = pkt:ipv4(Bin), + + {UDP = #udp{sport = SPort, dport = DPort, sum = CheckSum}, UDPPayload} = pkt:udp(UdpPacket), + + X = udp_checksum(SAddr, DAddr, SPort, DPort, UDPPayload), + + lager:debug("ip_sum: ~p, y?: ~p, udp: ~p, checkSum: ~p, X is: ~p", [IpSum, ip_checksum(IPPacket), UDP, CheckSum, X]), + + dns:decode_message(UDPPayload). \ No newline at end of file