并发请求dns

This commit is contained in:
anlicheng 2025-12-04 14:29:58 +08:00
parent 10be6e6aeb
commit 5d41b45d6f
2 changed files with 67 additions and 83 deletions

View File

@ -12,8 +12,6 @@
-behaviour(gen_server). -behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl"). -include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl").
-include("dns_proxy.hrl"). -include("dns_proxy.hrl").
%% API %% API
@ -25,16 +23,12 @@
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool). -define(RESOLVER_POOL, dns_resolver_pool).
%%
-define(UPSTREAM_TIMEOUT, 1000).
-record(state, { -record(state, {
socket, socket,
src_ip, src_ip,
src_port, src_port,
packet, packet
question,
dns_servers = []
}). }).
%%%=================================================================== %%%===================================================================
@ -57,10 +51,7 @@ handle(Pid) when is_pid(Pid) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore). {stop, Reason :: term()} | ignore).
init([Sock, SrcIp, SrcPort, Packet]) -> init([Sock, SrcIp, SrcPort, Packet]) ->
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), {ok, #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
%%
erlang:start_timer(5000, self(), handler_max_ttl),
{ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
%% @private %% @private
%% @doc Handling call messages %% @doc Handling call messages
@ -81,26 +72,36 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}) ->
case dns:decode_message(Packet) of case dns:decode_message(Packet) of
Msg = #dns_message{qc = 1, questions = [Question|_]} -> QueryMsg = #dns_message{qc = 1, questions = [Question|_]} ->
Qname = Question#dns_query.name,
lager:debug("[dns_handler] qname: ~p", [Qname]),
case dns_cache:lookup(Question) of case dns_cache:lookup(Question) of
{hit, Cache} -> {hit, Cache} ->
lager:debug("[dns_handler] hit cache: ~p", [Cache]), lager:debug("[dns_handler] question: ~p, hit cache: ~p", [Question, Cache]),
Resp = build_response(Msg, Cache), RespMsg = build_response(QueryMsg, Cache),
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(RespMsg)),
{stop, normal, State}; {stop, normal, State};
miss -> miss ->
lager:debug("[dns_handler] cache is miss"), lager:debug("[dns_handler] cache is miss"),
forward_to_upstream(DnsIp, DnsPort, Packet, Msg), Ref = make_ref(),
%% forward_to_upstream(Ref, Packet, QueryMsg),
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}), receive
{noreply, State#state{dns_servers = RestDnsServers, question = Question}} {dns_resolver_reply, Ref, Resp} ->
case dns:decode_message(Resp) of
RespMsg = #dns_message{answers = Answers} ->
lager:debug("[dns_handler] get a response answers: ~p", [Answers]),
dns_cache:insert(Question, RespMsg),
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
Other ->
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
end,
{stop, normal, State}
after 5000 ->
{stop, normal, State}
end
end; end;
Other -> Other ->
lager:warning("[] decode msg get error: ~p", [Other]), lager:warning("[dns_handler] decode dns query get error: ~p", [Other]),
{stop, normal, State} {stop, normal, State}
end. end.
@ -110,30 +111,8 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
%% handle_info(_Info, State) ->
handle_info({timeout, _, {trigger_next, Msg}}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> {noreply, State}.
forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
{noreply, State#state{dns_servers = RestDnsServers}};
handle_info({timeout, _, {trigger_next, _}}, State = #state{dns_servers = []}) ->
{stop, normal, State};
handle_info({timeout, _, handler_max_ttl}, State) ->
lager:debug("[dns_handler] reach the max ttl"),
{stop, normal, State};
%%
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, question = Question}) ->
%%
case dns:decode_message(Resp) of
Msg = #dns_message{answers = Answers} ->
lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]),
dns_cache:insert(Question, Msg),
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
Other ->
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
end,
{stop, normal, State}.
%% @private %% @private
%% @doc This function is called by a gen_server when it is about to %% @doc This function is called by a gen_server when it is about to
@ -157,13 +136,13 @@ code_change(_OldVsn, State = #state{}, _Extra) ->
%%% Internal functions %%% Internal functions
%%%=================================================================== %%%===================================================================
forward_to_upstream(TargetIp, TargetPort, Request, Msg) -> -spec forward_to_upstream(Ref :: reference(), Request :: binary(), QueryMsg :: #dns_message{}) -> no_return().
forward_to_upstream(Ref, Request, QueryMsg) ->
ReceiverPid = self(), ReceiverPid = self(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> dns_resolver:forward(Pid, ReceiverPid, Ref, Request, QueryMsg) end).
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
end).
build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) -> -spec build_response(QueryMsg :: #dns_message{}, Dns_cache :: #dns_cache{}) -> RespMsg :: #dns_message{}.
build_response(QueryMsg, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) ->
Now = os:system_time(second), Now = os:system_time(second),
RemainingTTL = ExpireAt - Now, RemainingTTL = ExpireAt - Now,
@ -171,7 +150,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author
Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority], Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority],
Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional], Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional],
Query#dns_message{ QueryMsg#dns_message{
qr = true, qr = true,
ra = true, ra = true,
aa = AA, aa = AA,
@ -184,6 +163,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author
additional = Additional2 additional = Additional2
}. }.
-spec adjust_ttl(RR :: any(), RemainingTTL :: integer()) -> any().
adjust_ttl(RR = #dns_rr{}, RemainingTTL) -> adjust_ttl(RR = #dns_rr{}, RemainingTTL) ->
RR#dns_rr{ttl = max(0, RemainingTTL)}; RR#dns_rr{ttl = max(0, RemainingTTL)};
adjust_ttl(RR, _RemainingTTL) -> adjust_ttl(RR, _RemainingTTL) ->

View File

@ -18,22 +18,23 @@
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([forward/6]). -export([forward/5]).
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
-define(REQUEST_TTL, 5000). -define(REQUEST_TTL, 5000).
-record(state, { -record(state, {
socket, socket,
tid tid,
dns_servers = []
}). }).
%%%=================================================================== %%%===================================================================
%%% API %%% API
%%%=================================================================== %%%===================================================================
forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) -> forward(Pid, ReceiverPid, Ref, Request, QueryMsg) ->
gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request, Msg}). gen_server:cast(Pid, {forward, ReceiverPid, Ref, Request, QueryMsg}).
%% @doc Spawns the server and registers the local name (unique) %% @doc Spawns the server and registers the local name (unique)
-spec(start_link(Args :: list()) -> -spec(start_link(Args :: list()) ->
@ -51,11 +52,13 @@ start_link(Args) when is_list(Args) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore). {stop, Reason :: term()} | ignore).
init([]) -> init([]) ->
{ok, DnsServers} = application:get_env(dns_proxy, public_dns_servers),
{ok, Sock} = gen_udp:open(0, [binary, {active, true}]), {ok, Sock} = gen_udp:open(0, [binary, {active, true}]),
%% ets来保存映射关系 %% ets来保存映射关系
Tid = ets:new(random_table(), [set, {read_concurrency, true}, {write_concurrency, true}, private]), Tid = ets:new(random_table(), [set, {read_concurrency, true}, {write_concurrency, true}, private]),
{ok, #state{socket = Sock, tid = Tid}}. {ok, #state{socket = Sock, tid = Tid, dns_servers = DnsServers}}.
%% @private %% @private
%% @doc Handling call messages %% @doc Handling call messages
@ -76,14 +79,14 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]}}, State = #state{socket = Socket, tid = Tid}) -> handle_cast({forward, ReceiverPid, Ref, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]}}, State = #state{socket = Socket, tid = Tid, dns_servers = DnsServers}) ->
lager:debug("[dns_resolver] forward request to: ~p", [{TargetIp, TargetPort}]), Keys = lists:foldl(fun({DnsIp, DnsPort}, Acc) ->
ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), ok = gen_udp:send(Socket, DnsIp, DnsPort, Request),
Key = {TxId, DnsIp, DnsPort, QName, QType, QClass},
Key = {TxId, TargetIp, TargetPort, QName}, true = ets:insert(Tid, {Key, Ref, ReceiverPid}),
true = ets:insert(Tid, {Key, ReceiverPid}), [Key|Acc]
end, [], DnsServers),
erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Key}), erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Keys}),
{noreply, State}. {noreply, State}.
@ -95,27 +98,16 @@ handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{i
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, socket = Socket}) -> handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, socket = Socket}) ->
case dns:decode_message(Resp) of case dns:decode_message(Resp) of
#dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} -> #dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]} ->
Key = {TxId, TargetIp, TargetPort, QName}, Key = {TxId, TargetIp, TargetPort, QName, QName, QType, QClass},
case ets:take(Tid, Key) of resolver_reply(ets:take(Tid, Key), Resp);
[{_, ReceiverPid}] ->
case is_process_alive(ReceiverPid) of
true ->
ReceiverPid ! {dns_resolver_reply, Resp};
false ->
ok
end;
[] ->
ok
end;
_ -> _ ->
ok ok
end, end,
{noreply, State}; {noreply, State};
handle_info({timeout, _, {clean_ticker, Key}}, State = #state{tid = Tid}) -> handle_info({timeout, _, {clean_ticker, Keys}}, State = #state{tid = Tid}) ->
true = ets:delete(Tid, Key), lists:foreach(fun(Key) -> ets:delete(Tid, Key) end, Keys),
{noreply, State}. {noreply, State}.
%% @private %% @private
@ -140,5 +132,17 @@ code_change(_OldVsn, State = #state{}, _Extra) ->
%%% Internal functions %%% Internal functions
%%%=================================================================== %%%===================================================================
-spec random_table() -> atom().
random_table() -> random_table() ->
list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))). list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))).
-spec resolver_reply(list(), Resp :: binary()) -> no_return().
resolver_reply([{_, Ref, ReceiverPid}], Resp) when is_binary(Resp) ->
case is_process_alive(ReceiverPid) of
true ->
ReceiverPid ! {dns_resolver_reply, Ref, Resp};
false ->
ok
end;
resolver_reply(_, _) ->
ok.