sdlan/apps/sdlan/src/dns_proxy/dns_handler.erl
2026-02-11 15:42:44 +08:00

263 lines
10 KiB
Erlang

%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 03. 12月 2025 23:00
%%%-------------------------------------------------------------------
-module(dns_handler).
-author("anlicheng").
-behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl").
-include_lib("pkt/include/pkt.hrl").
-include("dns_proxy.hrl").
%% API
-export([start_link/0]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle_ip_packet/5]).
-define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool).
%% 协议部分
-define(TCP_PROTOCOL, 6).
-define(UDP_PROTOCOL, 17).
-record(state, {}).
%%%===================================================================
%%% API
%%%===================================================================
start_link() ->
gen_server:start_link(?MODULE, [], []).
handle_ip_packet(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) ->
gen_server:cast(Pid, {handle_ip_packet, Sock, SrcIp, SrcPort, Packet}).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
%% @private
%% @doc Initializes the server
-spec(init(Args :: term()) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore).
init([]) ->
{ok, #state{}}.
%% @private
%% @doc Handling call messages
-spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()},
State :: #state{}) ->
{reply, Reply :: term(), NewState :: #state{}} |
{reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} |
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: #state{}} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_call(_Request, _From, State = #state{}) ->
{reply, ok, State}.
%% @private
%% @doc Handling cast messages
-spec(handle_cast(Request :: term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket}, State) ->
{#ipv4{saddr = ReqSAddr, daddr = ReqDAddr, p = Protocol}, ReqIpPayload} = pkt:ipv4(IpPacket),
case Protocol =:= ?UDP_PROTOCOL of
true ->
{#udp{sport = ReqSPort, dport = ReqDPort}, UdpPayload} = pkt:udp(ReqIpPayload),
case resolver(UdpPayload) of
{ok, DnsResp} ->
RespIpPacket = build_ip_packet(ReqDAddr, ReqSAddr, ReqDPort, ReqSPort, DnsResp),
gen_udp:send(Sock, SrcIp, SrcPort, RespIpPacket);
{error, Reason} ->
logger:notice("[dns_handler] resolver get error: ~p", [Reason])
end;
false ->
logger:notice("[dns_handler] resolver invalid protocol: ~p", [Protocol])
end,
{stop, normal, State}.
%% @private
%% @doc Handling all non call/cast messages
-spec(handle_info(Info :: timeout() | term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_info(_Info, State) ->
{noreply, State}.
%% @private
%% @doc This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
-spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()),
State :: #state{}) -> term()).
terminate(_Reason, _State = #state{}) ->
ok.
%% @private
%% @doc Convert process state when code is changed
-spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{},
Extra :: term()) ->
{ok, NewState :: #state{}} | {error, Reason :: term()}).
code_change(_OldVsn, State = #state{}, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
-spec resolver(Packet :: binary()) -> {ok, Resp :: binary()} | {error, Reason :: any()}.
resolver(Packet) when is_binary(Packet) ->
resolver0(Packet, dns:decode_message(Packet)).
resolver0(Packet, QueryMsg = #dns_message{qc = 1, questions = [Question = #dns_query{name = QName, type = QType, class = QClass}|_]}) ->
%% 查找是否是内置的域名
case sdlan_hostname_regedit:lookup(QName) of
{ok, Ip} ->
Answer = #dns_rr {
name = QName,
type = QType,
class = QClass,
ttl = 300,
data = #dns_rrdata_a {
ip = Ip
}
},
RespMsg = QueryMsg#dns_message{
qr = true,
ra = true,
anc = 1,
auc = 0,
adc = 0,
answers = [Answer],
authority = [],
additional = []
},
logger:debug("[dns_handler] punchnet inbuilt qnanme: ~p, ip: ~p", [QName, Ip]),
{ok, dns:encode_message(RespMsg)};
error ->
%% 是否命中内部的域名后缀
EmptyDnsResp = dns:encode_message(build_nxdomain_response(QueryMsg)),
case sdlan_domain_regedit:maybe_domain(QName) of
true ->
logger:debug("[dns_handler] punchnet inbuilt qnanme: ~p, nxdomain", [QName]),
{ok, EmptyDnsResp};
false ->
case dns_cache:lookup(Question) of
{hit, Cache} ->
logger:debug("[dns_handler] qname: ~p, hit cache answers: ~p", [QName, Cache#dns_cache.answers]),
RespMsg = build_response(QueryMsg, Cache),
{ok, dns:encode_message(RespMsg)};
miss ->
Ref = make_ref(),
forward_to_upstream(Ref, Packet, QueryMsg),
logger:debug("[dns_handler] cache is miss, forward_to_upstream"),
receive
{dns_resolver_reply, Ref, Resp} ->
case dns:decode_message(Resp) of
RespMsg = #dns_message{answers = Answers} ->
logger:debug("[dns_handler] get a response answers: ~p", [Answers]),
dns_cache:insert(Question, RespMsg),
{ok, Resp};
Error ->
logger:debug("[dns_handler] parse reply get error: ~p", [Error]),
{ok, EmptyDnsResp}
end
after 5000 ->
logger:debug("[dns_handler] forward_to_upstream timeout"),
{ok, EmptyDnsResp}
end
end
end
end;
resolver0(_, Error) ->
logger:warning("[dns_handler] decode dns_query get error: ~p", [Error]),
{error, Error}.
-spec forward_to_upstream(Ref :: reference(), Request :: binary(), QueryMsg :: #dns_message{}) -> no_return().
forward_to_upstream(Ref, Request, QueryMsg) ->
ReceiverPid = self(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> dns_resolver:forward(Pid, ReceiverPid, Ref, Request, QueryMsg) end).
-spec build_response(QueryMsg :: #dns_message{}, Dns_cache :: #dns_cache{}) -> RespMsg :: #dns_message{}.
build_response(QueryMsg, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) ->
Now = os:system_time(second),
RemainingTTL = ExpireAt - Now,
Answers2 = [adjust_ttl(RR, RemainingTTL) || RR <- Answers],
Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority],
Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional],
QueryMsg#dns_message{
qr = true,
ra = true,
aa = AA,
rc = RCode,
anc = length(Answers2),
auc = length(Authority2),
adc = length(Additional2),
answers = Answers2,
authority = Authority2,
additional = Additional2
}.
-spec adjust_ttl(RR :: any(), RemainingTTL :: integer()) -> any().
adjust_ttl(RR = #dns_rr{}, RemainingTTL) ->
RR#dns_rr{ttl = max(0, RemainingTTL)};
adjust_ttl(RR, _RemainingTTL) ->
RR.
-spec build_nxdomain_response(QueryMsg :: #dns_message{}) -> EmptyResp :: #dns_message{}.
build_nxdomain_response(QueryMsg) ->
QueryMsg#dns_message{
qr = true,
aa = true,
ra = true,
rc = ?DNS_RCODE_NXDOMAIN,
anc = 0,
auc = 0,
adc = 0,
answers = [],
authority = [],
additional = []
}.
-spec build_ip_packet(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), Payload :: binary()) -> IpPacket :: binary().
build_ip_packet(SAddr, DAddr, SPort, DPort, UdpPayload) when is_integer(SPort), is_integer(DPort), is_binary(UdpPayload) ->
ULen = 8 + byte_size(UdpPayload),
RespUdpHeader = pkt:udp(#udp{
sport = SPort,
dport = DPort,
ulen = ULen,
sum = dns_utils:udp_checksum(SAddr, DAddr, SPort, DPort, UdpPayload)
}),
IpPayload = <<RespUdpHeader/binary, UdpPayload/binary>>,
IpPacket0 = #ipv4{
len = 20 + ULen,
ttl = 64,
off = 0,
mf = 0,
sum = 0,
p = ?UDP_PROTOCOL,
saddr = SAddr,
daddr = DAddr,
opt = <<>>
},
IpCheckSum = dns_utils:ip_checksum(IpPacket0),
IpHeader = pkt:ipv4(IpPacket0#ipv4{sum = IpCheckSum}),
<<IpHeader/binary, IpPayload/binary>>.