This commit is contained in:
anlicheng 2025-12-03 23:41:04 +08:00
parent 19f05de3a6
commit f5cb87dc6a
4 changed files with 143 additions and 185 deletions

View File

@ -4,23 +4,82 @@
%%% @doc %%% @doc
%%% %%%
%%% @end %%% @end
%%% Created : 03. 12 2025 17:27 %%% Created : 03. 12 2025 23:00
%%%------------------------------------------------------------------- %%%-------------------------------------------------------------------
-module(dns_handler). -module(dns_handler).
-author("anlicheng"). -author("anlicheng").
-behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl"). -include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl"). -include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl"). -include_lib("dns_erlang/include/dns_terms.hrl").
-define(RESOLVER_POOL, dns_resolver_pool). %% API
-export([start_link/4]).
-export([start_link/4, init/4]). %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle/1]).
-define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool).
%%
-define(UPSTREAM_TIMEOUT, 1000).
-record(state, {
socket,
src_ip,
src_port,
packet,
dns_servers = []
}).
%%%===================================================================
%%% API
%%%===================================================================
start_link(Sock, Ip, Port, Packet) -> start_link(Sock, Ip, Port, Packet) ->
{ok, proc_lib:spawn(?MODULE, init, [Sock, Ip, Port, Packet])}. gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []).
init(Sock, Ip, Port, Packet) -> handle(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, handle).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
%% @private
%% @doc Initializes the server
-spec(init(Args :: term()) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore).
init([Sock, SrcIp, SrcPort, Packet]) ->
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers),
%%
erlang:start_timer(5000, self(), handler_max_ttl),
{ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
%% @private
%% @doc Handling call messages
-spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()},
State :: #state{}) ->
{reply, Reply :: term(), NewState :: #state{}} |
{reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} |
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: #state{}} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_call(_Request, _From, State = #state{}) ->
{reply, ok, State}.
%% @private
%% @doc Handling cast messages
-spec(handle_cast(Request :: term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
case dns:decode_message(Packet) of case dns:decode_message(Packet) of
Msg = #dns_message{qc = 1, questions = [Question|_]} -> Msg = #dns_message{qc = 1, questions = [Question|_]} ->
Qname = Question#dns_query.name, Qname = Question#dns_query.name,
@ -28,31 +87,70 @@ init(Sock, Ip, Port, Packet) ->
case dns_cache:lookup(Qname) of case dns_cache:lookup(Qname) of
{hit, R} -> {hit, R} ->
Resp = build_response(Msg, R), Resp = build_response(Msg, R),
gen_udp:send(Sock, Ip, Port, dns:encode_message(Resp)); gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal};
miss -> miss ->
lager:debug("[dns_handler] cache is miss"), lager:debug("[dns_handler] cache is miss"),
forward_to_upstream(Sock, Ip, Port, Packet) forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
%%
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}}
end; end;
Other -> Other ->
lager:warning("decode msg get error: ~p", [Other]), lager:warning("[] decode msg get error: ~p", [Other]),
exit(normal) {stop, normal}
end. end.
forward_to_upstream(Sock, SrcIp, SrcPort, Request) -> %% @private
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), %% @doc Handling all non call/cast messages
-spec(handle_info(Info :: timeout() | term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
%%
handle_info({timeout, _, trigger_next}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
forward_to_upstream(DnsIp, DnsPort, Packet),
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}};
handle_info({timeout, _, trigger_next}, State = #state{dns_servers = []}) ->
{stop, normal, State};
handle_info({timeout, _, handler_max_ttl}, State) ->
lager:debug("[dns_handler] reach the max ttl"),
{stop, normal, State};
%%
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort}) ->
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal, State}.
%% @private
%% @doc This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
-spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()),
State :: #state{}) -> term()).
terminate(_Reason, _State = #state{}) ->
ok.
%% @private
%% @doc Convert process state when code is changed
-spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{},
Extra :: term()) ->
{ok, NewState :: #state{}} | {error, Reason :: term()}).
code_change(_OldVsn, State = #state{}, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
forward_to_upstream(TargetIp, TargetPort, Request, Msg) ->
ReceiverPid = self(), ReceiverPid = self(),
Ref = make_ref(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> poolboy:transaction(?RESOLVER_POOL, fun(Pid) ->
dns_resolver:forward(Pid, ReceiverPid, Ref, TargetIp, TargetPort, Request) dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
end), end).
receive
{udp, SendSock, _UIp, _UPort, Resp} ->
gen_udp:send(Sock, Ip, Port, Resp)
after 2000 ->
ok
end.
build_response(Req, RR) -> build_response(Req, RR) ->
Msg = Req, Msg = Req,

View File

@ -1,151 +0,0 @@
%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 03. 12 2025 23:00
%%%-------------------------------------------------------------------
-module(dns_handler2).
-author("anlicheng").
-behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl").
%% API
-export([start_link/4]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle/1]).
-define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool).
%%
-define(UPSTREAM_TIMEOUT, 1000).
-record(state, {
socket,
src_ip,
src_port,
packet,
dns_servers = []
}).
%%%===================================================================
%%% API
%%%===================================================================
start_link(Sock, Ip, Port, Packet) ->
gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []).
handle(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, handle).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
%% @private
%% @doc Initializes the server
-spec(init(Args :: term()) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore).
init([Sock, SrcIp, SrcPort, Packet]) ->
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers),
{ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
%% @private
%% @doc Handling call messages
-spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()},
State :: #state{}) ->
{reply, Reply :: term(), NewState :: #state{}} |
{reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} |
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: #state{}} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_call(_Request, _From, State = #state{}) ->
{reply, ok, State}.
%% @private
%% @doc Handling cast messages
-spec(handle_cast(Request :: term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
case dns:decode_message(Packet) of
Msg = #dns_message{qc = 1, questions = [Question|_]} ->
Qname = Question#dns_query.name,
lager:debug("[dns_handler] qname: ~p", [Qname]),
case dns_cache:lookup(Qname) of
{hit, R} ->
Resp = build_response(Msg, R),
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal};
miss ->
lager:debug("[dns_handler] cache is miss"),
forward_to_upstream(DnsIp, DnsPort, Packet),
%%
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}}
end;
Other ->
lager:warning("[] decode msg get error: ~p", [Other]),
{stop, normal}
end.
%% @private
%% @doc Handling all non call/cast messages
-spec(handle_info(Info :: timeout() | term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
%%
handle_info({timeout, _, trigger_next}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
forward_to_upstream(DnsIp, DnsPort, Packet),
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}};
handle_info({timeout, _, trigger_next}, State = #state{dns_servers = []}) ->
{stop, normal, State};
%%
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort}) ->
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal, State}.
%% @private
%% @doc This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
-spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()),
State :: #state{}) -> term()).
terminate(_Reason, _State = #state{}) ->
ok.
%% @private
%% @doc Convert process state when code is changed
-spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{},
Extra :: term()) ->
{ok, NewState :: #state{}} | {error, Reason :: term()}).
code_change(_OldVsn, State = #state{}, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
forward_to_upstream(TargetIp, TargetPort, Request) ->
ReceiverPid = self(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) ->
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request)
end).
build_response(Req, RR) ->
Msg = Req,
Msg#dns_message{answers=[RR], qr=true, aa=true}.

View File

@ -18,9 +18,10 @@
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([forward/5]). -export([forward/6]).
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
-define(REQUEST_TTL, 5000).
-record(state, { -record(state, {
socket, socket,
@ -31,8 +32,8 @@
%%% API %%% API
%%%=================================================================== %%%===================================================================
forward(Pid, ReceiverPid, TargetIp, TargetPort, Request) -> forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) ->
gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request}). gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request, Msg}).
%% @doc Spawns the server and registers the local name (unique) %% @doc Spawns the server and registers the local name (unique)
-spec(start_link(Args :: list()) -> -spec(start_link(Args :: list()) ->
@ -75,14 +76,14 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request}, State = #state{socket = Socket, tid = Tid}) -> handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]}}, State = #state{socket = Socket, tid = Tid}) ->
case dns:decode_message(Request) of
#dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} ->
ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), ok = gen_udp:send(Socket, TargetIp, TargetPort, Request),
ok = ets:insert(Tid, {{TxId, TargetIp, TargetPort, QName}, ReceiverPid});
_ -> Key = {TxId, TargetIp, TargetPort, QName},
ok ok = ets:insert(Tid, {Key, ReceiverPid}),
end,
erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Key}),
{noreply, State}. {noreply, State}.
%% @private %% @private
@ -109,6 +110,11 @@ handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid,
_ -> _ ->
ok ok
end, end,
{noreply, State};
handle_info({timeout, _, {clean_ticker, Key}}, State = #state{tid = Tid}) ->
true = ets:delete(Tid, Key),
{noreply, State}. {noreply, State}.
%% @private %% @private

View File

@ -16,7 +16,12 @@ init() ->
loop(Sock) -> loop(Sock) ->
receive receive
{udp, Sock, Ip, Port, Packet} -> {udp, Sock, Ip, Port, Packet} ->
Res = dns_handler_sup:start_handler(Sock, Ip, Port, Packet), lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]),
lager:debug("[dns_server] ip: ~p, get a packet: ~p, handler res: ~p", [{Ip, Port}, Packet, Res]), case dns_handler_sup:start_handler(Sock, Ip, Port, Packet) of
{ok, HandlerPid} ->
dns_handler:handle(HandlerPid);
Error ->
lager:debug("[dns_server] start handler get error: ~p", [Error])
end,
loop(Sock) loop(Sock)
end. end.