From 5d41b45d6f41487ad5fc3c263c8fd51454df3532 Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Thu, 4 Dec 2025 14:29:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B9=B6=E5=8F=91=E8=AF=B7=E6=B1=82dns?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dns_proxy/src/dns_handler.erl | 86 +++++++++++------------------ apps/dns_proxy/src/dns_resolver.erl | 64 +++++++++++---------- 2 files changed, 67 insertions(+), 83 deletions(-) diff --git a/apps/dns_proxy/src/dns_handler.erl b/apps/dns_proxy/src/dns_handler.erl index 2bdfacb..90ca8e2 100644 --- a/apps/dns_proxy/src/dns_handler.erl +++ b/apps/dns_proxy/src/dns_handler.erl @@ -12,8 +12,6 @@ -behaviour(gen_server). -include_lib("dns_erlang/include/dns.hrl"). --include_lib("dns_erlang/include/dns_records.hrl"). --include_lib("dns_erlang/include/dns_terms.hrl"). -include("dns_proxy.hrl"). %% API @@ -25,16 +23,12 @@ -define(SERVER, ?MODULE). -define(RESOLVER_POOL, dns_resolver_pool). -%% 转发的超时设置 --define(UPSTREAM_TIMEOUT, 1000). -record(state, { socket, src_ip, src_port, - packet, - question, - dns_servers = [] + packet }). %%%=================================================================== @@ -57,10 +51,7 @@ handle(Pid) when is_pid(Pid) -> {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {stop, Reason :: term()} | ignore). init([Sock, SrcIp, SrcPort, Packet]) -> - {ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), - %% 进程的最大存活时间 - erlang:start_timer(5000, self(), handler_max_ttl), - {ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}. + {ok, #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}. %% @private %% @doc Handling call messages @@ -81,26 +72,36 @@ handle_call(_Request, _From, State = #state{}) -> {noreply, NewState :: #state{}} | {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), NewState :: #state{}}). -handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> +handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}) -> case dns:decode_message(Packet) of - Msg = #dns_message{qc = 1, questions = [Question|_]} -> - Qname = Question#dns_query.name, - lager:debug("[dns_handler] qname: ~p", [Qname]), + QueryMsg = #dns_message{qc = 1, questions = [Question|_]} -> case dns_cache:lookup(Question) of {hit, Cache} -> - lager:debug("[dns_handler] hit cache: ~p", [Cache]), - Resp = build_response(Msg, Cache), - gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), + lager:debug("[dns_handler] question: ~p, hit cache: ~p", [Question, Cache]), + RespMsg = build_response(QueryMsg, Cache), + gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(RespMsg)), {stop, normal, State}; miss -> lager:debug("[dns_handler] cache is miss"), - forward_to_upstream(DnsIp, DnsPort, Packet, Msg), - %% 开启定时器,超时后递归请求后面的服务 - erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}), - {noreply, State#state{dns_servers = RestDnsServers, question = Question}} + Ref = make_ref(), + forward_to_upstream(Ref, Packet, QueryMsg), + receive + {dns_resolver_reply, Ref, Resp} -> + case dns:decode_message(Resp) of + RespMsg = #dns_message{answers = Answers} -> + lager:debug("[dns_handler] get a response answers: ~p", [Answers]), + dns_cache:insert(Question, RespMsg), + gen_udp:send(Sock, SrcIp, SrcPort, Resp); + Other -> + lager:debug("[dns_handler] parse reply get error: ~p", [Other]) + end, + {stop, normal, State} + after 5000 -> + {stop, normal, State} + end end; Other -> - lager:warning("[] decode msg get error: ~p", [Other]), + lager:warning("[dns_handler] decode dns query get error: ~p", [Other]), {stop, normal, State} end. @@ -110,30 +111,8 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src {noreply, NewState :: #state{}} | {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), NewState :: #state{}}). -%% 处理超时重试 -handle_info({timeout, _, {trigger_next, Msg}}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> - forward_to_upstream(DnsIp, DnsPort, Packet, Msg), - erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}), - {noreply, State#state{dns_servers = RestDnsServers}}; -handle_info({timeout, _, {trigger_next, _}}, State = #state{dns_servers = []}) -> - {stop, normal, State}; - -handle_info({timeout, _, handler_max_ttl}, State) -> - lager:debug("[dns_handler] reach the max ttl"), - {stop, normal, State}; - -%% 收到请求 -handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, question = Question}) -> - %% 尝试解析 - case dns:decode_message(Resp) of - Msg = #dns_message{answers = Answers} -> - lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]), - dns_cache:insert(Question, Msg), - gen_udp:send(Sock, SrcIp, SrcPort, Resp); - Other -> - lager:debug("[dns_handler] parse reply get error: ~p", [Other]) - end, - {stop, normal, State}. +handle_info(_Info, State) -> + {noreply, State}. %% @private %% @doc This function is called by a gen_server when it is about to @@ -157,13 +136,13 @@ code_change(_OldVsn, State = #state{}, _Extra) -> %%% Internal functions %%%=================================================================== -forward_to_upstream(TargetIp, TargetPort, Request, Msg) -> +-spec forward_to_upstream(Ref :: reference(), Request :: binary(), QueryMsg :: #dns_message{}) -> no_return(). +forward_to_upstream(Ref, Request, QueryMsg) -> ReceiverPid = self(), - poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> - dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) - end). + poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> dns_resolver:forward(Pid, ReceiverPid, Ref, Request, QueryMsg) end). -build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) -> +-spec build_response(QueryMsg :: #dns_message{}, Dns_cache :: #dns_cache{}) -> RespMsg :: #dns_message{}. +build_response(QueryMsg, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional, rc = RCode, flags = #{aa := AA}}) -> Now = os:system_time(second), RemainingTTL = ExpireAt - Now, @@ -171,7 +150,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author Authority2 = [adjust_ttl(RR, RemainingTTL) || RR <- Authority], Additional2 = [adjust_ttl(RR, RemainingTTL) || RR <- Additional], - Query#dns_message{ + QueryMsg#dns_message{ qr = true, ra = true, aa = AA, @@ -184,6 +163,7 @@ build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, author additional = Additional2 }. +-spec adjust_ttl(RR :: any(), RemainingTTL :: integer()) -> any(). adjust_ttl(RR = #dns_rr{}, RemainingTTL) -> RR#dns_rr{ttl = max(0, RemainingTTL)}; adjust_ttl(RR, _RemainingTTL) -> diff --git a/apps/dns_proxy/src/dns_resolver.erl b/apps/dns_proxy/src/dns_resolver.erl index d3ca382..5783ed4 100644 --- a/apps/dns_proxy/src/dns_resolver.erl +++ b/apps/dns_proxy/src/dns_resolver.erl @@ -18,22 +18,23 @@ %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([forward/6]). +-export([forward/5]). -define(SERVER, ?MODULE). -define(REQUEST_TTL, 5000). -record(state, { socket, - tid + tid, + dns_servers = [] }). %%%=================================================================== %%% API %%%=================================================================== -forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) -> - gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request, Msg}). +forward(Pid, ReceiverPid, Ref, Request, QueryMsg) -> + gen_server:cast(Pid, {forward, ReceiverPid, Ref, Request, QueryMsg}). %% @doc Spawns the server and registers the local name (unique) -spec(start_link(Args :: list()) -> @@ -51,11 +52,13 @@ start_link(Args) when is_list(Args) -> {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {stop, Reason :: term()} | ignore). init([]) -> + {ok, DnsServers} = application:get_env(dns_proxy, public_dns_servers), + {ok, Sock} = gen_udp:open(0, [binary, {active, true}]), %% 通过ets来保存映射关系 Tid = ets:new(random_table(), [set, {read_concurrency, true}, {write_concurrency, true}, private]), - {ok, #state{socket = Sock, tid = Tid}}. + {ok, #state{socket = Sock, tid = Tid, dns_servers = DnsServers}}. %% @private %% @doc Handling call messages @@ -76,14 +79,14 @@ handle_call(_Request, _From, State = #state{}) -> {noreply, NewState :: #state{}} | {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), NewState :: #state{}}). -handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]}}, State = #state{socket = Socket, tid = Tid}) -> - lager:debug("[dns_resolver] forward request to: ~p", [{TargetIp, TargetPort}]), - ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), - - Key = {TxId, TargetIp, TargetPort, QName}, - true = ets:insert(Tid, {Key, ReceiverPid}), - - erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Key}), +handle_cast({forward, ReceiverPid, Ref, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]}}, State = #state{socket = Socket, tid = Tid, dns_servers = DnsServers}) -> + Keys = lists:foldl(fun({DnsIp, DnsPort}, Acc) -> + ok = gen_udp:send(Socket, DnsIp, DnsPort, Request), + Key = {TxId, DnsIp, DnsPort, QName, QType, QClass}, + true = ets:insert(Tid, {Key, Ref, ReceiverPid}), + [Key|Acc] + end, [], DnsServers), + erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Keys}), {noreply, State}. @@ -95,27 +98,16 @@ handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{i {stop, Reason :: term(), NewState :: #state{}}). handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, socket = Socket}) -> case dns:decode_message(Resp) of - #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} -> - Key = {TxId, TargetIp, TargetPort, QName}, - case ets:take(Tid, Key) of - [{_, ReceiverPid}] -> - case is_process_alive(ReceiverPid) of - true -> - ReceiverPid ! {dns_resolver_reply, Resp}; - false -> - ok - end; - [] -> - ok - end; + #dns_message{id = TxId, questions = [#dns_query{name = QName, type = QType, class = QClass}|_]} -> + Key = {TxId, TargetIp, TargetPort, QName, QName, QType, QClass}, + resolver_reply(ets:take(Tid, Key), Resp); _ -> ok end, {noreply, State}; -handle_info({timeout, _, {clean_ticker, Key}}, State = #state{tid = Tid}) -> - true = ets:delete(Tid, Key), - +handle_info({timeout, _, {clean_ticker, Keys}}, State = #state{tid = Tid}) -> + lists:foreach(fun(Key) -> ets:delete(Tid, Key) end, Keys), {noreply, State}. %% @private @@ -140,5 +132,17 @@ code_change(_OldVsn, State = #state{}, _Extra) -> %%% Internal functions %%%=================================================================== +-spec random_table() -> atom(). random_table() -> - list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))). \ No newline at end of file + list_to_atom("udp_ets:" ++ integer_to_list(erlang:unique_integer([monotonic, positive]))). + +-spec resolver_reply(list(), Resp :: binary()) -> no_return(). +resolver_reply([{_, Ref, ReceiverPid}], Resp) when is_binary(Resp) -> + case is_process_alive(ReceiverPid) of + true -> + ReceiverPid ! {dns_resolver_reply, Ref, Resp}; + false -> + ok + end; +resolver_reply(_, _) -> + ok. \ No newline at end of file