diff --git a/apps/dns_proxy/src/dns_handler.erl b/apps/dns_proxy/src/dns_handler.erl index a6696cf..3d32395 100644 --- a/apps/dns_proxy/src/dns_handler.erl +++ b/apps/dns_proxy/src/dns_handler.erl @@ -4,23 +4,82 @@ %%% @doc %%% %%% @end -%%% Created : 03. 12月 2025 17:27 +%%% Created : 03. 12月 2025 23:00 %%%------------------------------------------------------------------- -module(dns_handler). -author("anlicheng"). +-behaviour(gen_server). + -include_lib("dns_erlang/include/dns.hrl"). -include_lib("dns_erlang/include/dns_records.hrl"). -include_lib("dns_erlang/include/dns_terms.hrl"). --define(RESOLVER_POOL, dns_resolver_pool). +%% API +-export([start_link/4]). --export([start_link/4, init/4]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). +-export([handle/1]). + +-define(SERVER, ?MODULE). +-define(RESOLVER_POOL, dns_resolver_pool). +%% 转发的超时设置 +-define(UPSTREAM_TIMEOUT, 1000). + +-record(state, { + socket, + src_ip, + src_port, + packet, + dns_servers = [] +}). + +%%%=================================================================== +%%% API +%%%=================================================================== start_link(Sock, Ip, Port, Packet) -> - {ok, proc_lib:spawn(?MODULE, init, [Sock, Ip, Port, Packet])}. + gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []). -init(Sock, Ip, Port, Packet) -> +handle(Pid) when is_pid(Pid) -> + gen_server:cast(Pid, handle). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== + +%% @private +%% @doc Initializes the server +-spec(init(Args :: term()) -> + {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | + {stop, Reason :: term()} | ignore). +init([Sock, SrcIp, SrcPort, Packet]) -> + {ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), + %% 进程的最大存活时间 + erlang:start_timer(5000, self(), handler_max_ttl), + {ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}. + +%% @private +%% @doc Handling call messages +-spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()}, + State :: #state{}) -> + {reply, Reply :: term(), NewState :: #state{}} | + {reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} | + {noreply, NewState :: #state{}} | + {noreply, NewState :: #state{}, timeout() | hibernate} | + {stop, Reason :: term(), Reply :: term(), NewState :: #state{}} | + {stop, Reason :: term(), NewState :: #state{}}). +handle_call(_Request, _From, State = #state{}) -> + {reply, ok, State}. + +%% @private +%% @doc Handling cast messages +-spec(handle_cast(Request :: term(), State :: #state{}) -> + {noreply, NewState :: #state{}} | + {noreply, NewState :: #state{}, timeout() | hibernate} | + {stop, Reason :: term(), NewState :: #state{}}). +handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> case dns:decode_message(Packet) of Msg = #dns_message{qc = 1, questions = [Question|_]} -> Qname = Question#dns_query.name, @@ -28,31 +87,70 @@ init(Sock, Ip, Port, Packet) -> case dns_cache:lookup(Qname) of {hit, R} -> Resp = build_response(Msg, R), - gen_udp:send(Sock, Ip, Port, dns:encode_message(Resp)); + gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), + {stop, normal}; miss -> lager:debug("[dns_handler] cache is miss"), - forward_to_upstream(Sock, Ip, Port, Packet) + forward_to_upstream(DnsIp, DnsPort, Packet, Msg), + %% 开启定时器,超时后递归请求后面的服务 + erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next), + {noreply, State#state{dns_servers = RestDnsServers}} end; Other -> - lager:warning("decode msg get error: ~p", [Other]), - exit(normal) + lager:warning("[] decode msg get error: ~p", [Other]), + {stop, normal} end. -forward_to_upstream(Sock, SrcIp, SrcPort, Request) -> - {ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), +%% @private +%% @doc Handling all non call/cast messages +-spec(handle_info(Info :: timeout() | term(), State :: #state{}) -> + {noreply, NewState :: #state{}} | + {noreply, NewState :: #state{}, timeout() | hibernate} | + {stop, Reason :: term(), NewState :: #state{}}). +%% 处理超时重试 +handle_info({timeout, _, trigger_next}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> + forward_to_upstream(DnsIp, DnsPort, Packet), + erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next), + {noreply, State#state{dns_servers = RestDnsServers}}; +handle_info({timeout, _, trigger_next}, State = #state{dns_servers = []}) -> + {stop, normal, State}; +handle_info({timeout, _, handler_max_ttl}, State) -> + lager:debug("[dns_handler] reach the max ttl"), + {stop, normal, State}; + +%% 收到请求 +handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort}) -> + gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), + {stop, normal, State}. + +%% @private +%% @doc This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any +%% necessary cleaning up. When it returns, the gen_server terminates +%% with Reason. The return value is ignored. +-spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()), + State :: #state{}) -> term()). +terminate(_Reason, _State = #state{}) -> + ok. + +%% @private +%% @doc Convert process state when code is changed +-spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{}, + Extra :: term()) -> + {ok, NewState :: #state{}} | {error, Reason :: term()}). +code_change(_OldVsn, State = #state{}, _Extra) -> + {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== + +forward_to_upstream(TargetIp, TargetPort, Request, Msg) -> ReceiverPid = self(), - Ref = make_ref(), poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> - dns_resolver:forward(Pid, ReceiverPid, Ref, TargetIp, TargetPort, Request) - end), - - receive - {udp, SendSock, _UIp, _UPort, Resp} -> - gen_udp:send(Sock, Ip, Port, Resp) - after 2000 -> - ok - end. + dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) + end). build_response(Req, RR) -> Msg = Req, diff --git a/apps/dns_proxy/src/dns_handler2.erl b/apps/dns_proxy/src/dns_handler2.erl deleted file mode 100644 index 64a9631..0000000 --- a/apps/dns_proxy/src/dns_handler2.erl +++ /dev/null @@ -1,151 +0,0 @@ -%%%------------------------------------------------------------------- -%%% @author anlicheng -%%% @copyright (C) 2025, -%%% @doc -%%% -%%% @end -%%% Created : 03. 12月 2025 23:00 -%%%------------------------------------------------------------------- --module(dns_handler2). --author("anlicheng"). - --behaviour(gen_server). - --include_lib("dns_erlang/include/dns.hrl"). --include_lib("dns_erlang/include/dns_records.hrl"). --include_lib("dns_erlang/include/dns_terms.hrl"). - -%% API --export([start_link/4]). - -%% gen_server callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([handle/1]). - --define(SERVER, ?MODULE). --define(RESOLVER_POOL, dns_resolver_pool). -%% 转发的超时设置 --define(UPSTREAM_TIMEOUT, 1000). - --record(state, { - socket, - src_ip, - src_port, - packet, - dns_servers = [] -}). - -%%%=================================================================== -%%% API -%%%=================================================================== - -start_link(Sock, Ip, Port, Packet) -> - gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []). - -handle(Pid) when is_pid(Pid) -> - gen_server:cast(Pid, handle). - -%%%=================================================================== -%%% gen_server callbacks -%%%=================================================================== - -%% @private -%% @doc Initializes the server --spec(init(Args :: term()) -> - {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | - {stop, Reason :: term()} | ignore). -init([Sock, SrcIp, SrcPort, Packet]) -> - {ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers), - {ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}. - -%% @private -%% @doc Handling call messages --spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()}, - State :: #state{}) -> - {reply, Reply :: term(), NewState :: #state{}} | - {reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} | - {noreply, NewState :: #state{}} | - {noreply, NewState :: #state{}, timeout() | hibernate} | - {stop, Reason :: term(), Reply :: term(), NewState :: #state{}} | - {stop, Reason :: term(), NewState :: #state{}}). -handle_call(_Request, _From, State = #state{}) -> - {reply, ok, State}. - -%% @private -%% @doc Handling cast messages --spec(handle_cast(Request :: term(), State :: #state{}) -> - {noreply, NewState :: #state{}} | - {noreply, NewState :: #state{}, timeout() | hibernate} | - {stop, Reason :: term(), NewState :: #state{}}). -handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> - case dns:decode_message(Packet) of - Msg = #dns_message{qc = 1, questions = [Question|_]} -> - Qname = Question#dns_query.name, - lager:debug("[dns_handler] qname: ~p", [Qname]), - case dns_cache:lookup(Qname) of - {hit, R} -> - Resp = build_response(Msg, R), - gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), - {stop, normal}; - miss -> - lager:debug("[dns_handler] cache is miss"), - forward_to_upstream(DnsIp, DnsPort, Packet), - %% 开启定时器,超时后递归请求后面的服务 - erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next), - {noreply, State#state{dns_servers = RestDnsServers}} - end; - Other -> - lager:warning("[] decode msg get error: ~p", [Other]), - {stop, normal} - end. - -%% @private -%% @doc Handling all non call/cast messages --spec(handle_info(Info :: timeout() | term(), State :: #state{}) -> - {noreply, NewState :: #state{}} | - {noreply, NewState :: #state{}, timeout() | hibernate} | - {stop, Reason :: term(), NewState :: #state{}}). -%% 处理超时重试 -handle_info({timeout, _, trigger_next}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) -> - forward_to_upstream(DnsIp, DnsPort, Packet), - erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next), - {noreply, State#state{dns_servers = RestDnsServers}}; -handle_info({timeout, _, trigger_next}, State = #state{dns_servers = []}) -> - {stop, normal, State}; - -%% 收到请求 -handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort}) -> - gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), - {stop, normal, State}. - -%% @private -%% @doc This function is called by a gen_server when it is about to -%% terminate. It should be the opposite of Module:init/1 and do any -%% necessary cleaning up. When it returns, the gen_server terminates -%% with Reason. The return value is ignored. --spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()), - State :: #state{}) -> term()). -terminate(_Reason, _State = #state{}) -> - ok. - -%% @private -%% @doc Convert process state when code is changed --spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{}, - Extra :: term()) -> - {ok, NewState :: #state{}} | {error, Reason :: term()}). -code_change(_OldVsn, State = #state{}, _Extra) -> - {ok, State}. - -%%%=================================================================== -%%% Internal functions -%%%=================================================================== - -forward_to_upstream(TargetIp, TargetPort, Request) -> - ReceiverPid = self(), - poolboy:transaction(?RESOLVER_POOL, fun(Pid) -> - dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request) - end). - -build_response(Req, RR) -> - Msg = Req, - Msg#dns_message{answers=[RR], qr=true, aa=true}. \ No newline at end of file diff --git a/apps/dns_proxy/src/dns_resolver.erl b/apps/dns_proxy/src/dns_resolver.erl index a776de5..99239d7 100644 --- a/apps/dns_proxy/src/dns_resolver.erl +++ b/apps/dns_proxy/src/dns_resolver.erl @@ -18,9 +18,10 @@ %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([forward/5]). +-export([forward/6]). -define(SERVER, ?MODULE). +-define(REQUEST_TTL, 5000). -record(state, { socket, @@ -31,8 +32,8 @@ %%% API %%%=================================================================== -forward(Pid, ReceiverPid, TargetIp, TargetPort, Request) -> - gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request}). +forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) -> + gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request, Msg}). %% @doc Spawns the server and registers the local name (unique) -spec(start_link(Args :: list()) -> @@ -75,14 +76,14 @@ handle_call(_Request, _From, State = #state{}) -> {noreply, NewState :: #state{}} | {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), NewState :: #state{}}). -handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request}, State = #state{socket = Socket, tid = Tid}) -> - case dns:decode_message(Request) of - #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} -> - ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), - ok = ets:insert(Tid, {{TxId, TargetIp, TargetPort, QName}, ReceiverPid}); - _ -> - ok - end, +handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request, #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]}}, State = #state{socket = Socket, tid = Tid}) -> + ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), + + Key = {TxId, TargetIp, TargetPort, QName}, + ok = ets:insert(Tid, {Key, ReceiverPid}), + + erlang:start_timer(?REQUEST_TTL, self(), {clean_ticker, Key}), + {noreply, State}. %% @private @@ -109,6 +110,11 @@ handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid, _ -> ok end, + {noreply, State}; + +handle_info({timeout, _, {clean_ticker, Key}}, State = #state{tid = Tid}) -> + true = ets:delete(Tid, Key), + {noreply, State}. %% @private diff --git a/apps/dns_proxy/src/dns_server.erl b/apps/dns_proxy/src/dns_server.erl index b144776..d231d11 100644 --- a/apps/dns_proxy/src/dns_server.erl +++ b/apps/dns_proxy/src/dns_server.erl @@ -16,7 +16,12 @@ init() -> loop(Sock) -> receive {udp, Sock, Ip, Port, Packet} -> - Res = dns_handler_sup:start_handler(Sock, Ip, Port, Packet), - lager:debug("[dns_server] ip: ~p, get a packet: ~p, handler res: ~p", [{Ip, Port}, Packet, Res]), + lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]), + case dns_handler_sup:start_handler(Sock, Ip, Port, Packet) of + {ok, HandlerPid} -> + dns_handler:handle(HandlerPid); + Error -> + lager:debug("[dns_server] start handler get error: ~p", [Error]) + end, loop(Sock) end. \ No newline at end of file