并发请求dns
This commit is contained in:
parent
10be6e6aeb
commit
5d41b45d6f
@ -12,8 +12,6 @@
|
|||||||
-behaviour(gen_server).
|
-behaviour(gen_server).
|
||||||
|
|
||||||
-include_lib("dns_erlang/include/dns.hrl").
|
-include_lib("dns_erlang/include/dns.hrl").
|
||||||
-include_lib("dns_erlang/include/dns_records.hrl").
|
|
||||||
-include_lib("dns_erlang/include/dns_terms.hrl").
|
|
||||||
-include("dns_proxy.hrl").
|
-include("dns_proxy.hrl").
|
||||||
|
|
||||||
%% API
|
%% API
|
||||||
@ -25,16 +23,12 @@
|
|||||||
|
|
||||||
-define(SERVER, ?MODULE).
|
-define(SERVER, ?MODULE).
|
||||||
-define(RESOLVER_POOL, dns_resolver_pool).
|
-define(RESOLVER_POOL, dns_resolver_pool).
|
||||||
%% 转发的超时设置
|
|
||||||
-define(UPSTREAM_TIMEOUT, 1000).
|
|
||||||
|
|
||||||
-record(state, {
|
-record(state, {
|
||||||
socket,
|
socket,
|
||||||
src_ip,
|
src_ip,
|
||||||
src_port,
|
src_port,
|
||||||
packet,
|
packet
|
||||||
question,
|
|
||||||
dns_servers = []
|
|
||||||
}).
|
}).
|
||||||
|
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
@ -57,10 +51,7 @@ handle(Pid) when is_pid(Pid) ->
|
|||||||
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
|
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
|
||||||
{stop, Reason :: term()} | ignore).
|
{stop, Reason :: term()} | ignore).
|
||||||
init([Sock, SrcIp, SrcPort, Packet]) ->
|
init([Sock, SrcIp, SrcPort, Packet]) ->
|
||||||
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers),
|
{ok, #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
|
||||||
%% 进程的最大存活时间
|
|
||||||
erlang:start_timer(5000, self(), handler_max_ttl),
|
|
||||||
{ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
|
|
||||||
|
|
||||||
%% @private
|
%% @private
|
||||||
%% @doc Handling call messages
|
%% @doc Handling call messages
|
||||||
@ -81,26 +72,36 @@ handle_call(_Request, _From, State = #state{}) ->
|
|||||||
{noreply, NewState :: #state{}} |
|
{noreply, NewState :: #state{}} |
|
||||||
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
||||||
{stop, Reason :: term(), NewState :: #state{}}).
|
{stop, Reason :: term(), NewState :: #state{}}).
|
||||||
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
|
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}) ->
|
||||||
case dns:decode_message(Packet) of
|
case dns:decode_message(Packet) of
|
||||||
Msg = #dns_message{qc = 1, questions = [Question|_]} ->
|
QueryMsg = #dns_message{qc = 1, questions = [Question|_]} ->
|
||||||
Qname = Question#dns_query.name,
|
|
||||||
lager:debug("[dns_handler] qname: ~p", [Qname]),
|
|
||||||
case dns_cache:lookup(Question) of
|
case dns_cache:lookup(Question) of
|
||||||
{hit, Cache} ->
|
{hit, Cache} ->
|
||||||
lager:debug("[dns_handler] hit cache: ~p", [Cache]),
|
lager:debug("[dns_handler] question: ~p, hit cache: ~p", [Question, Cache]),
|
||||||
Resp = build_response(Msg, Cache),
|
RespMsg = build_response(QueryMsg, Cache),
|
||||||
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
|
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(RespMsg)),
|
||||||
{stop, normal, State};
|
{stop, normal, State};
|
||||||
miss ->
|
miss ->
|
||||||
lager:debug("[dns_handler] cache is miss"),
|
lager:debug("[dns_handler] cache is miss"),
|
||||||
forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
|
Ref = make_ref(),
|
||||||
%% 开启定时器,超时后递归请求后面的服务
|
forward_to_upstream(Ref, Packet, QueryMsg),
|
||||||
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
|
receive
|
||||||
{noreply, State#state{dns_servers = RestDnsServers, question = Question}}
|
{dns_resolver_reply, Ref, Resp} ->
|
||||||
|
case dns:decode_message(Resp) of
|
||||||
|
RespMsg = #dns_message{answers = Answers} ->
|
||||||
|
lager:debug("[dns_handler] get a response answers: ~p", [Answers]),
|
||||||
|
dns_cache:insert(Question, RespMsg),
|
||||||
|
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
|
||||||
|
Other ->
|
||||||
|
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
|
||||||
|
end,
|
||||||
|
{stop, normal, State}
|
||||||
|
after 5000 ->
|
||||||
|
{stop, normal, State}
|
||||||
|
end
|
||||||
end;
|
end;
|
||||||
Other ->
|
Other ->
|
||||||
lager:warning("[] decode msg get error: ~p", [Other]),
|
lager:warning("[dns_handler] decode dns query get error: ~p", [Other]),
|
||||||
{stop, normal, State}
|
{stop, normal, State}
|
||||||
end.
|
end.
|
||||||
|
|
||||||
@ -110,30 +111,8 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
|
|||||||
{noreply, NewState :: #state{}} |
|
{noreply, NewState :: #state{}} |
|
||||||
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
||||||
{stop, Reason :: term(), NewState :: #state{}}).
|
{stop, Reason :: term(), NewState :: #state{}}).
|
||||||
%% 处理超时重试
|
handle_info(_Info, State) ->
|
||||||
handle_info({timeout, _, {trigger_next, Msg}}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
|
{noreply, State}.
|
||||||
forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
|
|
||||||
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
|
|
||||||
{noreply, State#state{dns_servers = RestDnsServers}};
|
|
||||||
handle_info({timeout, _, {trigger_next, _}}, State = #state{dns_servers = []}) ->
|
|
||||||
{stop, normal, State};
|
|
||||||
|
|
||||||
handle_info({timeout, _, handler_max_ttl}, State) ->
|
|
||||||
lager:debug("[dns_handler] reach the max ttl"),
|
|
||||||
{stop, normal, State};
|
|
||||||
|
|
||||||
%% 收到请求
|
|
||||||
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, question = Question}) ->
|
|
||||||
%% 尝试解析
|
|
||||||
case dns:decode_message(Resp) of
|
|
||||||
Msg = #dns_message{answers = Answers} ->
|
|
||||||
lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]),
|
|
||||||
dns_cache:insert(Question, Msg),
|
|
||||||
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
|
|
||||||
Other ->
|
|
||||||
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
|
|
||||||
end,
|
|
||||||
{stop, normal, State}.
|
|
||||||
|
|
||||||
%% @private
|
%% @private
|
||||||
%% @doc This function is called by a gen_server when it is about to
|
%% @doc This function is called by a gen_server when it is about to
|
||||||
@ -157,13 +136,13 @@ code_change(_OldVsn, State = #state{}, _Extra) ->
|
|||||||
%%% Internal functions
|
%%% Internal functions
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
|
|
||||||
forward_to_upstream(TargetIp, TargetPort, Request, Msg) ->
|
-spec forward_to_upstream(Ref :: reference(), Request :: binary(), QueryMsg :: #dns_message{}) -> no_return().
|
||||||
|
forward_to_upstream(Ref, Request, QueryMsg) ->
|
||||||
ReceiverPid = self(),
|
ReceiverPid = self(),
|
||||||
poolboy:transaction(?RESOLVER_POOL, fun(Pid) ->
|
poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> dns_resolver:forward(Pid, ReceiverPid, Ref, Request, QueryMsg) end).
|
||||||
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
|
|
||||||
end).
|
|
||||||
|
|
||||||
build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) ->
|
-spec build_response(QueryMsg :: #dns_message{}, Dns_cache :: #dns_cache{}) -> RespMsg :: #dns_message{}.
|
||||||
|
build_response(QueryMsg, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) ->
|
||||||
Now = os:system_time(second),
|
Now = os:system_time(second),
|
||||||
RemainingTTL = ExpireAt - Now,
|
RemainingTTL = ExpireAt - Now,
|
||||||
|
|
||||||
@ -171,7 +150,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author
|
|||||||
Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority],
|
Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority],
|
||||||
Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional],
|
Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional],
|
||||||
|
|
||||||
Query#dns_message{
|
QueryMsg#dns_message{
|
||||||
qr = true,
|
qr = true,
|
||||||
ra = true,
|
ra = true,
|
||||||
aa = AA,
|
aa = AA,
|
||||||
@ -184,6 +163,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author
|
|||||||
additional = Additional2
|
additional = Additional2
|
||||||
}.
|
}.
|
||||||
|
|
||||||
|
-spec adjust_ttl(RR :: any(), RemainingTTL :: integer()) -> any().
|
||||||
adjust_ttl(RR = #dns_rr{}, RemainingTTL) ->
|
adjust_ttl(RR = #dns_rr{}, RemainingTTL) ->
|
||||||
RR#dns_rr{ttl = max(0, RemainingTTL)};
|
RR#dns_rr{ttl = max(0, RemainingTTL)};
|
||||||
adjust_ttl(RR, _RemainingTTL) ->
|
adjust_ttl(RR, _RemainingTTL) ->
|
||||||
|
|||||||
@ -18,22 +18,23 @@
|
|||||||
%% gen_server callbacks
|
%% gen_server callbacks
|
||||||
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
|
||||||
|
|
||||||
-export([forward/6]).
|
-export([forward/5]).
|
||||||
|
|
||||||
-define(SERVER, ?MODULE).
|
-define(SERVER, ?MODULE).
|
||||||
-define(REQUEST_TTL, 5000).
|
-define(REQUEST_TTL, 5000).
|
||||||
|
|
||||||
-record(state, {
|
-record(state, {
|
||||||
socket,
|
socket,
|
||||||
tid
|
tid,
|
||||||
|
dns_servers = []
|
||||||
}).
|
}).
|
||||||
|
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
%%% API
|
%%% API
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
|
|
||||||
forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) ->
|
forward(Pid, ReceiverPid, Ref, Request, QueryMsg) ->
|
||||||
gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request, Msg}).
|
gen_server:cast(Pid, {forward, ReceiverPid, Ref, Request, QueryMsg}).
|
||||||
|
|
||||||
%% @doc Spawns the server and registers the local name (unique)
|
%% @doc Spawns the server and registers the local name (unique)
|
||||||
-spec(start_link(Args :: list()) ->
|
-spec(start_link(Args :: list()) ->
|
||||||
@ -51,11 +52,13 @@ start_link(Args) when is_list(Args) ->
|
|||||||
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
|
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
|
||||||
{stop, Reason :: term()} | ignore).
|
{stop, Reason :: term()} | ignore).
|
||||||
init([]) ->
|
init([]) ->
|
||||||
|
{ok, DnsServers} = application:get_env(dns_proxy, public_dns_servers),
|
||||||
|
|
||||||
{ok, Sock} = gen_udp:open(0, [binary, {active, true}]),
|
{ok, Sock} = gen_udp:open(0, [binary, {active, true}]),
|
||||||
%% 通过ets来保存映射关系
|
%% 通过ets来保存映射关系
|
||||||
Tid = ets:new(random_table(), [set, {read_concurrency, true}, {write_concurrency, true}, private]),
|
Tid = ets:new(random_table(), [set, {read_concurrency, true}, {write_concurrency, true}, private]),
|
||||||
|
|
||||||
{ok, #state{socket = Sock, tid = Tid}}.
|
{ok, #state{socket = Sock, tid = Tid, dns_servers = DnsServers}}.
|
||||||
|
|
||||||
%% @private
|
%% @private
|
||||||
%% @doc Handling call messages
|
%% @doc Handling call messages
|
||||||
@ -76,14 +79,14 @@ handle_call(_Request, _From, State = #state{}) ->
|
|||||||
{noreply, NewState :: #state{}} |
|
{noreply, NewState :: #state{}} |
|
||||||
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
{noreply, NewState :: #state{}, timeout() | hibernate} |
|
||||||
{stop, Reason :: term(), NewState :: #state{}}).
|
{stop, Reason :: term(), NewState :: #state{}}).
|
||||||
handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]}}, State = #state{socket = Socket, tid = Tid}) ->
|
handle_cast({forward, ReceiverPid, Ref, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]}}, State = #state{socket = Socket, tid = Tid, dns_servers = DnsServers}) ->
|
||||||
lager:debug("[dns_resolver] forward request to: ~p", [{TargetIp, TargetPort}]),
|
Keys = lists:foldl(fun({DnsIp, DnsPort}, Acc) ->
|
||||||
ok = gen_udp:send(Socket, TargetIp, TargetPort, Request),
|
ok = gen_udp:send(Socket, DnsIp, DnsPort, Request),
|
||||||
|
Key = {TxId, DnsIp, DnsPort, QName, QType, QClass},
|
||||||
Key = {TxId, TargetIp, TargetPort, QName},
|
true = ets:insert(Tid, {Key, Ref, ReceiverPid}),
|
||||||
true = ets:insert(Tid, {Key, ReceiverPid}),
|
[Key|Acc]
|
||||||
|
end, [], DnsServers),
|
||||||
erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Key}),
|
erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Keys}),
|
||||||
|
|
||||||
{noreply, State}.
|
{noreply, State}.
|
||||||
|
|
||||||
@ -95,27 +98,16 @@ handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{i
|
|||||||
{stop, Reason :: term(), NewState :: #state{}}).
|
{stop, Reason :: term(), NewState :: #state{}}).
|
||||||
handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, socket = Socket}) ->
|
handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, socket = Socket}) ->
|
||||||
case dns:decode_message(Resp) of
|
case dns:decode_message(Resp) of
|
||||||
#dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} ->
|
#dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]} ->
|
||||||
Key = {TxId, TargetIp, TargetPort, QName},
|
Key = {TxId, TargetIp, TargetPort, QName, QName, QType, QClass},
|
||||||
case ets:take(Tid, Key) of
|
resolver_reply(ets:take(Tid, Key), Resp);
|
||||||
[{_, ReceiverPid}] ->
|
|
||||||
case is_process_alive(ReceiverPid) of
|
|
||||||
true ->
|
|
||||||
ReceiverPid ! {dns_resolver_reply, Resp};
|
|
||||||
false ->
|
|
||||||
ok
|
|
||||||
end;
|
|
||||||
[] ->
|
|
||||||
ok
|
|
||||||
end;
|
|
||||||
_ ->
|
_ ->
|
||||||
ok
|
ok
|
||||||
end,
|
end,
|
||||||
{noreply, State};
|
{noreply, State};
|
||||||
|
|
||||||
handle_info({timeout, _, {clean_ticker, Key}}, State = #state{tid = Tid}) ->
|
handle_info({timeout, _, {clean_ticker, Keys}}, State = #state{tid = Tid}) ->
|
||||||
true = ets:delete(Tid, Key),
|
lists:foreach(fun(Key) -> ets:delete(Tid, Key) end, Keys),
|
||||||
|
|
||||||
{noreply, State}.
|
{noreply, State}.
|
||||||
|
|
||||||
%% @private
|
%% @private
|
||||||
@ -140,5 +132,17 @@ code_change(_OldVsn, State = #state{}, _Extra) ->
|
|||||||
%%% Internal functions
|
%%% Internal functions
|
||||||
%%%===================================================================
|
%%%===================================================================
|
||||||
|
|
||||||
|
-spec random_table() -> atom().
|
||||||
random_table() ->
|
random_table() ->
|
||||||
list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))).
|
list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))).
|
||||||
|
|
||||||
|
-spec resolver_reply(list(), Resp :: binary()) -> no_return().
|
||||||
|
resolver_reply([{_, Ref, ReceiverPid}], Resp) when is_binary(Resp) ->
|
||||||
|
case is_process_alive(ReceiverPid) of
|
||||||
|
true ->
|
||||||
|
ReceiverPid ! {dns_resolver_reply, Ref, Resp};
|
||||||
|
false ->
|
||||||
|
ok
|
||||||
|
end;
|
||||||
|
resolver_reply(_, _) ->
|
||||||
|
ok.
|
||||||
Loading…
x
Reference in New Issue
Block a user