This commit is contained in:
anlicheng 2025-12-03 23:29:15 +08:00
parent 1989c48ca0
commit 19f05de3a6
5 changed files with 174 additions and 46 deletions

View File

@ -13,6 +13,8 @@
-include_lib("dns_erlang/include/dns_records.hrl"). -include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl"). -include_lib("dns_erlang/include/dns_terms.hrl").
-define(RESOLVER_POOL, dns_resolver_pool).
-export([start_link/4, init/4]). -export([start_link/4, init/4]).
start_link(Sock, Ip, Port, Packet) -> start_link(Sock, Ip, Port, Packet) ->
@ -36,11 +38,15 @@ init(Sock, Ip, Port, Packet) ->
exit(normal) exit(normal)
end. end.
forward_to_upstream(Sock, Ip, Port, Packet) -> forward_to_upstream(Sock, SrcIp, SrcPort, Request) ->
{SendSock, {UpIP, UpPort}} = dns_socket_pool:get_socket(), {ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers),
ok = gen_udp:send(SendSock, UpIP, UpPort, Packet),
ReceiverPid = self(),
Ref = make_ref(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) ->
dns_resolver:forward(Pid, ReceiverPid, Ref, TargetIp, TargetPort, Request)
end),
inet:setopts(SendSock, [{active, once}]),
receive receive
{udp, SendSock, _UIp, _UPort, Resp} -> {udp, SendSock, _UIp, _UPort, Resp} ->
gen_udp:send(Sock, Ip, Port, Resp) gen_udp:send(Sock, Ip, Port, Resp)

View File

@ -0,0 +1,151 @@
%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 03. 12 2025 23:00
%%%-------------------------------------------------------------------
-module(dns_handler2).
-author("anlicheng").
-behaviour(gen_server).
-include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl").
%% API
-export([start_link/4]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle/1]).
-define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool).
%%
-define(UPSTREAM_TIMEOUT, 1000).
-record(state, {
socket,
src_ip,
src_port,
packet,
dns_servers = []
}).
%%%===================================================================
%%% API
%%%===================================================================
start_link(Sock, Ip, Port, Packet) ->
gen_server:start_link(?MODULE, [Sock, Ip, Port, Packet], []).
handle(Pid) when is_pid(Pid) ->
gen_server:cast(Pid, handle).
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
%% @private
%% @doc Initializes the server
-spec(init(Args :: term()) ->
{ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} |
{stop, Reason :: term()} | ignore).
init([Sock, SrcIp, SrcPort, Packet]) ->
{ok, DNSServers} = application:get_env(dns_proxy, public_dns_servers),
{ok, #state{dns_servers = DNSServers, socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet}}.
%% @private
%% @doc Handling call messages
-spec(handle_call(Request :: term(), From :: {pid(), Tag :: term()},
State :: #state{}) ->
{reply, Reply :: term(), NewState :: #state{}} |
{reply, Reply :: term(), NewState :: #state{}, timeout() | hibernate} |
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), Reply :: term(), NewState :: #state{}} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_call(_Request, _From, State = #state{}) ->
{reply, ok, State}.
%% @private
%% @doc Handling cast messages
-spec(handle_cast(Request :: term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
case dns:decode_message(Packet) of
Msg = #dns_message{qc = 1, questions = [Question|_]} ->
Qname = Question#dns_query.name,
lager:debug("[dns_handler] qname: ~p", [Qname]),
case dns_cache:lookup(Qname) of
{hit, R} ->
Resp = build_response(Msg, R),
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal};
miss ->
lager:debug("[dns_handler] cache is miss"),
forward_to_upstream(DnsIp, DnsPort, Packet),
%%
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}}
end;
Other ->
lager:warning("[] decode msg get error: ~p", [Other]),
{stop, normal}
end.
%% @private
%% @doc Handling all non call/cast messages
-spec(handle_info(Info :: timeout() | term(), State :: #state{}) ->
{noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}).
%%
handle_info({timeout, _, trigger_next}, State = #state{packet = Packet, dns_servers = [{DnsIp, DnsPort}|RestDnsServers]}) ->
forward_to_upstream(DnsIp, DnsPort, Packet),
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), trigger_next),
{noreply, State#state{dns_servers = RestDnsServers}};
handle_info({timeout, _, trigger_next}, State = #state{dns_servers = []}) ->
{stop, normal, State};
%%
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort}) ->
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal, State}.
%% @private
%% @doc This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
-spec(terminate(Reason :: (normal | shutdown | {shutdown, term()} | term()),
State :: #state{}) -> term()).
terminate(_Reason, _State = #state{}) ->
ok.
%% @private
%% @doc Convert process state when code is changed
-spec(code_change(OldVsn :: term() | {down, term()}, State :: #state{},
Extra :: term()) ->
{ok, NewState :: #state{}} | {error, Reason :: term()}).
code_change(_OldVsn, State = #state{}, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
forward_to_upstream(TargetIp, TargetPort, Request) ->
ReceiverPid = self(),
poolboy:transaction(?RESOLVER_POOL, fun(Pid) ->
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request)
end).
build_response(Req, RR) ->
Msg = Req,
Msg#dns_message{answers=[RR], qr=true, aa=true}.

View File

@ -18,7 +18,7 @@
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([forward/6]). -export([forward/5]).
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
@ -31,8 +31,8 @@
%%% API %%% API
%%%=================================================================== %%%===================================================================
forward(Pid, ReceiverPid, Ref, TargetIp, TargetPort, Request) -> forward(Pid, ReceiverPid, TargetIp, TargetPort, Request) ->
gen_server:cast(Pid, {forward, ReceiverPid, Ref, TargetIp, TargetPort, Request}). gen_server:cast(Pid, {forward, ReceiverPid, TargetIp, TargetPort, Request}).
%% @doc Spawns the server and registers the local name (unique) %% @doc Spawns the server and registers the local name (unique)
-spec(start_link(Args :: list()) -> -spec(start_link(Args :: list()) ->
@ -75,11 +75,11 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_cast({forward, ReceiverPid, Ref, TargetIp, TargetPort, Request}, State = #state{socket = Socket, tid = Tid}) -> handle_cast({forward, ReceiverPid, TargetIp, TargetPort, Request}, State = #state{socket = Socket, tid = Tid}) ->
case dns:decode_message(Request) of case dns:decode_message(Request) of
#dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} -> #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} ->
ok = gen_udp:send(Socket, TargetIp, TargetPort, Request), ok = gen_udp:send(Socket, TargetIp, TargetPort, Request),
ok = ets:insert(Tid, {{TxId, TargetIp, TargetPort, QName}, {Ref, ReceiverPid}}); ok = ets:insert(Tid, {{TxId, TargetIp, TargetPort, QName}, ReceiverPid});
_ -> _ ->
ok ok
end, end,
@ -96,8 +96,13 @@ handle_info({udp, Socket, TargetIp, TargetPort, Resp}, State = #state{tid = Tid,
#dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} -> #dns_message{id = TxId, questions = [#dns_query{name = QName}|_]} ->
Key = {TxId, TargetIp, TargetPort, QName}, Key = {TxId, TargetIp, TargetPort, QName},
case ets:take(Tid, Key) of case ets:take(Tid, Key) of
[{_, {Ref, ReceiverPid}}] -> [{_, ReceiverPid}] ->
ReceiverPid ! {xyz, Ref, Resp}; case is_process_alive(ReceiverPid) of
true ->
ReceiverPid ! {dns_resolver_reply, Resp};
false ->
ok
end;
[] -> [] ->
ok ok
end; end;

View File

@ -9,14 +9,6 @@ start_link() ->
init() -> init() ->
dns_cache:init(), dns_cache:init(),
%dns_zone_loader:load("priv/local.zone"), %dns_zone_loader:load("priv/local.zone"),
%% DNS
Upstreams = [
{{8,8,8,8}, 53},
{{1,1,1,1}, 53}
],
dns_socket_pool:start_link(Upstreams),
{ok, Sock} = gen_udp:open(?LISTEN_PORT, [binary, {active, true}]), {ok, Sock} = gen_udp:open(?LISTEN_PORT, [binary, {active, true}]),
io:format("DNS Forwarder started on UDP port ~p~n", [?LISTEN_PORT]), io:format("DNS Forwarder started on UDP port ~p~n", [?LISTEN_PORT]),
loop(Sock). loop(Sock).
@ -25,6 +17,6 @@ loop(Sock) ->
receive receive
{udp, Sock, Ip, Port, Packet} -> {udp, Sock, Ip, Port, Packet} ->
Res = dns_handler_sup:start_handler(Sock, Ip, Port, Packet), Res = dns_handler_sup:start_handler(Sock, Ip, Port, Packet),
lager:debug("ip: ~p, get a packet: ~p, handler res: ~p", [{Ip, Port}, Packet, Res]), lager:debug("[dns_server] ip: ~p, get a packet: ~p, handler res: ~p", [{Ip, Port}, Packet, Res]),
loop(Sock) loop(Sock)
end. end.

View File

@ -1,26 +0,0 @@
-module(dns_socket_pool).
-behaviour(gen_server).
-export([start_link/1, get_socket/0]).
-export([init/1, handle_call/3]).
-record(state, {sockets=[]}).
start_link(Ups) ->
gen_server:start_link({local, ?MODULE}, ?MODULE, Ups, []).
get_socket() ->
gen_server:call(?MODULE, get).
init(Upstreams) ->
%% DNS socket
Sockets = lists:map(fun({IP, Port}) ->
{ok, Sock} = gen_udp:open(0, [binary, {active, false}]),
{Sock, {IP, Port}}
end, Upstreams),
{ok, #state{sockets=Sockets}}.
handle_call(get, _From, State=#state{sockets=Socks}) ->
%% round-robin
[H|T] = Socks,
{reply, H, State#state{sockets=T ++ [H]}}.