sdlan/src/quic/sdlan_quic_channel.erl
2026-04-16 16:35:36 +08:00

442 lines
20 KiB
Erlang
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

%%%-------------------------------------------------------------------
%%% @author anlicheng
%%% @copyright (C) 2026, <COMPANY>
%%% @doc
%%%
%%% @end
%%% Created : 11. 2月 2026 23:00
%%%-------------------------------------------------------------------
-module(sdlan_quic_channel).
-author("anlicheng").
-include("sdlan.hrl").
-include("sdlan_pb.hrl").
-behaviour(gen_statem).
%% 心跳包监测机制
-define(PING_TICKER, 15000).
%% 注册失败的的错误码
%% 网络错误
-define(NAK_NETWORK_FAULT, 4).
%% 内部错误
-define(NAK_INTERNAL_FAULT, 5).
%% API
-export([start_link/2]).
-export([send_event/2, command/4, stop/2]).
-export([test_rules/2]).
%% gen_statem callbacks
-export([init/1, handle_event/4, terminate/3, code_change/4, callback_mode/0]).
-record(state, {
conn :: quicer:connection_handle(),
%% 最大包大小
max_packet_size = 16384,
%% 心跳间隔
heartbeat_sec = 10,
stream :: undefined | quicer:stream_handle(),
%% 累积器用于处理协议framing的解析
buf = <<>>,
client_id :: undefined | binary(),
network_id = 0 :: integer(),
%% 网络相关信息id
network_pid :: undefined | pid(),
%% mac地址
mac :: undefined | binary(),
ip = 0 :: integer(),
%% 建立请求和响应的对应关系
pkt_id = 1,
%% #{pkt_id => {Ref, ReceiverPid}}
pending_commands = #{},
ping_counter = 0,
%% 离线回调函数
offline_cb :: undefined | fun()
}).
%%%===================================================================
%%% API
%%%===================================================================
%% 测试规则函数
test_rules(SrcIdentityId, DstIdentityId) when is_integer(SrcIdentityId), is_integer(DstIdentityId) ->
{ok, Rules} = get_rules(SrcIdentityId, DstIdentityId),
logger:debug("[sdlan_channel] test_rules policy_request src_identity_id: ~p, dst_identity_id: ~p, rules: ~p", [SrcIdentityId, DstIdentityId, Rules]),
iolist_to_binary(lists:map(fun({Proto, Port}) -> <<Proto:8, Port:16>> end, Rules)).
-spec send_event(Pid :: pid(), Event :: binary()) -> no_return().
send_event(Pid, ProtobufEvent) when is_pid(Pid), is_binary(ProtobufEvent) ->
gen_statem:cast(Pid, {send_event, ProtobufEvent}).
-spec command(Pid :: pid(), Ref :: reference(), ReceiverPid :: pid(), {Tag :: atom(), SubCommand :: any()}) -> no_return().
command(Pid, Ref, ReceiverPid, SubCommand) when is_pid(Pid), is_pid(ReceiverPid) ->
gen_statem:cast(Pid, {command, Ref, ReceiverPid, SubCommand}).
-spec stop(Pid :: pid(), Reason :: term()) -> ok.
stop(Pid, Reason) when is_pid(Pid) ->
gen_statem:stop(Pid, Reason, 2000).
%% @doc Creates a gen_statem process which calls Module:init/1 to
%% initialize. To ensure a synchronized start-up procedure, this
%% function does not return until Module:init/1 has returned.
start_link(Conn, Limits) when is_list(Limits) ->
gen_statem:start_link(?MODULE, [Conn, Limits], []).
%%%===================================================================
%%% gen_statem callbacks
%%%===================================================================
%% @private
%% @doc Whenever a gen_statem is started using gen_statem:start/[3,4] or
%% gen_statem:start_link/[3,4], this function is called by the new
%% process to initialize.
init([Conn, Limits]) ->
MaxPacketSize = proplists:get_value(max_packet_size, Limits, 16384),
HeartbeatSec = proplists:get_value(heartbeat_sec, Limits, 10),
{ok, initializing, #state{conn = Conn, max_packet_size = MaxPacketSize, heartbeat_sec = HeartbeatSec}, [{next_event, internal, do_init}]}.
%% @private
%% @doc This function is called by a gen_statem when it needs to find out
%% the callback mode of the callback module.
callback_mode() ->
handle_event_function.
%% @private
%% @doc If callback_mode is handle_event_function, then whenever a
%% gen_statem receives an event from call/2, cast/2, or as a normal
%% process message, this function is called.
handle_event(internal, do_init, initializing, State=#state{conn = Conn}) ->
logger:debug("[sdlan_quic_channel] call do_init of conn: ~p", [Conn]),
{ok, _} = quicer:async_accept_stream(Conn, #{active => true}),
{next_state, waiting_stream, State};
%% 处理收到的quic消息
handle_event(info, {quic, dgram_state_changed, Conn, Opts = #{dgram_send_enabled := true}}, _, State=#state{conn = Conn}) ->
logger:debug("[sdlan_quic_channel] dgram_state_changed, opts: ~p", [Opts]),
{keep_state, State};
handle_event(info, {quic, new_stream, Stream, Opts}, waiting_stream, State=#state{max_packet_size = MaxPacketSize, heartbeat_sec = HeartbeatSec}) ->
logger:debug("[sdlan_quic_channel] call new_stream: ~p, opts: ~p", [Stream, Opts]),
Ipv6Assist = case application:get_env(sdlan, ipv6_assist_info) of
{ok, {V6Bytes, Port}} ->
#'SDLV6Info' {
v6 = V6Bytes,
port = Port
};
_ ->
undefined
end,
%% 发送欢迎消息
WelcomePkt = sdlan_pb:encode_msg(#'SDLWelcome'{
version = 1,
max_bidi_streams = 1,
max_packet_size = MaxPacketSize,
heartbeat_sec = HeartbeatSec,
ipv6_assist = Ipv6Assist
}),
quic_send(Stream, <<?PACKET_WELCOME, WelcomePkt/binary>>),
logger:debug("[sdlan_quic_channel] get stream: ~p, send welcome", [Stream]),
{next_state, initialized, State#state{stream = Stream}};
handle_event(info, {quic, closed, Stream, _Props}, _StateName, State = #state{stream = Stream}) ->
{stop, connection_closed, State};
handle_event(info, {quic, send_shutdown_complete, Stream, _Props}, _StateName, State = #state{stream = Stream}) ->
{stop, connection_shutdown, State};
handle_event(info, {quic, transport_shutdown, Stream, _Props}, _StateName, State = #state{stream = Stream}) ->
{stop, transport_shutdown, State};
%% 处理quicer相关的信息, 需要转换成内部能够识别的frame消息
handle_event(info, {quic, Data, Stream, _Props}, _StateName, State = #state{stream = Stream, buf = Buf, max_packet_size = MaxPacketSize}) when is_binary(Data) ->
case decode_frames(<<Buf/binary, Data/binary>>, MaxPacketSize) of
{error, Reason} ->
{stop, Reason, State};
{ok, NBuf, Frames} ->
Actions = [{next_event, internal, {frame, Frame}} || Frame <- Frames],
%logger:debug("[sdlan_quic_channel] get frames: ~p", [Frames]),
{keep_state, State#state{buf = NBuf}, Actions}
end;
%% 处理内部的包消息
handle_event(internal, {frame, <<?PACKET_REGISTER_SUPER, Body/binary>>}, initialized, State=#state{stream = Stream}) ->
#'SDLRegisterSuper'{
client_id = ClientId, network_id = NetworkId, mac = Mac, ip = Ip, mask_len = MaskLen,
hostname = HostName, pub_key = PubKey, access_token = AccessToken} = sdlan_pb:decode_msg(Body, 'SDLRegisterSuper'),
true = (Mac =/= <<>> andalso PubKey =/= <<>> andalso ClientId =/= <<>>),
%% Mac地址不能是广播地址
true = not (sdlan_util:is_multicast_mac(Mac) orelse sdlan_util:is_broadcast_mac(Mac)),
MacBinStr = sdlan_util:format_mac(Mac),
IpAddr = sdlan_util:int_to_ipv4(Ip),
Params = #{
<<"network_id">> => NetworkId,
<<"client_id">> => ClientId,
<<"mac">> => MacBinStr,
<<"ip">> => IpAddr,
<<"mask_len">> => MaskLen,
<<"hostname">> => HostName,
<<"access_token">> => AccessToken
},
%% 参数检查
logger:debug("[sdlan_quic_channel] client_id: ~p, ip: ~p, mac: ~p, host_name: ~p, access_token: ~p, network_id: ~p",
[ClientId, Ip, Mac, HostName, AccessToken, NetworkId]),
case sdlan_api:auth_access_token(Params) of
{ok, #{<<"result">> := <<"ok">>}} ->
%% 建立到network的对应关系
case sdlan_network:get_pid(NetworkId) of
NetworkPid when is_pid(NetworkPid) ->
{ok, Algorithm, Key, RegionId, SessionToken} = sdlan_network:attach(NetworkPid, self(), ClientId, Mac, Ip, HostName),
RsaPubKey = sdlan_cipher:rsa_pem_decode(PubKey),
RegisterSuperAck = sdlan_pb:encode_msg(#'SDLRegisterSuperAck'{
algorithm = Algorithm,
key = rsa_encode(Key, RsaPubKey),
region_id = RegionId,
session_token = SessionToken
}),
%% 发送确认信息
quic_send(Stream, <<?PACKET_REGISTER_SUPER_ACK, RegisterSuperAck/binary>>),
%% 设置节点的在线状态
Result = sdlan_api:set_node_status(#{
<<"network_id">> => NetworkId,
<<"client_id">> => ClientId,
<<"access_token">> => AccessToken,
<<"status">> => 1
}),
logger:debug("[sdlan_quic_channel] client_id: ~p, set none online result is: ~p", [ClientId, Result]),
OfflineCb = fun() ->
Result = sdlan_api:set_node_status(#{
<<"network_id">> => NetworkId,
<<"client_id">> => ClientId,
<<"access_token">> => AccessToken,
<<"status">> => 0
})
end,
{next_state, registered, State#state{network_id = NetworkId, network_pid = NetworkPid, client_id = ClientId, mac = Mac, ip = Ip, offline_cb = OfflineCb}};
undefined ->
logger:warning("[sdlan_quic_channel] client_id: ~p, register get error: network not found", [ClientId]),
quic_send(Stream, register_nak_reply(?NAK_INTERNAL_FAULT, <<"Internal Error">>)),
{stop, normal, State}
end;
{ok, #{<<"error">> := #{<<"code">> := Code, <<"message">> := Message}}} ->
logger:warning("[sdlan_quic_channel] network_id: ~p, client_id: ~p, register get error: ~ts, error_code: ~p", [NetworkId, ClientId, Message, Code]),
quic_send(Stream, register_nak_reply(Code, Message)),
{stop, normal, State};
{error, Reason} ->
logger:warning("[sdlan_quic_channel] network_id: ~p, client_id: ~p, register get error: ~p", [NetworkId, ClientId, Reason]),
quic_send(Stream, register_nak_reply(?NAK_NETWORK_FAULT, <<"Network Error">>)),
{stop, normal, State}
end;
handle_event(internal, {frame, <<?PACKET_QUERY_INFO, Body/binary>>}, registered, #state{stream = Stream, network_pid = NetworkPid, mac = SrcMac}) when is_pid(NetworkPid) ->
#'SDLQueryInfo'{dst_mac = DstMac} = sdlan_pb:decode_msg(Body, 'SDLQueryInfo'),
case sdlan_network:peer_info(NetworkPid, SrcMac, DstMac) of
error ->
logger:debug("[sdlan_channel] query_info src_mac is: ~p, dst_mac: ~p, nat_peer not found",
[sdlan_util:format_mac(SrcMac), sdlan_util:format_mac(DstMac)]),
EmptyResponse = sdlan_pb:encode_msg(#'SDLPeerInfo'{
dst_mac = DstMac,
v4_info = undefined,
v6_info = undefined
}),
quic_send(Stream, <<?PACKET_PEER_INFO, EmptyResponse/binary>>),
keep_state_and_data;
{ok, {NatPeer = {{Ip0, Ip1, Ip2, Ip3}, NatPort}, NatType}, V6Info} ->
logger:debug("[sdlan_channel] query_info src_mac is: ~p, dst_mac: ~p, nat_peer: ~p",
[sdlan_util:format_mac(SrcMac), sdlan_util:format_mac(DstMac), NatPeer]),
PeerInfo = sdlan_pb:encode_msg(#'SDLPeerInfo'{
dst_mac = DstMac,
v4_info = #'SDLV4Info' {
port = NatPort,
v4 = <<Ip0, Ip1, Ip2, Ip3>>,
nat_type = NatType
},
v6_info = V6Info
}),
quic_send(Stream, <<?PACKET_PEER_INFO, PeerInfo/binary>>),
keep_state_and_data
end;
%% arp查询
handle_event(internal, {frame, <<?PACKET_ARP_REQUEST, Body/binary>>}, registered, #state{stream = Stream, network_id = NetworkId, network_pid = NetworkPid}) when is_pid(NetworkPid) ->
#'SDLArpRequest'{target_ip = TargetIp, origin_ip = OriginIp, context = Context} = sdlan_pb:decode_msg(Body, 'SDLArpRequest'),
case sdlan_network:arp_request(NetworkPid, TargetIp) of
error ->
logger:debug("[sdlan_channel] network: ~p, arp_request target_ip: ~p, mac not found", [NetworkId, sdlan_util:int_to_ipv4(TargetIp)]),
EmptyArpResponsePkt = sdlan_pb:encode_msg(#'SDLArpResponse'{
target_ip = TargetIp,
target_mac = <<>>,
origin_ip = OriginIp,
context = Context
}),
quic_send(Stream, <<?PACKET_ARP_RESPONSE, EmptyArpResponsePkt/binary>>),
keep_state_and_data;
{ok, Mac} ->
logger:debug("[sdlan_channel] network: ~p, arp_request target_ip: ~p, mac: ~p", [NetworkId, sdlan_util:int_to_ipv4(TargetIp), sdlan_util:format_mac(Mac)]),
ArpResponsePkt = sdlan_pb:encode_msg(#'SDLArpResponse'{
target_ip = TargetIp,
target_mac = Mac,
origin_ip = OriginIp,
context = Context
}),
quic_send(Stream, <<?PACKET_ARP_RESPONSE, ArpResponsePkt/binary>>),
keep_state_and_data
end;
handle_event(internal, {frame, <<?PACKET_POLICY_REQUEST, Body/binary>>}, registered, #state{stream = Stream, network_pid = NetworkPid}) when is_pid(NetworkPid) ->
maybe
#'SDLPolicyRequest'{src_identity_id = SrcIdentityId, dst_identity_id = DstIdentityId, version = Version} ?= sdlan_pb:decode_msg(Body, 'SDLPolicyRequest'),
{ok, Rules} = get_rules(SrcIdentityId, DstIdentityId),
logger:debug("[sdlan_channel] policy_request src_identity_id: ~p, dst_identity_id: ~p, rules: ~p", [SrcIdentityId, DstIdentityId, Rules]),
RuleBin = iolist_to_binary(lists:map(fun({Proto, Port}) -> <<Proto:8, Port:16>> end, Rules)),
PolicyResponsePkt = sdlan_pb:encode_msg(#'SDLPolicyResponse'{
src_identity_id = SrcIdentityId,
dst_identity_id = DstIdentityId,
version = Version,
rules = RuleBin
}),
quic_send(Stream, <<?PACKET_POLICY_REPLY, PolicyResponsePkt/binary>>)
end,
keep_state_and_data;
%% 处理命令的响应逻辑
handle_event(internal, {frame, <<?PACKET_COMMAND_ACK, Body/binary>>}, registered, State=#state{pending_commands = PendingCommands}) ->
maybe
CommandAck = sdlan_pb:decode_msg(Body, 'SDLCommandAck'),
#'SDLCommandAck'{pkt_id = PktId} ?= CommandAck,
{{Ref, ReceiverPid}, RestPendingCommands} ?= maps:take(PktId, PendingCommands),
case is_process_alive(ReceiverPid) of
true ->
ReceiverPid ! {quic_command_ack, Ref, CommandAck};
false ->
ok
end,
{keep_state, State#state{pending_commands = RestPendingCommands}}
else _ ->
keep_state_and_data
end;
handle_event(internal, {frame, <<?PACKET_PING>>}, _StateName, State = #state{stream = Stream, ping_counter = PingCounter}) ->
quic_send(Stream, <<?PACKET_PONG>>),
{keep_state, State#state{ping_counter = PingCounter + 1}};
%% 取消注册
handle_event(internal, {frame, <<?PACKET_UNREGISTER>>}, registered, State=#state{client_id = ClientId, mac = Mac, network_pid = NetworkPid}) when is_pid(NetworkPid) ->
logger:warning("[sdlan_channel] unregister client_id: ~p", [ClientId]),
sdlan_network:unregister(NetworkPid, ClientId, Mac),
{stop, normal, State};
handle_event(info, {timeout, _, ping_ticker}, _, State = #state{client_id = ClientId, ping_counter = PingCounter}) ->
%% 等待下一次的心跳检测
erlang:start_timer(?PING_TICKER, self(), ping_ticker),
case PingCounter > 0 of
true ->
{keep_state, State#state{ping_counter = 0}};
false ->
logger:debug("[sdlan_channel] client_id: ~p, ping losted", [ClientId]),
{stop, normal, State#state{ping_counter = 0}}
end;
%% 发送指令信息
handle_event(cast, {send_event, Event}, registered, #state{stream = Stream}) ->
quic_send(Stream, <<?PACKET_EVENT, Event/binary>>),
keep_state_and_data;
%% 发送命令信息
handle_event(cast, {command, Ref, ReceiverPid, SubCommand}, registered, State=#state{stream = Stream, pkt_id = PktId, pending_commands = PendingCommands, client_id = ClientId}) ->
CommandPkt = sdlan_pb:encode_msg(#'SDLCommand'{
pkt_id = PktId,
command = SubCommand
}),
logger:debug("[sdlan_channel] client_id: ~p, will send Command: ~p", [ClientId, SubCommand]),
quic_send(Stream, <<?PACKET_COMMAND, CommandPkt/binary>>),
{keep_state, State#state{pkt_id = PktId + 1, pending_commands = maps:put(PktId, {Ref, ReceiverPid}, PendingCommands)}};
handle_event(info, {'EXIT', _, _}, _StateName, State) ->
{stop, connection_closed, State};
handle_event(EventType, Info, StateName, State) ->
logger:notice("[sdlan_quic_channel] state: ~p, state_name: ~p, event_type: ~p, info: ~p", [State, StateName, EventType, Info]),
keep_state_and_data.
%% @private
%% @doc This function is called by a gen_statem when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_statem terminates with
%% Reason. The return value is ignored.
terminate(Reason, _StateName, _State = #state{conn = Conn, stream = Stream, offline_cb = OfflineCb}) ->
Stream /= undefined andalso quicer:close_stream(Stream),
quicer:close_connection(Conn),
logger:warning("[sdlan_quic_conn] terminate closed with reason: ~p", [Reason]),
%% 触发客户端的离线逻辑
is_function(OfflineCb) andalso OfflineCb(),
ok.
%% @private
%% @doc Convert process state when code is changed
code_change(_OldVsn, StateName, State = #state{}, _Extra) ->
{ok, StateName, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
%% 有2种情况
%% 1. 收到了多个完整的请求
%% 2. 不完整,则不处理
-spec decode_frames(Buf :: binary(), MaxPacketSize :: integer()) -> {ok, RestBin::binary(), Frames :: list()} | {error, Reason :: any()}.
decode_frames(Buf, MaxPacketSize) when is_binary(Buf) ->
decode_frames0(Buf, MaxPacketSize, []).
decode_frames0(<<Len:16, _/binary>>, MaxPacketSize, _Frames) when Len > MaxPacketSize ->
{error, frame_too_large};
decode_frames0(<<Len:16, Frame:Len/binary, Rest/binary>>, MaxPacketSize, Frames) ->
decode_frames0(Rest, MaxPacketSize, [Frame|Frames]);
decode_frames0(Rest, _MaxPacketSize, Frames) ->
{ok, Rest, lists:reverse(Frames)}.
-spec register_nak_reply(ErrorCode :: integer(), ErrorMsg :: binary()) -> binary().
register_nak_reply(ErrorCode, ErrorMsg) when is_integer(ErrorCode), is_binary(ErrorMsg) ->
RegisterNakReply = sdlan_pb:encode_msg(#'SDLRegisterSuperNak'{
error_code = ErrorCode,
error_message = ErrorMsg
}),
<<?PACKET_REGISTER_SUPER_NAK, RegisterNakReply/binary>>.
rsa_encode(PlainText, RsaPubKey) when is_binary(PlainText) ->
iolist_to_binary(sdlan_cipher:rsa_encrypt(PlainText, RsaPubKey)).
-spec quic_send(Stream :: quicer:stream_handle(), Packet :: binary()) -> no_return().
quic_send(Stream, Packet) when is_binary(Packet) ->
Len = byte_size(Packet),
case quicer:send(Stream, <<Len:16, Packet/binary>>) of
{ok, _} ->
ok;
{error, Reason} ->
exit({quic_send_failed, Reason})
end.
-spec get_rules(SrcIdentityId :: integer(), DstIdentityId :: integer()) -> {ok, [{Proto :: integer(), Port :: integer()}]}.
get_rules(SrcIdentityId, DstIdentityId) when is_integer(SrcIdentityId), is_integer(DstIdentityId) ->
SrcPolicyIds = identity_policy_ets:get_policies(SrcIdentityId),
DstPolicyIds = identity_policy_ets:get_policies(DstIdentityId),
rule_ets:get_rules(SrcPolicyIds, DstPolicyIds).