模块化,方便作为lib库

This commit is contained in:
anlicheng 2025-12-13 15:51:41 +08:00
parent 323c4c199d
commit f8882894f5
7 changed files with 124 additions and 76 deletions

View File

@ -41,9 +41,7 @@ insert(#dns_query{name = Qname, type = QType, class = QClass},
true -> true ->
TTL = lists:min(TTLs), TTL = lists:min(TTLs),
ExpireAt = os:system_time(second) + TTL, ExpireAt = os:system_time(second) + TTL,
lager:debug("min ttl is: ~p, expire_at: ~p", [TTL, ExpireAt]), lager:debug("min ttl is: ~p, expire_at: ~p", [TTL, ExpireAt]),
Key = {Qname, QType, QClass}, Key = {Qname, QType, QClass},
Cache = #dns_cache{ Cache = #dns_cache{
key = Key, key = Key,

View File

@ -20,7 +20,7 @@
%% gen_server callbacks %% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).
-export([handle/5, handle_ip_packet/5]). -export([handle_ip_packet/6]).
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
-define(RESOLVER_POOL, dns_resolver_pool). -define(RESOLVER_POOL, dns_resolver_pool).
@ -38,11 +38,8 @@
start_link() -> start_link() ->
gen_server:start_link(?MODULE, [], []). gen_server:start_link(?MODULE, [], []).
handle_ip_packet(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) -> handle_ip_packet(Pid, Sock, SrcIp, SrcPort, Packet, InbuiltResolver) when is_pid(Pid) ->
gen_server:cast(Pid, {handle_ip_packet, Sock, SrcIp, SrcPort, Packet}). gen_server:cast(Pid, {handle_ip_packet, Sock, SrcIp, SrcPort, Packet, InbuiltResolver}).
handle(Pid, Sock, SrcIp, SrcPort, Packet) when is_pid(Pid) ->
gen_server:cast(Pid, {handle, Sock, SrcIp, SrcPort, Packet}).
%%%=================================================================== %%%===================================================================
%%% gen_server callbacks %%% gen_server callbacks
@ -75,14 +72,15 @@ handle_call(_Request, _From, State = #state{}) ->
{noreply, NewState :: #state{}} | {noreply, NewState :: #state{}} |
{noreply, NewState :: #state{}, timeout() | hibernate} | {noreply, NewState :: #state{}, timeout() | hibernate} |
{stop, Reason :: term(), NewState :: #state{}}). {stop, Reason :: term(), NewState :: #state{}}).
handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket}, State) -> handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket, InbuiltResolver}, State) ->
{#ipv4{saddr = ReqSAddr, daddr = ReqDAddr, p = Protocol}, ReqIpPayload} = pkt:ipv4(IpPacket), {#ipv4{saddr = ReqSAddr, daddr = ReqDAddr, p = Protocol}, ReqIpPayload} = pkt:ipv4(IpPacket),
case Protocol =:= ?UDP_PROTOCOL of case Protocol =:= ?UDP_PROTOCOL of
true -> true ->
{#udp{sport = ReqSPort, dport = ReqDPort}, UdpPayload} = pkt:udp(ReqIpPayload), {#udp{sport = ReqSPort, dport = ReqDPort}, UdpPayload} = pkt:udp(ReqIpPayload),
case resolver(UdpPayload) of case resolver(UdpPayload, InbuiltResolver) of
{ok, DnsResp} -> {ok, DnsResp} ->
RespIpPacket = build_ip_packet(ReqDAddr, ReqSAddr, ReqDPort, ReqSPort, DnsResp), RespIpPacket = build_ip_packet(ReqDAddr, ReqSAddr, ReqDPort, ReqSPort, DnsResp),
lager:debug("[dns_handler] ip packet: ~p", [RespIpPacket]),
gen_udp:send(Sock, SrcIp, SrcPort, RespIpPacket); gen_udp:send(Sock, SrcIp, SrcPort, RespIpPacket);
{error, Reason} -> {error, Reason} ->
lager:debug("[dns_handler] resolver get error: ~p", [Reason]) lager:debug("[dns_handler] resolver get error: ~p", [Reason])
@ -90,14 +88,6 @@ handle_cast({handle_ip_packet, Sock, SrcIp, SrcPort, IpPacket}, State) ->
false -> false ->
lager:debug("[dns_handler] resolver invalid protocol: ~p", [Protocol]) lager:debug("[dns_handler] resolver invalid protocol: ~p", [Protocol])
end, end,
{stop, normal, State};
handle_cast({handle, Sock, SrcIp, SrcPort, Packet}, State) ->
case resolver(Packet) of
{ok, Resp} ->
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
{error, Reason} ->
lager:debug("[dns_handler] resolver get error: ~p", [Reason])
end,
{stop, normal, State}. {stop, normal, State}.
%% @private %% @private
@ -131,12 +121,12 @@ code_change(_OldVsn, State = #state{}, _Extra) ->
%%% Internal functions %%% Internal functions
%%%=================================================================== %%%===================================================================
-spec resolver(Packet :: binary()) -> {ok, Resp :: binary()} | {error, Reason :: any()}. -spec resolver(Packet :: binary(), InbuiltResolver :: tuple()) -> {ok, Resp :: binary()} | {error, Reason :: any()}.
resolver(Packet) when is_binary(Packet) -> resolver(Packet, InbuiltResolver) when is_binary(Packet) ->
resolver0(Packet, dns:decode_message(Packet)). resolver0(Packet, dns:decode_message(Packet), InbuiltResolver).
resolver0(Packet, QueryMsg = #dns_message{qc = 1, questions = [Question = #dns_query{name = QName, type = QType, class = QClass}|_]}) -> resolver0(Packet, QueryMsg = #dns_message{qc = 1, questions = [Question = #dns_query{name = QName, type = QType, class = QClass}|_]}, {M, F, A}) ->
%% %%
case search_inbuilt_domain(QName) of case erlang:apply(M, F, A ++ [QName]) of
{ok, Ip} -> {ok, Ip} ->
Answer = #dns_rr { Answer = #dns_rr {
name = QName, name = QName,
@ -157,6 +147,7 @@ resolver0(Packet, QueryMsg = #dns_message{qc = 1, questions = [Question = #dns_q
authority = [], authority = [],
additional = [] additional = []
}, },
lager:debug("[dns_handler] inbuilt qnanme: ~p, ip: ~p", [QName, Ip]),
{ok, dns:encode_message(RespMsg)}; {ok, dns:encode_message(RespMsg)};
error -> error ->
case dns_cache:lookup(Question) of case dns_cache:lookup(Question) of
@ -184,7 +175,7 @@ resolver0(Packet, QueryMsg = #dns_message{qc = 1, questions = [Question = #dns_q
end end
end end
end; end;
resolver0(_, Error) -> resolver0(_, Error, _InbuiltResolver) ->
lager:warning("[dns_handler] decode dns query get error: ~p", [Error]), lager:warning("[dns_handler] decode dns query get error: ~p", [Error]),
{error, Error}. {error, Error}.
@ -221,25 +212,16 @@ adjust_ttl(RR = #dns_rr{}, RemainingTTL) ->
adjust_ttl(RR, _RemainingTTL) -> adjust_ttl(RR, _RemainingTTL) ->
RR. RR.
search_inbuilt_domain(QName) when is_binary(QName) ->
Suffix = <<".iot.cn">>,
case dns_utils:ends_with(QName, Suffix) of
true ->
{ok, {192, 168, 1, 101}};
false ->
error
end.
-spec build_ip_packet(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), Payload :: binary()) -> IpPacket :: binary(). -spec build_ip_packet(SAddr :: inet:ip4_address(), DAddr :: inet:ip4_address(), SPort :: integer(), DPort :: integer(), Payload :: binary()) -> IpPacket :: binary().
build_ip_packet(SAddr, DAddr, SPort, DPort, Payload) when is_integer(SPort), is_integer(DPort), is_binary(Payload) -> build_ip_packet(SAddr, DAddr, SPort, DPort, UdpPayload) when is_integer(SPort), is_integer(DPort), is_binary(UdpPayload) ->
ULen = 8 + byte_size(Payload), ULen = 8 + byte_size(UdpPayload),
RespUdpHeader = pkt:udp(#udp{ RespUdpHeader = pkt:udp(#udp{
sport = SPort, sport = SPort,
dport = DPort, dport = DPort,
ulen = ULen, ulen = ULen,
sum = dns_utils:udp_checksum(SAddr, DAddr, SPort, DPort, Payload) sum = dns_utils:udp_checksum(SAddr, DAddr, SPort, DPort, UdpPayload)
}), }),
IpPayload = <<RespUdpHeader/binary, Payload/binary>>, IpPayload = <<RespUdpHeader/binary, UdpPayload/binary>>,
IpPacket0 = #ipv4{ IpPacket0 = #ipv4{
len = 20 + ULen, len = 20 + ULen,
@ -247,6 +229,7 @@ build_ip_packet(SAddr, DAddr, SPort, DPort, Payload) when is_integer(SPort), is_
off = 0, off = 0,
mf = 0, mf = 0,
sum = 0, sum = 0,
p = ?UDP_PROTOCOL,
saddr = SAddr, saddr = SAddr,
daddr = DAddr, daddr = DAddr,
opt = <<>> opt = <<>>

View File

@ -0,0 +1,19 @@
%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 13. 12 2025 15:20
%%%-------------------------------------------------------------------
-module(dns_proxy).
-author("anlicheng").
%% API
-export([start_proxy/2]).
start_proxy(Port, InbuiltResolver = {_M, _F, _A}) when is_integer(Port) ->
{ok, _} = dns_proxy_sup:start_resolver_pool(),
{ok, _} = dns_proxy_sup:start_handler_sup(),
{ok, _} = dns_proxy_sup:start_dns_server(Port, InbuiltResolver),
ok.

View File

@ -8,8 +8,8 @@
-behaviour(supervisor). -behaviour(supervisor).
-export([start_link/0]). -export([start_link/0]).
-export([init/1]). -export([init/1]).
-export([start_resolver_pool/0, start_handler_sup/0, start_dns_server/2]).
-define(SERVER, ?MODULE). -define(SERVER, ?MODULE).
@ -27,12 +27,24 @@ start_link() ->
%% modules => modules()} % optional %% modules => modules()} % optional
init([]) -> init([]) ->
SupFlags = #{strategy => one_for_one, intensity => 1000, period => 3600}, SupFlags = #{strategy => one_for_one, intensity => 1000, period => 3600},
{ok, {SupFlags, []}}.
%% internal functions
start_resolver_pool() ->
{ok, PoolArgs} = application:get_env(dns_proxy, dns_resolver_pool), {ok, PoolArgs} = application:get_env(dns_proxy, dns_resolver_pool),
ResolverPoolSpec = poolboy:child_spec(dns_resolver_pool, [{name, {local, dns_resolver_pool}}|PoolArgs], []), PoolSpec = poolboy:child_spec(dns_resolver_pool, [{name, {local, dns_resolver_pool}}|PoolArgs], []),
case supervisor:start_child(?MODULE, PoolSpec) of
{ok, Pid} ->
{ok, Pid};
{error, {already_started, Pid}} ->
{ok, Pid};
StartError ->
StartError
end.
ChildSpecs = [ start_handler_sup() ->
#{ Spec = #{
id => dns_handler_sup, id => dns_handler_sup,
start => {dns_handler_sup, start_link, []}, start => {dns_handler_sup, start_link, []},
restart => permanent, restart => permanent,
@ -40,15 +52,29 @@ init([]) ->
type => supervisor, type => supervisor,
modules => ['dns_handler_sup'] modules => ['dns_handler_sup']
}, },
#{ case supervisor:start_child(?MODULE, Spec) of
{ok, Pid} ->
{ok, Pid};
{error, {already_started, Pid}} ->
{ok, Pid};
StartError ->
StartError
end.
start_dns_server(Port, InbuiltResolver) when is_integer(Port) ->
Spec = #{
id => dns_server, id => dns_server,
start => {dns_server, start_link, []}, start => {dns_server, start_link, [Port, InbuiltResolver]},
restart => permanent, restart => permanent,
shutdown => 2000, shutdown => 2000,
type => worker, type => worker,
modules => ['dns_server'] modules => ['dns_server']
} },
], case supervisor:start_child(?MODULE, Spec) of
{ok, {SupFlags, [ResolverPoolSpec|ChildSpecs]}}. {ok, Pid} ->
{ok, Pid};
%% internal functions {error, {already_started, Pid}} ->
{ok, Pid};
StartError ->
StartError
end.

View File

@ -1,27 +1,24 @@
-module(dns_server). -module(dns_server).
-export([start_link/0, init/0]). -export([start_link/2, init/2]).
-define(LISTEN_PORT, 15353). start_link(Port, InbuiltResolver) when is_integer(Port) ->
{ok, spawn_link(?MODULE, init, [Port, InbuiltResolver])}.
start_link() -> init(Port, InbuiltResolver) ->
{ok, spawn_link(?MODULE, init, [])}.
init() ->
dns_cache:init(), dns_cache:init(),
%dns_zone_loader:load("priv/local.zone"), {ok, Sock} = gen_udp:open(Port, [binary, {active, true}]),
{ok, Sock} = gen_udp:open(?LISTEN_PORT, [binary, {active, true}]), io:format("DNS Forwarder started on UDP port ~p~n", [Port]),
io:format("DNS Forwarder started on UDP port ~p~n", [?LISTEN_PORT]), loop(Sock, InbuiltResolver).
loop(Sock).
loop(Sock) -> loop(Sock, InbuiltResolver) ->
receive receive
{udp, Sock, Ip, Port, Packet} -> {udp, Sock, Ip, Port, Packet} ->
lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]), lager:debug("[dns_server] ip: ~p, get a packet: ~p", [{Ip, Port}, Packet]),
case dns_handler_sup:start_handler() of case dns_handler_sup:start_handler() of
{ok, HandlerPid} -> {ok, HandlerPid} ->
dns_handler:handle_ip_packet(HandlerPid, Sock, Ip, Port, Packet); dns_handler:handle_ip_packet(HandlerPid, Sock, Ip, Port, Packet, InbuiltResolver);
Error -> Error ->
lager:debug("[dns_server] start handler get error: ~p", [Error]) lager:debug("[dns_server] start handler get error: ~p", [Error])
end, end,
loop(Sock) loop(Sock, InbuiltResolver)
end. end.

View File

@ -113,7 +113,8 @@ ip_checksum(#ipv4{hl = HL, tos = ToS, len = Len,
end. end.
test() -> test() ->
Bin = <<69,0,0,77,48,179,0,0,64,17,28,168,100,123,0,2,100,100,100,100,252,230,0,53,0,57,6,92,152,24,1,0,0,1,0,0,0,0,0,0,2,100,98,7,95,100,110,115,45,115,100,4,95,117,100,112,8,112,117,110,99,104,110,101,116,2,116,115,3,110,101,116,0,0,12,0,1>>, %Bin = <<69,0,0,77,48,179,0,0,64,17,28,168,100,123,0,2,100,100,100,100,252,230,0,53,0,57,6,92,152,24,1,0,0,1,0,0,0,0,0,0,2,100,98,7,95,100,110,115,45,115,100,4,95,117,100,112,8,112,117,110,99,104,110,101,116,2,116,115,3,110,101,116,0,0,12,0,1>>,
Bin = <<69,0,0,93,0,0,0,0,64,6,77,86,100,100,100,100,100,123,0,2,0,53,196,102,0,73,39,7,215,192,129,128,0,1,0,1,0,0,0,0,2,108,98,7,95,100,110,115,45,115,100,4,95,117,100,112,8,112,117,110,99,104,110,101,116,2,116,115,3,110,101,116,0,0,12,0,1,192,12,0,12,0,1,0,0,1,44,0,4,192,168,1,101>>,
{IPPacket = #ipv4{ {IPPacket = #ipv4{
saddr = SAddr, saddr = SAddr,
@ -125,6 +126,6 @@ test() ->
X = udp_checksum(SAddr, DAddr, SPort, DPort, UDPPayload), X = udp_checksum(SAddr, DAddr, SPort, DPort, UDPPayload),
lager:debug("ip_sum: ~p, y?: ~p, udp: ~p, checkSum: ~p, X is: ~p", [IpSum, ip_checksum(IPPacket), UDP, CheckSum, X]), lager:debug("ip_sum: ~p, =: ~p, udp: ~p, checkSum: ~p, =: ~p", [IpSum, ip_checksum(IPPacket), UDP, CheckSum, X]),
dns:decode_message(UDPPayload). dns:decode_message(UDPPayload).

View File

@ -0,0 +1,24 @@
%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 13. 12 2025 15:44
%%%-------------------------------------------------------------------
-module(inbuilt_dns_resolver).
-author("anlicheng").
%% API
-export([resolve/1]).
-spec resolve(QName :: binary()) -> {ok, IpAddr :: inet:ip4_address()} | error.
resolve(QName) when is_binary(QName) ->
Suffix = <<".punchnet.ts.net">>,
case dns_utils:ends_with(QName, Suffix) of
true ->
Ip4 = rand:uniform(254),
{ok, {192, 168, 1, Ip4}};
false ->
error
end.