%%%------------------------------------------------------------------- %%% @author anlicheng %%% @copyright (C) 2026, %%% @doc %%% %%% @end %%% Created : 11. 2月 2026 23:00 %%%------------------------------------------------------------------- -module(sdlan_quic_channel). -author("anlicheng"). -include("sdlan.hrl"). -include("sdlan_pb.hrl"). -behaviour(gen_statem). %% 心跳包监测机制 -define(PING_TICKER, 15000). %% 注册失败的的错误码 %% 网络错误 -define(NAK_NETWORK_FAULT, 4). %% 内部错误 -define(NAK_INTERNAL_FAULT, 5). %% API -export([start_link/2]). -export([send_event/2, command/4, stop/2]). -export([test_rules/2]). %% gen_statem callbacks -export([init/1, handle_event/4, terminate/3, code_change/4, callback_mode/0]). -record(state, { conn :: quicer:connection_handle(), %% 最大包大小 max_packet_size = 16384, %% 心跳间隔 heartbeat_sec = 10, stream :: undefined | quicer:stream_handle(), %% 累积器,用于处理协议framing的解析 buf = <<>>, client_id :: undefined | binary(), network_id = 0 :: integer(), %% 网络相关信息id network_pid :: undefined | pid(), %% mac地址 mac :: undefined | binary(), ip = 0 :: integer(), %% 建立请求和响应的对应关系 pkt_id = 1, %% #{pkt_id => {Ref, ReceiverPid}} pending_commands = #{}, ping_counter = 0, %% 离线回调函数 offline_cb :: undefined | fun() }). %%%=================================================================== %%% API %%%=================================================================== %% 测试规则函数 test_rules(SrcIdentityId, DstIdentityId) when is_integer(SrcIdentityId), is_integer(DstIdentityId) -> {ok, Rules} = get_rules(SrcIdentityId, DstIdentityId), logger:debug("[sdlan_channel] test_rules policy_request src_identity_id: ~p, dst_identity_id: ~p, rules: ~p", [SrcIdentityId, DstIdentityId, Rules]), iolist_to_binary(lists:map(fun({Proto, Port}) -> <> end, Rules)). -spec send_event(Pid :: pid(), Event :: binary()) -> no_return(). send_event(Pid, ProtobufEvent) when is_pid(Pid), is_binary(ProtobufEvent) -> gen_statem:cast(Pid, {send_event, ProtobufEvent}). -spec command(Pid :: pid(), Ref :: reference(), ReceiverPid :: pid(), {Tag :: atom(), SubCommand :: any()}) -> no_return(). command(Pid, Ref, ReceiverPid, SubCommand) when is_pid(Pid), is_pid(ReceiverPid) -> gen_statem:cast(Pid, {command, Ref, ReceiverPid, SubCommand}). -spec stop(Pid :: pid(), Reason :: term()) -> ok. stop(Pid, Reason) when is_pid(Pid) -> gen_statem:stop(Pid, Reason, 2000). %% @doc Creates a gen_statem process which calls Module:init/1 to %% initialize. To ensure a synchronized start-up procedure, this %% function does not return until Module:init/1 has returned. start_link(Conn, Limits) when is_list(Limits) -> gen_statem:start_link(?MODULE, [Conn, Limits], []). %%%=================================================================== %%% gen_statem callbacks %%%=================================================================== %% @private %% @doc Whenever a gen_statem is started using gen_statem:start/[3,4] or %% gen_statem:start_link/[3,4], this function is called by the new %% process to initialize. init([Conn, Limits]) -> MaxPacketSize = proplists:get_value(max_packet_size, Limits, 16384), HeartbeatSec = proplists:get_value(heartbeat_sec, Limits, 10), {ok, initializing, #state{conn = Conn, max_packet_size = MaxPacketSize, heartbeat_sec = HeartbeatSec}, [{next_event, internal, do_init}]}. %% @private %% @doc This function is called by a gen_statem when it needs to find out %% the callback mode of the callback module. callback_mode() -> handle_event_function. %% @private %% @doc If callback_mode is handle_event_function, then whenever a %% gen_statem receives an event from call/2, cast/2, or as a normal %% process message, this function is called. handle_event(internal, do_init, initializing, State=#state{conn = Conn}) -> logger:debug("[sdlan_quic_channel] call do_init of conn: ~p", [Conn]), {ok, _} = quicer:async_accept_stream(Conn, #{active => true}), {next_state, waiting_stream, State}; %% 处理收到的quic消息 handle_event(info, {quic, dgram_state_changed, Conn, Opts = #{dgram_send_enabled := true}}, _, State=#state{conn = Conn}) -> logger:debug("[sdlan_quic_channel] dgram_state_changed, opts: ~p", [Opts]), {keep_state, State}; handle_event(info, {quic, new_stream, Stream, Opts}, waiting_stream, State=#state{max_packet_size = MaxPacketSize, heartbeat_sec = HeartbeatSec}) -> logger:debug("[sdlan_quic_channel] call new_stream: ~p, opts: ~p", [Stream, Opts]), Ipv6Assist = case application:get_env(sdlan, ipv6_assist_info) of {ok, {V6Bytes, Port}} -> #'SDLV6Info' { v6 = V6Bytes, port = Port }; _ -> undefined end, %% 发送欢迎消息 WelcomePkt = sdlan_pb:encode_msg(#'SDLWelcome'{ version = 1, max_bidi_streams = 1, max_packet_size = MaxPacketSize, heartbeat_sec = HeartbeatSec, ipv6_assist = Ipv6Assist }), quic_send(Stream, <>), logger:debug("[sdlan_quic_channel] get stream: ~p, send welcome", [Stream]), {next_state, initialized, State#state{stream = Stream}}; handle_event(info, {quic, closed, Stream, _Props}, _StateName, State = #state{stream = Stream}) -> {stop, connection_closed, State}; handle_event(info, {quic, send_shutdown_complete, Stream, _Props}, _StateName, State = #state{stream = Stream}) -> {stop, connection_shutdown, State}; handle_event(info, {quic, transport_shutdown, Stream, _Props}, _StateName, State = #state{stream = Stream}) -> {stop, transport_shutdown, State}; %% 处理quicer相关的信息, 需要转换成内部能够识别的frame消息 handle_event(info, {quic, Data, Stream, _Props}, _StateName, State = #state{stream = Stream, buf = Buf, max_packet_size = MaxPacketSize}) when is_binary(Data) -> case decode_frames(<>, MaxPacketSize) of {error, Reason} -> {stop, Reason, State}; {ok, NBuf, Frames} -> Actions = [{next_event, internal, {frame, Frame}} || Frame <- Frames], %logger:debug("[sdlan_quic_channel] get frames: ~p", [Frames]), {keep_state, State#state{buf = NBuf}, Actions} end; %% 处理内部的包消息 handle_event(internal, {frame, <>}, initialized, State=#state{stream = Stream}) -> #'SDLRegisterSuper'{ client_id = ClientId, network_id = NetworkId, mac = Mac, ip = Ip, mask_len = MaskLen, hostname = HostName, pub_key = PubKey, access_token = AccessToken} = sdlan_pb:decode_msg(Body, 'SDLRegisterSuper'), true = (Mac =/= <<>> andalso PubKey =/= <<>> andalso ClientId =/= <<>>), %% Mac地址不能是广播地址 true = not (sdlan_util:is_multicast_mac(Mac) orelse sdlan_util:is_broadcast_mac(Mac)), MacBinStr = sdlan_util:format_mac(Mac), IpAddr = sdlan_util:int_to_ipv4(Ip), Params = #{ <<"network_id">> => NetworkId, <<"client_id">> => ClientId, <<"mac">> => MacBinStr, <<"ip">> => IpAddr, <<"mask_len">> => MaskLen, <<"hostname">> => HostName, <<"access_token">> => AccessToken }, %% 参数检查 logger:debug("[sdlan_quic_channel] client_id: ~p, ip: ~p, mac: ~p, host_name: ~p, access_token: ~p, network_id: ~p", [ClientId, Ip, Mac, HostName, AccessToken, NetworkId]), case sdlan_api:auth_access_token(Params) of {ok, #{<<"result">> := <<"ok">>}} -> %% 建立到network的对应关系 case sdlan_network:get_pid(NetworkId) of NetworkPid when is_pid(NetworkPid) -> {ok, Algorithm, Key, RegionId, SessionToken} = sdlan_network:attach(NetworkPid, self(), ClientId, Mac, Ip, HostName), RsaPubKey = sdlan_cipher:rsa_pem_decode(PubKey), RegisterSuperAck = sdlan_pb:encode_msg(#'SDLRegisterSuperAck'{ algorithm = Algorithm, key = rsa_encode(Key, RsaPubKey), region_id = RegionId, session_token = SessionToken }), %% 发送确认信息 quic_send(Stream, <>), %% 设置节点的在线状态 Result = sdlan_api:set_node_status(#{ <<"network_id">> => NetworkId, <<"client_id">> => ClientId, <<"access_token">> => AccessToken, <<"status">> => 1 }), logger:debug("[sdlan_quic_channel] client_id: ~p, set none online result is: ~p", [ClientId, Result]), OfflineCb = fun() -> Result = sdlan_api:set_node_status(#{ <<"network_id">> => NetworkId, <<"client_id">> => ClientId, <<"access_token">> => AccessToken, <<"status">> => 0 }) end, {next_state, registered, State#state{network_id = NetworkId, network_pid = NetworkPid, client_id = ClientId, mac = Mac, ip = Ip, offline_cb = OfflineCb}}; undefined -> logger:warning("[sdlan_quic_channel] client_id: ~p, register get error: network not found", [ClientId]), quic_send(Stream, register_nak_reply(?NAK_INTERNAL_FAULT, <<"Internal Error">>)), {stop, normal, State} end; {ok, #{<<"error">> := #{<<"code">> := Code, <<"message">> := Message}}} -> logger:warning("[sdlan_quic_channel] network_id: ~p, client_id: ~p, register get error: ~ts, error_code: ~p", [NetworkId, ClientId, Message, Code]), quic_send(Stream, register_nak_reply(Code, Message)), {stop, normal, State}; {error, Reason} -> logger:warning("[sdlan_quic_channel] network_id: ~p, client_id: ~p, register get error: ~p", [NetworkId, ClientId, Reason]), quic_send(Stream, register_nak_reply(?NAK_NETWORK_FAULT, <<"Network Error">>)), {stop, normal, State} end; handle_event(internal, {frame, <>}, registered, #state{stream = Stream, network_pid = NetworkPid, mac = SrcMac}) when is_pid(NetworkPid) -> #'SDLQueryInfo'{dst_mac = DstMac} = sdlan_pb:decode_msg(Body, 'SDLQueryInfo'), case sdlan_network:peer_info(NetworkPid, SrcMac, DstMac) of error -> logger:debug("[sdlan_channel] query_info src_mac is: ~p, dst_mac: ~p, nat_peer not found", [sdlan_util:format_mac(SrcMac), sdlan_util:format_mac(DstMac)]), EmptyResponse = sdlan_pb:encode_msg(#'SDLPeerInfo'{ dst_mac = DstMac, v4_info = undefined, v6_info = undefined }), quic_send(Stream, <>), keep_state_and_data; {ok, {NatPeer = {{Ip0, Ip1, Ip2, Ip3}, NatPort}, NatType}, V6Info} -> logger:debug("[sdlan_channel] query_info src_mac is: ~p, dst_mac: ~p, nat_peer: ~p", [sdlan_util:format_mac(SrcMac), sdlan_util:format_mac(DstMac), NatPeer]), PeerInfo = sdlan_pb:encode_msg(#'SDLPeerInfo'{ dst_mac = DstMac, v4_info = #'SDLV4Info' { port = NatPort, v4 = <>, nat_type = NatType }, v6_info = V6Info }), quic_send(Stream, <>), keep_state_and_data end; %% arp查询 handle_event(internal, {frame, <>}, registered, #state{stream = Stream, network_id = NetworkId, network_pid = NetworkPid}) when is_pid(NetworkPid) -> #'SDLArpRequest'{target_ip = TargetIp, origin_ip = OriginIp, context = Context} = sdlan_pb:decode_msg(Body, 'SDLArpRequest'), case sdlan_network:arp_request(NetworkPid, TargetIp) of error -> logger:debug("[sdlan_channel] network: ~p, arp_request target_ip: ~p, mac not found", [NetworkId, sdlan_util:int_to_ipv4(TargetIp)]), EmptyArpResponsePkt = sdlan_pb:encode_msg(#'SDLArpResponse'{ target_ip = TargetIp, target_mac = <<>>, origin_ip = OriginIp, context = Context }), quic_send(Stream, <>), keep_state_and_data; {ok, Mac} -> logger:debug("[sdlan_channel] network: ~p, arp_request target_ip: ~p, mac: ~p", [NetworkId, sdlan_util:int_to_ipv4(TargetIp), sdlan_util:format_mac(Mac)]), ArpResponsePkt = sdlan_pb:encode_msg(#'SDLArpResponse'{ target_ip = TargetIp, target_mac = Mac, origin_ip = OriginIp, context = Context }), quic_send(Stream, <>), keep_state_and_data end; handle_event(internal, {frame, <>}, registered, #state{stream = Stream, network_pid = NetworkPid}) when is_pid(NetworkPid) -> maybe #'SDLPolicyRequest'{src_identity_id = SrcIdentityId, dst_identity_id = DstIdentityId, version = Version} ?= sdlan_pb:decode_msg(Body, 'SDLPolicyRequest'), {ok, Rules} = get_rules(SrcIdentityId, DstIdentityId), logger:debug("[sdlan_channel] policy_request src_identity_id: ~p, dst_identity_id: ~p, rules: ~p", [SrcIdentityId, DstIdentityId, Rules]), RuleBin = iolist_to_binary(lists:map(fun({Proto, Port}) -> <> end, Rules)), PolicyResponsePkt = sdlan_pb:encode_msg(#'SDLPolicyResponse'{ src_identity_id = SrcIdentityId, dst_identity_id = DstIdentityId, version = Version, rules = RuleBin }), quic_send(Stream, <>) end, keep_state_and_data; %% 处理命令的响应逻辑 handle_event(internal, {frame, <>}, registered, State=#state{pending_commands = PendingCommands}) -> maybe CommandAck = sdlan_pb:decode_msg(Body, 'SDLCommandAck'), #'SDLCommandAck'{pkt_id = PktId} ?= CommandAck, {{Ref, ReceiverPid}, RestPendingCommands} ?= maps:take(PktId, PendingCommands), case is_process_alive(ReceiverPid) of true -> ReceiverPid ! {quic_command_ack, Ref, CommandAck}; false -> ok end, {keep_state, State#state{pending_commands = RestPendingCommands}} else _ -> keep_state_and_data end; handle_event(internal, {frame, <>}, _StateName, State = #state{stream = Stream, ping_counter = PingCounter}) -> quic_send(Stream, <>), {keep_state, State#state{ping_counter = PingCounter + 1}}; %% 取消注册 handle_event(internal, {frame, <>}, registered, State=#state{client_id = ClientId, mac = Mac, network_pid = NetworkPid}) when is_pid(NetworkPid) -> logger:warning("[sdlan_channel] unregister client_id: ~p", [ClientId]), sdlan_network:unregister(NetworkPid, ClientId, Mac), {stop, normal, State}; handle_event(info, {timeout, _, ping_ticker}, _, State = #state{client_id = ClientId, ping_counter = PingCounter}) -> %% 等待下一次的心跳检测 erlang:start_timer(?PING_TICKER, self(), ping_ticker), case PingCounter > 0 of true -> {keep_state, State#state{ping_counter = 0}}; false -> logger:debug("[sdlan_channel] client_id: ~p, ping losted", [ClientId]), {stop, normal, State#state{ping_counter = 0}} end; %% 发送指令信息 handle_event(cast, {send_event, Event}, registered, #state{stream = Stream}) -> quic_send(Stream, <>), keep_state_and_data; %% 发送命令信息 handle_event(cast, {command, Ref, ReceiverPid, SubCommand}, registered, State=#state{stream = Stream, pkt_id = PktId, pending_commands = PendingCommands, client_id = ClientId}) -> CommandPkt = sdlan_pb:encode_msg(#'SDLCommand'{ pkt_id = PktId, command = SubCommand }), logger:debug("[sdlan_channel] client_id: ~p, will send Command: ~p", [ClientId, SubCommand]), quic_send(Stream, <>), {keep_state, State#state{pkt_id = PktId + 1, pending_commands = maps:put(PktId, {Ref, ReceiverPid}, PendingCommands)}}; handle_event(info, {'EXIT', _, _}, _StateName, State) -> {stop, connection_closed, State}; handle_event(EventType, Info, StateName, State) -> logger:notice("[sdlan_quic_channel] state: ~p, state_name: ~p, event_type: ~p, info: ~p", [State, StateName, EventType, Info]), keep_state_and_data. %% @private %% @doc This function is called by a gen_statem when it is about to %% terminate. It should be the opposite of Module:init/1 and do any %% necessary cleaning up. When it returns, the gen_statem terminates with %% Reason. The return value is ignored. terminate(Reason, _StateName, _State = #state{conn = Conn, stream = Stream, offline_cb = OfflineCb}) -> Stream /= undefined andalso quicer:close_stream(Stream), quicer:close_connection(Conn), logger:warning("[sdlan_quic_conn] terminate closed with reason: ~p", [Reason]), %% 触发客户端的离线逻辑 is_function(OfflineCb) andalso OfflineCb(), ok. %% @private %% @doc Convert process state when code is changed code_change(_OldVsn, StateName, State = #state{}, _Extra) -> {ok, StateName, State}. %%%=================================================================== %%% Internal functions %%%=================================================================== %% 有2种情况 %% 1. 收到了多个完整的请求 %% 2. 不完整,则不处理 -spec decode_frames(Buf :: binary(), MaxPacketSize :: integer()) -> {ok, RestBin::binary(), Frames :: list()} | {error, Reason :: any()}. decode_frames(Buf, MaxPacketSize) when is_binary(Buf) -> decode_frames0(Buf, MaxPacketSize, []). decode_frames0(<>, MaxPacketSize, _Frames) when Len > MaxPacketSize -> {error, frame_too_large}; decode_frames0(<>, MaxPacketSize, Frames) -> decode_frames0(Rest, MaxPacketSize, [Frame|Frames]); decode_frames0(Rest, _MaxPacketSize, Frames) -> {ok, Rest, lists:reverse(Frames)}. -spec register_nak_reply(ErrorCode :: integer(), ErrorMsg :: binary()) -> binary(). register_nak_reply(ErrorCode, ErrorMsg) when is_integer(ErrorCode), is_binary(ErrorMsg) -> RegisterNakReply = sdlan_pb:encode_msg(#'SDLRegisterSuperNak'{ error_code = ErrorCode, error_message = ErrorMsg }), <>. rsa_encode(PlainText, RsaPubKey) when is_binary(PlainText) -> iolist_to_binary(sdlan_cipher:rsa_encrypt(PlainText, RsaPubKey)). -spec quic_send(Stream :: quicer:stream_handle(), Packet :: binary()) -> no_return(). quic_send(Stream, Packet) when is_binary(Packet) -> Len = byte_size(Packet), case quicer:send(Stream, <>) of {ok, _} -> ok; {error, Reason} -> exit({quic_send_failed, Reason}) end. -spec get_rules(SrcIdentityId :: integer(), DstIdentityId :: integer()) -> {ok, [{Proto :: integer(), Port :: integer()}]}. get_rules(SrcIdentityId, DstIdentityId) when is_integer(SrcIdentityId), is_integer(DstIdentityId) -> SrcPolicyIds = identity_policy_ets:get_policies(SrcIdentityId), DstPolicyIds = identity_policy_ets:get_policies(DstIdentityId), rule_ets:get_rules(SrcPolicyIds, DstPolicyIds).