fix cache
This commit is contained in:
parent
28355c1df0
commit
2a1e39b16e
21
apps/dns_proxy/include/dns_proxy.hrl
Normal file
21
apps/dns_proxy/include/dns_proxy.hrl
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
%%%-------------------------------------------------------------------
|
||||||
|
%%% @author anlicheng
|
||||||
|
%%% @copyright (C) 2025, <COMPANY>
|
||||||
|
%%% @doc
|
||||||
|
%%%
|
||||||
|
%%% @end
|
||||||
|
%%% Created : 04. 12月 2025 11:41
|
||||||
|
%%%-------------------------------------------------------------------
|
||||||
|
-author("anlicheng").
|
||||||
|
|
||||||
|
-record(dns_cache, {
|
||||||
|
%% {Qname, QType, QClass}
|
||||||
|
key,
|
||||||
|
answers = [],
|
||||||
|
authority = [],
|
||||||
|
additional = [],
|
||||||
|
rcode :: integer(),
|
||||||
|
flags = #{},
|
||||||
|
% unix time
|
||||||
|
expire_at :: integer()
|
||||||
|
}).
|
||||||
@ -1,17 +1,48 @@
|
|||||||
-module(dns_cache).
|
-module(dns_cache).
|
||||||
|
-include_lib("dns_proxy.hrl").
|
||||||
|
-include_lib("dns_erlang/include/dns.hrl").
|
||||||
|
-include_lib("dns_erlang/include/dns_records.hrl").
|
||||||
|
-include_lib("dns_erlang/include/dns_terms.hrl").
|
||||||
|
|
||||||
-export([init/0, lookup/1, insert/2]).
|
-export([init/0, lookup/1, insert/2]).
|
||||||
|
|
||||||
-define(TABLE, dns_cache).
|
-define(TABLE, dns_cache).
|
||||||
|
|
||||||
init() ->
|
init() ->
|
||||||
ets:new(?TABLE, [named_table, set, public, {read_concurrency, true}]).
|
ets:new(?TABLE, [named_table, set, public, {keypos, 2}, {read_concurrency, true}]).
|
||||||
|
|
||||||
lookup(Key) ->
|
lookup(#dns_query{name = Qname, type = QType, class = QClass}) ->
|
||||||
|
Key = {Qname, QType, QClass},
|
||||||
case ets:lookup(?TABLE, Key) of
|
case ets:lookup(?TABLE, Key) of
|
||||||
[{_Key, Value}] -> {hit, Value};
|
[Cache = #dns_cache{expire_at = ExpireAt}] ->
|
||||||
[] -> miss
|
Now = os:system_time(second),
|
||||||
|
case ExpireAt > Now of
|
||||||
|
true ->
|
||||||
|
{hit, Cache};
|
||||||
|
false ->
|
||||||
|
true = ets:delete(?TABLE, Key),
|
||||||
|
miss
|
||||||
|
end;
|
||||||
|
[] ->
|
||||||
|
miss
|
||||||
end.
|
end.
|
||||||
|
|
||||||
insert(Key, DNSMsg) ->
|
insert(#dns_query{name = Qname, type = QType, class = QClass},
|
||||||
ets:insert(?TABLE, {Key, DNSMsg}),
|
#dns_message{answers = Answers, authority = Authority, additional = Additional, aa = AA}) ->
|
||||||
ok.
|
|
||||||
|
TTLs = [RR#dns_rr.ttl || RR <- Answers] ++ [RR#dns_rr.ttl || RR <- Authority] ++ [RR#dns_rr.ttl || RR <- Additional],
|
||||||
|
TTL = lists:min(TTLs),
|
||||||
|
ExpireAt = os:system_time(second) + TTL,
|
||||||
|
|
||||||
|
Key = {Qname, QType, QClass},
|
||||||
|
Cache = #dns_cache{
|
||||||
|
key = Key,
|
||||||
|
answers = Answers,
|
||||||
|
authority = Authority,
|
||||||
|
additional = Additional,
|
||||||
|
rcode = 0,
|
||||||
|
flags = #{aa => AA},
|
||||||
|
% unix time
|
||||||
|
expire_at = ExpireAt
|
||||||
|
},
|
||||||
|
true = ets:insert(?TABLE, Cache).
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
-include_lib("dns_erlang/include/dns.hrl").
|
-include_lib("dns_erlang/include/dns.hrl").
|
||||||
-include_lib("dns_erlang/include/dns_records.hrl").
|
-include_lib("dns_erlang/include/dns_records.hrl").
|
||||||
-include_lib("dns_erlang/include/dns_terms.hrl").
|
-include_lib("dns_erlang/include/dns_terms.hrl").
|
||||||
|
-include("dns_proxy.hrl").
|
||||||
|
|
||||||
%% API
|
%% API
|
||||||
-export([start_link/4]).
|
-export([start_link/4]).
|
||||||
@ -32,7 +33,7 @@
|
|||||||
src_ip,
|
src_ip,
|
||||||
src_port,
|
src_port,
|
||||||
packet,
|
packet,
|
||||||
qname,
|
question,
|
||||||
dns_servers = []
|
dns_servers = []
|
||||||
}).
|
}).
|
||||||
|
|
||||||
@ -85,10 +86,10 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
|
|||||||
Msg = #dns_message{qc = 1, questions = [Question|_]} ->
|
Msg = #dns_message{qc = 1, questions = [Question|_]} ->
|
||||||
Qname = Question#dns_query.name,
|
Qname = Question#dns_query.name,
|
||||||
lager:debug("[dns_handler] qname: ~p", [Qname]),
|
lager:debug("[dns_handler] qname: ~p", [Qname]),
|
||||||
case dns_cache:lookup(Qname) of
|
case dns_cache:lookup(Question) of
|
||||||
{hit, R} ->
|
{hit, Cache} ->
|
||||||
lager:debug("[dns_handler] hit cache rr: ~p", [R]),
|
lager:debug("[dns_handler] hit cache: ~p", [Cache]),
|
||||||
Resp = build_response(Msg, R),
|
Resp = build_response(Msg, Cache),
|
||||||
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
|
gen_udp:send(Sock, SrcIp, SrcPort, dns:encode_message(Resp)),
|
||||||
{stop, normal, State};
|
{stop, normal, State};
|
||||||
miss ->
|
miss ->
|
||||||
@ -96,7 +97,7 @@ handle_cast(handle, State = #state{socket = Sock, src_ip = SrcIp, src_port = Src
|
|||||||
forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
|
forward_to_upstream(DnsIp, DnsPort, Packet, Msg),
|
||||||
%% 开启定时器,超时后递归请求后面的服务
|
%% 开启定时器,超时后递归请求后面的服务
|
||||||
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
|
erlang:start_timer(?UPSTREAM_TIMEOUT, self(), {trigger_next, Msg}),
|
||||||
{noreply, State#state{dns_servers = RestDnsServers, qname = Qname}}
|
{noreply, State#state{dns_servers = RestDnsServers, question = Question}}
|
||||||
end;
|
end;
|
||||||
Other ->
|
Other ->
|
||||||
lager:warning("[] decode msg get error: ~p", [Other]),
|
lager:warning("[] decode msg get error: ~p", [Other]),
|
||||||
@ -122,12 +123,12 @@ handle_info({timeout, _, handler_max_ttl}, State) ->
|
|||||||
{stop, normal, State};
|
{stop, normal, State};
|
||||||
|
|
||||||
%% 收到请求
|
%% 收到请求
|
||||||
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, qname = Qname}) ->
|
handle_info({dns_resolver_reply, Resp}, State = #state{socket = Sock, src_ip = SrcIp, src_port = SrcPort, question = Question}) ->
|
||||||
%% 尝试解析
|
%% 尝试解析
|
||||||
case dns:decode_message(Resp) of
|
case dns:decode_message(Resp) of
|
||||||
Msg = #dns_message{answers = Answers} ->
|
Msg = #dns_message{answers = Answers} ->
|
||||||
lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]),
|
lager:debug("[dns_handler] get a resolver reply: ~p, bin: ~p", [Msg, Answers]),
|
||||||
dns_cache:insert(Qname, Answers),
|
dns_cache:insert(Question, Msg),
|
||||||
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
|
gen_udp:send(Sock, SrcIp, SrcPort, Resp);
|
||||||
Other ->
|
Other ->
|
||||||
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
|
lager:debug("[dns_handler] parse reply get error: ~p", [Other])
|
||||||
@ -162,6 +163,39 @@ forward_to_upstream(TargetIp, TargetPort, Request, Msg) ->
|
|||||||
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
|
dns_resolver:forward(Pid, ReceiverPid, TargetIp, TargetPort, Request, Msg)
|
||||||
end).
|
end).
|
||||||
|
|
||||||
build_response(Req, RR) ->
|
build_response(Query, #dns_cache{expire_at = ExpireAt, answers = Answers, authority = Authority, additional = Additional}) ->
|
||||||
Msg = Req,
|
Now = os:system_time(second),
|
||||||
Msg#dns_message{answers = RR, qr = true, aa = true}.
|
RemainingTTL = ExpireAt - Now,
|
||||||
|
|
||||||
|
Adjust = fun(RR) -> RR#dns_rr{ttl = max(0, RemainingTTL)} end,
|
||||||
|
Answers2 = [Adjust(RR) || RR <- Answers],
|
||||||
|
Authority2 = [Adjust(RR) || RR <- Authority],
|
||||||
|
|
||||||
|
Additional0 = [Adjust(RR) || RR <- Additional],
|
||||||
|
Additional2 = add_opt_if_needed(Query, Additional0),
|
||||||
|
|
||||||
|
Query#dns_message{
|
||||||
|
qr = true,
|
||||||
|
ra = true,
|
||||||
|
answers = Answers2,
|
||||||
|
authority = Authority2,
|
||||||
|
additional = Additional2
|
||||||
|
}.
|
||||||
|
|
||||||
|
add_opt_if_needed(Query, Additional) ->
|
||||||
|
case dns_opt:find(Query#dns_message.additional) of
|
||||||
|
false ->
|
||||||
|
%% 客户端未使用 EDNS → 不在响应中加 OPT
|
||||||
|
Additional;
|
||||||
|
{ok, OptReq} ->
|
||||||
|
%% 客户端使用 EDNS → 构造自身的 OPT RR
|
||||||
|
UdpSize = dns_opt:udp_payload(OptReq),
|
||||||
|
DoBit = dns_opt:do_bit(OptReq),
|
||||||
|
|
||||||
|
OptResp = dns_opt:make(UdpSize, DoBit),
|
||||||
|
|
||||||
|
%% 去除上游响应里的旧 OPT(如果有)
|
||||||
|
Additional2 = [RR || RR <- Additional, RR#dns_rr.type =/= opt],
|
||||||
|
%% 新的 OPT 一定放在 Additional 最后
|
||||||
|
Additional2 ++ [OptResp]
|
||||||
|
end.
|
||||||
46
apps/dns_proxy/src/dns_opt.erl
Normal file
46
apps/dns_proxy/src/dns_opt.erl
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
%%--------------------------------------------------------------------
|
||||||
|
%% EDNS (OPT RR) Utility for dns_erlang
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
-module(dns_opt).
|
||||||
|
|
||||||
|
-export([find/1, make/2, udp_payload/1, do_bit/1]).
|
||||||
|
|
||||||
|
-include_lib("dns_erlang/include/dns.hrl").
|
||||||
|
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
%% 查找 OPT RR(RR type = opt)
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
find(RRs) ->
|
||||||
|
case lists:dropwhile(fun(RR) -> RR#dns_rr.type =/= opt end, RRs) of
|
||||||
|
[] ->
|
||||||
|
false;
|
||||||
|
[RR|_] ->
|
||||||
|
{ok, RR}
|
||||||
|
end.
|
||||||
|
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
%% 获取 DO bit(TTL 的最高 bit)
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
do_bit(RR) ->
|
||||||
|
(RR#dns_rr.ttl band 16#8000) =/= 0.
|
||||||
|
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
%% 获取 UDP payload size(来自 class 字段)
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
udp_payload(RR) ->
|
||||||
|
RR#dns_rr.class.
|
||||||
|
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
%% 构造新的 OPT RR(符合 RFC6891)
|
||||||
|
%% make(UdpPayloadSize, DoBit)
|
||||||
|
%%--------------------------------------------------------------------
|
||||||
|
make(UdpPayloadSize, DoBit) ->
|
||||||
|
TTL = if DoBit -> 16#8000; true -> 0 end,
|
||||||
|
#dns_rr{
|
||||||
|
name = <<>>, % Root label
|
||||||
|
type = opt,
|
||||||
|
class = UdpPayloadSize,
|
||||||
|
ttl = TTL,
|
||||||
|
% No EDNS options by default
|
||||||
|
data = []
|
||||||
|
}.
|
||||||
Loading…
x
Reference in New Issue
Block a user