From a4f2ed84282984d3e351bdfcd96b212bbd49571d Mon Sep 17 00:00:00 2001 From: anlicheng <244108715@qq.com> Date: Sun, 25 Jan 2026 21:07:34 +0800 Subject: [PATCH] fix register --- apps/sdlan/src/sdlan_network.erl | 132 ++++++++--------------- apps/sdlan/src/sdlan_register_worker.erl | 41 +++---- apps/sdlan/src/sdlan_stun_pool.erl | 14 ++- 3 files changed, 79 insertions(+), 108 deletions(-) diff --git a/apps/sdlan/src/sdlan_network.erl b/apps/sdlan/src/sdlan_network.erl index 78e1cb6..22696e8 100644 --- a/apps/sdlan/src/sdlan_network.erl +++ b/apps/sdlan/src/sdlan_network.erl @@ -23,7 +23,7 @@ %% API -export([start_link/2]). --export([get_name/1, get_pid/1, lookup_pid/1, assign_ip_addr/4, peer_info/3, unregister/3, debug_info/1, get_network_id/1, get_used_map/1, arp_query/2]). +-export([get_name/1, get_pid/1, lookup_pid/1, attach/6, peer_info/3, unregister/3, debug_info/1, get_network_id/1, get_used_map/1, arp_query/2]). -export([forward/5, update_hole/6, disable_client/2, dropout_client/2]). -export([test_event/1]). @@ -39,8 +39,8 @@ -record(endpoint, { client_id :: binary(), ip :: integer(), - monitor_ref :: reference(), - hole :: undefined | #hole{}, + hostname :: binary(), + hole :: #hole{}, %% 记录ip和ip_v6的映射关系, #{ip_addr :: integer() => {}} v6_info :: undefined | #sdl_v6_info{} }). @@ -91,11 +91,10 @@ lookup_pid(Id) when is_integer(Id) -> get_name(Id) when is_integer(Id) -> list_to_atom("sdlan_network:" ++ integer_to_list(Id)). --spec attach(Pid :: pid(), ChannelPid :: pid(), ClientId :: binary(), Mac :: binary()) -> - {ok, Domain :: binary(), NetAddr :: integer(), MaskLen :: integer(), AesKey :: binary()} | {error, Reason :: any()}. - -attach(Pid, ChannelPid, ClientId, Mac, Ip, HostName) when is_pid(Pid), is_pid(ChannelPid), is_binary(ClientId), is_binary(Mac) -> - gen_server:call(Pid, {attach, ChannelPid, ClientId, Mac, Ip, HostName}). +-spec attach(Pid :: pid(), Peer :: {Ip :: inet:ip4_address(), Port :: integer()}, ClientId :: binary(), Mac :: binary(), Ip :: inet:ip4_address(), HostName :: binary()) -> + {ok, Domain :: binary(), MaskLen :: integer(), AesKey :: binary()} | {error, Reason :: any()}. +attach(Pid, Peer, ClientId, Mac, Ip, HostName) when is_pid(Pid), is_binary(ClientId), is_binary(Mac) -> + gen_server:call(Pid, {attach, Peer, ClientId, Mac, Ip, HostName}). -spec get_network_id(Pid :: pid()) -> {ok, NetworkId :: integer()}. get_network_id(Pid) when is_pid(Pid) -> @@ -158,7 +157,6 @@ start_link(Name, Id) when is_atom(Name), is_integer(Id) -> {ok, State :: #state{}} | {ok, State :: #state{}, timeout() | hibernate} | {stop, Reason :: term()} | ignore). init([Id]) when is_integer(Id) -> - erlang:process_flag(trap_exit, true), case sdlan_api:get_network(Id) of {ok, #{<<"ipaddr">> := Null}} when Null == <<"null">>; Null == <<"NULL">> -> ignore; @@ -175,7 +173,8 @@ init([Id]) when is_integer(Id) -> erlang:start_timer(?FLOW_REPORT_INTERVAL, self(), flow_report_ticker), sdlan_domain_regedit:insert(Domain), - {ok, #state{network_id = Id, name = Name, domain = Domain, ipaddr = IpAddr, owner_id = OwnerId, mask_len = MaskLen, aes_key = AesKey, throttle_key = ThrottleKey}}; + {ok, #state{network_id = Id, name = Name, domain = Domain, ipaddr = IpAddr, owner_id = OwnerId, + mask_len = MaskLen, aes_key = AesKey, throttle_key = ThrottleKey}}; {error, Reason} -> logger:warning("[sdlan_network] load network: ~p, get error: ~p", [Id, Reason]), ignore @@ -191,21 +190,17 @@ init([Id]) when is_integer(Id) -> {noreply, NewState :: #state{}, timeout() | hibernate} | {stop, Reason :: term(), Reply :: term(), NewState :: #state{}} | {stop, Reason :: term(), NewState :: #state{}}). -%% 给客户端分配ip地址, TODO 这里要绑定hole -handle_call({attach, ClientId, Mac, Ip, HostName}, _From, +%% 给客户端分配ip地址 +handle_call({attach, Peer, ClientId, Mac, Ip, Hostname}, _From, State = #state{network_id = NetworkId, domain = Domain, endpoints = Endpoints, mask_len = MaskLen, aes_key = AesKey}) -> - %% 分配ip地址的时候,以mac地址为唯一基准 - logger:debug("[sdlan_network] alloc_ip, network_id: ~p, client_id: ~p, mac: ~p, net_addr: ~p", + logger:debug("[sdlan_network] alloc_ip, network_id: ~p, client_id: ~p, mac: ~p, ip_addr: ~p", [NetworkId, ClientId, sdlan_util:format_mac(Mac), sdlan_ipaddr:int_to_ipv4(Ip)]), - %% 关闭之前的channel - maybe_close_channel(maps:get(Mac, Endpoints, undefined)), %% 添加域名->ip的映射关系 - sdlan_hostname_regedit:insert(HostName, Domain, Ip), + sdlan_hostname_regedit:insert(Hostname, Domain, Ip), + NEndpoints = maps:put(Mac, #endpoint{client_id = ClientId, ip = Ip, hostname = Hostname, hole = #hole{peer = Peer, nat_type = 0}}, Endpoints), - NEndpoints = maps:put(Mac, #endpoint{client_id = ClientId, ip = Ip}, Endpoints), - - {reply, {ok, Domain, Ip, MaskLen, AesKey}, State#state{endpoints = NEndpoints}}; + {reply, {ok, Domain, MaskLen, AesKey}, State#state{endpoints = NEndpoints}}; handle_call(get_used_map, _From, State = #state{endpoints = Endpoints}) -> UsedInfos = maps:map(fun(_, #endpoint{hole = Hole, v6_info = V6Info}) -> @@ -257,20 +252,20 @@ handle_call({arp_query, TargetIp}, _From, State = #state{endpoints = Endpoints}) %% 网络存在的nat_peer信息 handle_call({peer_info, SrcMac, DstMac}, _From, State = #state{endpoints = Endpoints}) -> case maps:find(DstMac, Endpoints) of - {ok, #endpoint{channel_pid = DstChannelPid, hole = #hole{peer = DstNatPeer, nat_type = DstNatType}, v6_info = DstV6Info}} -> + {ok, #endpoint{hole = #hole{peer = DstNatPeer, nat_type = DstNatType}, v6_info = DstV6Info}} -> %% 让目标服务器发送sendRegister事件(2024-06-25 新增,提高打洞的成功率) - case maps:get(SrcMac, Endpoints, undefined) of - #endpoint{hole = #hole{peer = {SrcNatIp, SrcNatPort}, nat_type = NatType}, v6_info = SrcV6Info} -> - Event = sdlan_pb:encode_msg(#sdl_send_register_event { - dst_mac = SrcMac, - nat_ip = sdlan_ipaddr:ipv4_to_int(SrcNatIp), - nat_type = NatType, - nat_port = SrcNatPort, - v6_info = SrcV6Info - }), - sdlan_channel:send_event(DstChannelPid, ?PACKET_EVENT_SEND_REGISTER, Event); - _ -> - ok + maybe + {ok, #endpoint{hole = #hole{peer = {SrcNatIp, SrcNatPort}, nat_type = NatType}, v6_info = SrcV6Info}} ?= maps:find(SrcMac, Endpoints), + Event = sdlan_pb:encode_msg(#sdl_send_register_event { + dst_mac = SrcMac, + nat_ip = sdlan_ipaddr:ipv4_to_int(SrcNatIp), + nat_type = NatType, + nat_port = SrcNatPort, + v6_info = SrcV6Info + }), + + EventPacket = <>, + sdlan_stun_pool:send_packets([{DstNatPeer, EventPacket}]) end, {reply, {ok, {DstNatPeer, DstNatType}, DstV6Info}, State}; _ -> @@ -377,13 +372,10 @@ handle_cast({update_hole, ClientId, Mac, Peer, NatType, V6Info}, State = #state{ mac = Mac, ip = Ip }), - broadcast(fun(#endpoint{channel_pid = ChannelPid}) -> - case is_process_alive(ChannelPid) of - true -> - sdlan_channel:send_event(ChannelPid, ?PACKET_EVENT_NAT_CHANGED, NatChangedEvent); - false -> - ok - end + EventPacket = <>, + + broadcast(fun(#endpoint{hole = #hole{peer = Peer}}) -> + sdlan_stun_pool:send_packet(Peer, EventPacket) end, [Mac], Endpoints); false -> ok @@ -404,30 +396,7 @@ handle_cast({update_hole, ClientId, Mac, Peer, NatType, V6Info}, State = #state{ handle_info({timeout, _, flow_report_ticker}, State = #state{network_id = NetworkId, forward_bytes = ForwardBytes}) -> erlang:start_timer(?FLOW_REPORT_INTERVAL, self(), flow_report_ticker), catch sdlan_api:network_forward_report(NetworkId, ForwardBytes), - {noreply, State#state{forward_bytes = 0}}; - -handle_info({'EXIT', _Pid, shutdown}, State = #state{network_id = NetworkId, endpoints = Endpoints}) -> - logger:warning("[sdlan_network] network: ~p, get shutdown message", [NetworkId]), - - NetworkShutdownEvent = sdlan_pb:encode_msg(#sdl_network_shutdown_event { - message = <<"Network shutdown">> - }), - broadcast(fun(#endpoint{channel_pid = ChannelPid}) -> - case is_pid(ChannelPid) andalso is_process_alive(ChannelPid) of - true -> - sdlan_channel:send_event(ChannelPid, ?PACKET_EVENT_NETWORK_SHUTDOWN, NetworkShutdownEvent), - sdlan_channel:stop(ChannelPid, normal); - false -> - ok - end - end, Endpoints), - - {stop, shutdown, State}; -%% Channel进程退出, hole里面的数据也需要清理 -handle_info({'DOWN', _MRef, process, ChannelPid, Reason}, State = #state{network_id = NetworkId, endpoints = Endpoints}) -> - logger:notice("[sdlan_network] network_id: ~p, channel_pid: ~p, close with reason: ~p", [NetworkId, ChannelPid, Reason]), - NUsedMap = maps:filter(fun(_, #endpoint{channel_pid = ChannelPid0}) -> ChannelPid =/= ChannelPid0 end, Endpoints), - {noreply, State#state{endpoints = NUsedMap}}. + {noreply, State#state{forward_bytes = 0}}. %% @private %% @doc This function is called by a gen_server when it is about to @@ -442,15 +411,9 @@ terminate(Reason, #state{network_id = NetworkId, endpoints = Endpoints}) -> NetworkShutdownEvent = sdlan_pb:encode_msg(#sdl_network_shutdown_event { message = <<"Network shutdown">> }), - broadcast(fun(#endpoint{channel_pid = ChannelPid, monitor_ref = MRef}) -> - is_reference(MRef) andalso demonitor(MRef), - case is_pid(ChannelPid) andalso is_process_alive(ChannelPid) of - true -> - sdlan_channel:send_event(ChannelPid, ?PACKET_EVENT_NETWORK_SHUTDOWN, NetworkShutdownEvent), - sdlan_channel:stop(ChannelPid, normal); - false -> - ok - end + EventPacket = <>, + broadcast(fun(#endpoint{hole = #hole{peer = Peer}}) -> + sdlan_stun_pool:send_packet(Peer, EventPacket) end, Endpoints), ok. @@ -480,18 +443,6 @@ limiting_check(ThrottleKey) -> end end. --spec maybe_close_channel(undefined | #endpoint{}) -> no_return(). -maybe_close_channel(#endpoint{channel_pid = ChannelPid0, monitor_ref = MRef0}) -> - case is_pid(ChannelPid0) andalso is_process_alive(ChannelPid0) of - true -> - is_reference(MRef0) andalso demonitor(MRef0), - sdlan_channel:stop(ChannelPid0, channel_rebind); - false -> - ok - end; -maybe_close_channel(_) -> - ok. - -spec broadcast(Fun :: binary(), Endpoints :: map()) -> no_return(). broadcast(Fun, Endpoints) when is_function(Fun, 1), is_map(Endpoints) -> broadcast(Fun, [], Endpoints). @@ -507,6 +458,17 @@ broadcast(Fun, ExcludeMacs, Endpoints) when is_function(Fun, 1), is_map(Endpoint end end, Endpoints). +-spec broadcast_peers(Fun :: binary(), ExcludeMacs :: [binary()], Endpoints :: map()) -> no_return(). +broadcast_peers(Fun, ExcludeMacs, Endpoints) when is_function(Fun, 1), is_map(Endpoints), is_list(ExcludeMacs) -> + maps:filtermap(fun(Mac, Endpoint) -> + case lists:member(Mac, ExcludeMacs) of + true -> + ok; + false -> + Fun(Endpoint) + end + end, Endpoints). + %% 解析IpAddr: <<"192.168.172/24">> -spec parse_ipaddr(IpAddr0 :: binary()) -> {IpAddr :: binary(), MaskLen :: integer()}. parse_ipaddr(IpAddr0) when is_binary(IpAddr0) -> diff --git a/apps/sdlan/src/sdlan_register_worker.erl b/apps/sdlan/src/sdlan_register_worker.erl index 081de24..4a7d179 100644 --- a/apps/sdlan/src/sdlan_register_worker.erl +++ b/apps/sdlan/src/sdlan_register_worker.erl @@ -25,15 +25,15 @@ -define(NAK_HOSTNAME_USED, 6). %% API --export([start_link/4, do_work/4]). +-export([start_link/4, do_register/4]). -spec start_link(Sock :: inet:socket(), Ip :: inet:ip4_address(), Port :: integer(), Packet :: binary()) -> {ok, pid()}. start_link(Sock, Ip, Port, Packet) -> - {ok, erlang:spawn_link(?MODULE, do_work, [Sock, Ip, Port, Packet])}. + {ok, erlang:spawn_link(?MODULE, do_register, [Sock, Ip, Port, Packet])}. -do_work(Sock, Ip, Port, <>) -> +do_register(Sock, SrcIp, SrcPort, <>) -> #sdl_register_super{ - version = Version, + version = _Version, client_id = ClientId, network_id = NetworkId, mac = Mac, @@ -45,7 +45,9 @@ do_work(Sock, Ip, Port, <>) -> } = sdlan_pb:decode_msg(Body, sdl_register_super), %% 参数检查 - logger:debug("[sdlan_channel] client_id: ~p, ip: ~p, mac: ~p, host_name: ~p, access_token: ~p, network_id: ~p", [ClientId, Ip, Mac, HostName, AccessToken, NetworkId]), + logger:debug("[sdlan_channel] client_id: ~p, ip: ~p, mac: ~p, host_name: ~p, access_token: ~p, network_id: ~p", + [ClientId, Ip, Mac, HostName, AccessToken, NetworkId]), + true = (Mac =/= <<>> andalso PubKey =/= <<>> andalso ClientId =/= <<>>), %% Mac地址不能是广播地址 true = not (sdlan_util:is_multicast_mac(Mac) orelse sdlan_util:is_broadcast_mac(Mac)), @@ -65,7 +67,7 @@ do_work(Sock, Ip, Port, <>) -> %% 建立到network的对应关系 case sdlan_network:get_pid(NetworkId) of NetworkPid when is_pid(NetworkPid) -> - try sdlan_network:attach(NetworkPid, self(), ClientId, Mac, Ip, HostName) of + case sdlan_network:attach(NetworkPid, {SrcIp, SrcPort}, ClientId, Mac, Ip, HostName) of {ok, AesKey, SessionToken} -> RsaPubKey = sdlan_cipher:rsa_pem_decode(PubKey), EncodedAesKey = rsa_encode(AesKey, RsaPubKey), @@ -77,33 +79,31 @@ do_work(Sock, Ip, Port, <>) -> %% 发送确认信息 Reply = <>, - gen_udp:send(Sock, Ip, Port, Reply), + gen_udp:send(Sock, SrcIp, SrcPort, Reply), %% 设置节点的在线状态 Result = sdlan_api:node_online(ClientId, NetworkId, sdlan_ipaddr:int_to_ipv4(Ip)), logger:debug("[sdlan_channel] client_id: ~p, set none online result is: ~p", [ClientId, Result]); {error, no_ip} -> - logger:warning("[sdlan_channel] client_id: ~p, register get error: no_ip", [ClientId, Token]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, ?NAK_NO_IP, <<"No Ip address">>)); + logger:warning("[sdlan_channel] client_id: ~p, register get error: no_ip", [ClientId]), + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, ?NAK_NO_IP, <<"No Ip address">>)); {error, host_name_used} -> - logger:warning("[sdlan_channel] client_id: ~p, token: ~p, register get error: host_name_used", [ClientId, Token]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, ?NAK_HOSTNAME_USED, <<"Host Name Used">>)); + logger:warning("[sdlan_channel] client_id: ~p, register get error: host_name_used", [ClientId]), + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, ?NAK_HOSTNAME_USED, <<"Host Name Used">>)); {error, client_disabled} -> - logger:warning("[sdlan_channel] client_id: ~p, token: ~p, register get error: client_disabled", [ClientId, Token]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, ?NAK_NODE_DISABLE, <<"Client Connection Disable">>)) - catch _:Error:Stack -> - logger:warning("[sdlan_channel] get error: ~p, stack: ~p", [Error, Stack]) + logger:warning("[sdlan_channel] client_id: ~p, register get error: client_disabled", [ClientId]), + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, ?NAK_NODE_DISABLE, <<"Client Connection Disable">>)) end; undefined -> - logger:warning("[sdlan_channel] client_id: ~p, token: ~p, register get error: network not found", [ClientId, Token]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, ?NAK_INTERNAL_FAULT, <<"Internal Error">>)) + logger:warning("[sdlan_channel] client_id: ~p, register get error: network not found", [ClientId]), + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, ?NAK_INTERNAL_FAULT, <<"Internal Error">>)) end; {ok, #{<<"error">> := #{<<"code">> := Code, <<"message">> := Message}}} -> logger:warning("[sdlan_channel] network_id: ~p, client_id: ~p, register get error: ~ts, error_code: ~p", [NetworkId, ClientId, Message, Code]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, Code, Message)); + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, Code, Message)); {error, Reason} -> logger:warning("[sdlan_channel] network_id: ~p, client_id: ~p, register get error: ~p", [NetworkId, ClientId, Reason]), - gen_udp:send(Sock, Ip, Port, register_nak_reply(PacketId, ?NAK_NETWORK_FAULT, <<"Network Error">>)) + gen_udp:send(Sock, SrcIp, SrcPort, register_nak_reply(PacketId, ?NAK_NETWORK_FAULT, <<"Network Error">>)) end, exit(normal). @@ -113,7 +113,8 @@ register_nak_reply(PacketId, ErrorCode, ErrorMsg) when is_integer(PacketId), is_ error_code = ErrorCode, error_message = ErrorMsg }), - <>. + <>. +-spec rsa_encode(PlainText :: binary(), RsaPubKey :: public_key:rsa_public_key()) -> binary(). rsa_encode(PlainText, RsaPubKey) when is_binary(PlainText) -> iolist_to_binary(sdlan_cipher:rsa_encrypt(PlainText, RsaPubKey)). \ No newline at end of file diff --git a/apps/sdlan/src/sdlan_stun_pool.erl b/apps/sdlan/src/sdlan_stun_pool.erl index 6596def..729d91f 100644 --- a/apps/sdlan/src/sdlan_stun_pool.erl +++ b/apps/sdlan/src/sdlan_stun_pool.erl @@ -16,7 +16,7 @@ -behaviour(gen_server). %% API --export([start_link/0, send_packets/1]). +-export([start_link/0, send_packets/1, send_packet/2]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). @@ -33,6 +33,9 @@ %%% API %%%=================================================================== +send_packet(Peer, Packet) -> + gen_server:cast(?SERVER, {send_packet, Peer, Packet}). + send_packets(Packets) when is_list(Packets) -> gen_server:cast(?SERVER, {send_packets, Packets}). @@ -94,11 +97,16 @@ handle_call(_Request, _From, State = #state{}) -> {stop, Reason :: term(), NewState :: #state{}}). %% 当前node下的转发,基于进程间的通讯 +handle_cast({send_packet, {Ip, Port}, Packet}, State = #state{workers = Workers, idx = Idx, num = Num}) -> + {Sock, _} = element(Idx, Workers), + gen_udp:send(Sock, Ip, Port, Packet), + NewIdx = (Idx rem Num) + 1, + {noreply, State#state{idx = NewIdx}}; + handle_cast({send_packets, Packets}, State = #state{workers = Workers, idx = Idx, num = Num}) -> {Sock, _} = element(Idx, Workers), - lists:foreach(fun({Ip, Port, Data}) -> gen_udp:send(Sock, Ip, Port, Data) end, Packets), + lists:foreach(fun({{Ip, Port}, Data}) -> gen_udp:send(Sock, Ip, Port, Data) end, Packets), NewIdx = (Idx rem Num) + 1, - {noreply, State#state{idx = NewIdx}}. %% @private