fix cache

This commit is contained in:
anlicheng 2025-12-04 12:28:19 +08:00
parent 28355c1df0
commit 2a1e39b16e
4 changed files with 150 additions and 18 deletions

View File

@ -0,0 +1,21 @@
%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2025, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 04. 12 2025 11:41
%%%-------------------------------------------------------------------
-author("anlicheng").
-record(dns_cache, {
%% {Qname, QType, QClass}
key,
answers = [],
authority = [],
additional = [],
rcode :: integer(),
flags = #{},
% unix time
expire_at :: integer()
}).

View File

@ -1,17 +1,48 @@
-module(dns_cache). -module(dns_cache).
-include_lib("dns_proxy.hrl").
-include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl").
-export([init/0, lookup/1, insert/2]). -export([init/0, lookup/1, insert/2]).
-define(TABLE, dns_cache). -define(TABLE, dns_cache).
init() -> init() ->
ets:new(?TABLE, [named_table, set, public, {read_concurrency, true}]). ets:new(?TABLE, [named_table, set, public, {keypos, 2}, {read_concurrency, true}]).
lookup(Key) -> lookup(#dns_query{name = Qname, type = QType, class = QClass}) ->
Key = {Qname, QType, QClass},
case ets:lookup(?TABLE, Key) of case ets:lookup(?TABLE, Key) of
[{_Key, Value}] -> {hit, Value}; [Cache = #dns_cache{expire_at = ExpireAt}] ->
[] -> miss Now = os:system_time(second),
case ExpireAt > Now of
true ->
{hit, Cache};
false ->
true = ets:delete(?TABLE, Key),
miss
end;
[] ->
miss
end. end.
insert(Key, DNSMsg) -> insert(#dns_query{name = Qname, type = QType, class = QClass},
ets:insert(?TABLE, {Key, DNSMsg}), #dns_message{answers = Answers, authority = Authority, additional = Additional, aa = AA}) ->
ok.
TTLs = [RR#dns_rr.ttl || RR <- Answers] ++ [RR#dns_rr.ttl || RR <- Authority] ++ [RR#dns_rr.ttl || RR <- Additional],
TTL = lists:min(TTLs),
ExpireAt = os:system_time(second) + TTL,
Key = {Qname, QType, QClass},
Cache = #dns_cache{
key = Key,
answers = Answers,
authority = Authority,
additional = Additional,
rcode = 0,
flags = #{aa => AA},
% unix time
expire_at = ExpireAt
},
true = ets:insert(?TABLE, Cache).

View File

@ -14,6 +14,7 @@
-include_lib("dns_erlang/include/dns.hrl"). -include_lib("dns_erlang/include/dns.hrl").
-include_lib("dns_erlang/include/dns_records.hrl"). -include_lib("dns_erlang/include/dns_records.hrl").
-include_lib("dns_erlang/include/dns_terms.hrl"). -include_lib("dns_erlang/include/dns_terms.hrl").
-include("dns_proxy.hrl").
%% API %% API
-export([start_link/4]). -export([start_link/4]).
@ -32,7 +33,7 @@
src_ip, src_ip,
src_port, src_port,
packet, packet,
qname, question,
dns_servers = [] dns_servers = []
}). }).
@ -85,10 +86,10 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
Msg = #dns_message{qc = 1, questions = [Question|_]} -> Msg = #dns_message{qc = 1, questions = [Question|_]} ->
Qname = Question#dns_query.name, Qname = Question#dns_query.name,
lager:debug("[dns_handler] qname: ~p", [Qname]), lager:debug("[dns_handler] qname: ~p", [Qname]),
case dns_cache:lookup(Qname) of case dns_cache:lookup(Question) of
{hit, R} -> {hit, Cache} ->
lager:debug("[dns_handler] hit cache rr: ~p", [R]), lager:debug("[dns_handler] hit cache: ~p", [Cache]),
Resp = build_response(Msg, R), Resp = build_response(Msg, Cache),
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)), gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
{stop, normal, State}; {stop, normal, State};
miss -> miss ->
@ -96,7 +97,7 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
forward_to_upstream(DnsIp, DnsPort, Packet, Msg), forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
%% %%
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}), erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
{noreply, State#state{dns_servers = RestDnsServers, qname = Qname}} {noreply, State#state{dns_servers = RestDnsServers, question = Question}}
end; end;
Other -> Other ->
lager:warning("[] decode msg get error: ~p", [Other]), lager:warning("[] decode msg get error: ~p", [Other]),
@ -122,12 +123,12 @@ handle_info({timeout, _, handler_max_ttl}, State) ->
{stop, normal, State}; {stop, normal, State};
%% %%
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, qname = Qname}) -> handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, question = Question}) ->
%% %%
case dns:decode_message(Resp) of case dns:decode_message(Resp) of
Msg = #dns_message{answers = Answers} -> Msg = #dns_message{answers = Answers} ->
lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]), lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]),
dns_cache:insert(Qname, Answers), dns_cache:insert(Question, Msg),
gen_udp:send(Sock, SrcIp, SrcPort, Resp); gen_udp:send(Sock, SrcIp, SrcPort, Resp);
Other -> Other ->
lager:debug("[dns_handler] parse reply get error: ~p", [Other]) lager:debug("[dns_handler] parse reply get error: ~p", [Other])
@ -162,6 +163,39 @@ forward_to_upstream(TargetIp, TargetPort, Request, Msg) ->
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg) dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
end). end).
build_response(Req, RR) -> build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional}) ->
Msg = Req, Now = os:system_time(second),
Msg#dns_message{answers = RR, qr = true, aa = true}. RemainingTTL = ExpireAt - Now,
Adjust = fun(RR) -> RR#dns_rr{ttl = max(0, RemainingTTL)} end,
Answers2 = [Adjust(RR) || RR <- Answers],
Authority2 = [Adjust(RR) || RR <- Authority],
Additional0 = [Adjust(RR) || RR <- Additional],
Additional2 = add_opt_if_needed(Query, Additional0),
Query#dns_message{
qr = true,
ra = true,
answers = Answers2,
authority = Authority2,
additional = Additional2
}.
add_opt_if_needed(Query, Additional) ->
case dns_opt:find(Query#dns_message.additional) of
false ->
%% 使 EDNS OPT
Additional;
{ok, OptReq} ->
%% 使 EDNS OPT RR
UdpSize = dns_opt:udp_payload(OptReq),
DoBit = dns_opt:do_bit(OptReq),
OptResp = dns_opt:make(UdpSize, DoBit),
%% OPT
Additional2 = [RR || RR <- Additional, RR#dns_rr.type =/= opt],
%% OPT Additional
Additional2 ++ [OptResp]
end.

View File

@ -0,0 +1,46 @@
%%--------------------------------------------------------------------
%% EDNS (OPT RR) Utility for dns_erlang
%%--------------------------------------------------------------------
-module(dns_opt).
-export([find/1, make/2, udp_payload/1, do_bit/1]).
-include_lib("dns_erlang/include/dns.hrl").
%%--------------------------------------------------------------------
%% OPT RRRR type = opt
%%--------------------------------------------------------------------
find(RRs) ->
case lists:dropwhile(fun(RR) -> RR#dns_rr.type =/= opt end, RRs) of
[] ->
false;
[RR|_] ->
{ok, RR}
end.
%%--------------------------------------------------------------------
%% DO bitTTL bit
%%--------------------------------------------------------------------
do_bit(RR) ->
(RR#dns_rr.ttl band 16#8000) =/= 0.
%%--------------------------------------------------------------------
%% UDP payload size class
%%--------------------------------------------------------------------
udp_payload(RR) ->
RR#dns_rr.class.
%%--------------------------------------------------------------------
%% OPT RR RFC6891
%% make(UdpPayloadSize, DoBit)
%%--------------------------------------------------------------------
make(UdpPayloadSize, DoBit) ->
TTL = if DoBit -> 16#8000; true -> 0 end,
#dns_rr{
name = <<>>, % Root label
type = opt,
class = UdpPayloadSize,
ttl = TTL,
% No EDNS options by default
data = []
}.