%% Copyright (C) 2024 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
%% Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
%%
%% All Rights Reserved
%%
%% SPDX-License-Identifier: AGPL-3.0-or-later
%%
%% This program is free software; you can redistribute it and/or modify
%% it under the terms of the GNU Affero General Public License as
%% published by the Free Software Foundation; either version 3 of the
%% License, or (at your option) any later version.
%%
%% This program is distributed in the hope that it will be useful,
%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%% GNU General Public License for more details.
%%
%% You should have received a copy of the GNU Affero General Public License
%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%
%% Additional Permission under GNU AGPL version 3 section 7:
%%
%% If you modify this Program, or any covered work, by linking or
%% combining it with runtime libraries of Erlang/OTP as released by
%% Ericsson on https://www.erlang.org (or a modified version of these
%% libraries), containing parts covered by the terms of the Erlang Public
%% License (https://www.erlang.org/EPLICENSE), the licensors of this
%% Program grant you additional permission to convey the resulting work
%% without the need to license the runtime libraries of Erlang/OTP under
%% the GNU Affero General Public License. Corresponding Source for a
%% non-source form of such a combination shall include the source code
%% for the parts of the runtime libraries of Erlang/OTP used as well as
%% that of the covered work.

-module(sctp_server).
-behaviour(gen_server).

-export([init/1,
         handle_info/2,
         handle_call/3,
         handle_cast/2,
         terminate/2]).
-export([start_link/4,
         send_data/2,
         fetch_conn_list/0,
         shutdown/0]).

-include_lib("kernel/include/logger.hrl").
-include_lib("kernel/include/inet.hrl").
-include_lib("kernel/include/inet_sctp.hrl").

-include("s1gw_metrics.hrl").
-include("s1ap.hrl").

-type addr_port() :: {inet:ip_address(),
                      inet:port_number()}.

-type conn_info() :: #{pid => pid(),
                       aid => gen_sctp:assoc_id(),
                       addr => inet:ip_address(),
                       port => inet:port_number()
                      }.


-record(server_state, {sock :: gen_sctp:sctp_socket(),
                       clients :: dict:dict(),
                       handler :: module(),
                       priv :: term()
                      }).

-record(client_state, {addr_port :: addr_port(),
                       pid :: pid() | undefined
                      }).


%% ------------------------------------------------------------------
%% public API
%% ------------------------------------------------------------------

start_link(BindAddr, BindPort, Handler, Priv) ->
    gen_server:start_link({local, ?MODULE}, ?MODULE,
                          [BindAddr, BindPort,
                           Handler, Priv],
                          []).


send_data(Aid, Data) ->
    gen_server:cast(?MODULE, {?FUNCTION_NAME, Aid, Data}).


-spec fetch_conn_list() -> [conn_info()].
fetch_conn_list() ->
    gen_server:call(?MODULE, ?FUNCTION_NAME).


shutdown() ->
    gen_server:stop(?MODULE).


%% ------------------------------------------------------------------
%% gen_server API
%% ------------------------------------------------------------------

init([BindAddrStr, BindPort, Handler, Priv]) when is_list(BindAddrStr) ->
    {ok, BindAddr} = inet:parse_address(BindAddrStr),
    init([BindAddr, BindPort, Handler, Priv]);

init([BindAddr, BindPort, Handler, Priv]) ->
    process_flag(trap_exit, true),
    {ok, Sock} = gen_sctp:open([{ip, BindAddr},
                                {port, BindPort},
                                {type, seqpacket},
                                {reuseaddr, true},
                                {active, true}]),
    ?LOG_INFO("SCTP server listening on ~w:~w", [BindAddr, BindPort]),
    ok = gen_sctp:listen(Sock, true),
    {ok, #server_state{sock = Sock,
                       clients = dict:new(),
                       handler = Handler,
                       priv = Priv}}.


handle_call(fetch_conn_list, _From,
            #server_state{clients = Clients} = S) ->
    Reply = lists:map(fun gen_conn_info/1,
                      dict:to_list(Clients)),
    {reply, Reply, S};

handle_call(Info, From, S) ->
    ?LOG_ERROR("unknown ~p() from ~p: ~p", [?FUNCTION_NAME, From, Info]),
    {reply, {error, not_implemented}, S}.


handle_cast({send_data, Aid, Data},
            #server_state{sock = Sock} = S) ->
    sctp_common:send_data_to(enb, {Sock, Aid}, Data),
    {noreply, S};

handle_cast(Info, S) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, S}.


%% Handle SCTP events coming from gen_sctp module
handle_info({sctp, _Socket, FromAddr, FromPort, {AncData, Data}}, S0) ->
    S1 = sctp_recv({FromAddr, FromPort, AncData, Data}, S0),
    {noreply, S1};

%% Handle termination events of the child processes
handle_info({'EXIT', Pid, Reason},
            #server_state{sock = Sock, clients = Clients} = S0) ->
    ?LOG_DEBUG("Child process ~p terminated with reason ~p", [Pid, Reason]),
    case client_find(Pid, S0) of
        {ok, {Aid, C0}} ->
            %% shutdown the eNB connection gracefully
            sctp_common:shutdown({Sock, Aid}),
            %% invalidate pid in the client's state
            C1 = C0#client_state{pid = undefined},
            S1 = S0#server_state{clients = dict:store(Aid, C1, Clients)},
            {noreply, S1};
        error ->
            {noreply, S0}
    end;

%% Catch-all for unknown messages
handle_info(Info, S) ->
    ?LOG_ERROR("unknown ~p(): ~p", [?FUNCTION_NAME, Info]),
    {noreply, S}.


terminate(Reason, S) ->
    ?LOG_NOTICE("Terminating, reason ~p", [Reason]),
    close_conns(S),
    gen_sctp:close(S#server_state.sock),
    ok.

%% ------------------------------------------------------------------
%% private API
%% ------------------------------------------------------------------

%% Handle an #sctp_assoc_change event (connection state)
sctp_recv({FromAddr, FromPort, [],
           #sctp_assoc_change{state = ConnState,
                              assoc_id = Aid}},
          #server_state{} = S) ->
    case ConnState of
        comm_up ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) established", [Aid, FromAddr, FromPort]),
            client_add(Aid, FromAddr, FromPort, S);
        shutdown_comp ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) closed", [Aid, FromAddr, FromPort]),
            client_del(Aid, S);
        comm_lost ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) lost", [Aid, FromAddr, FromPort]),
            client_del(Aid, S);
        _ ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) state ~p",
                        [Aid, FromAddr, FromPort, ConnState]),
            S
    end;

%% Handle an #sctp_sndrcvinfo event (incoming data)
sctp_recv({FromAddr, FromPort,
           [#sctp_sndrcvinfo{assoc_id = Aid,
                             stream = SID,
                             ssn = SSN,
                             tsn = TSN}], Data},
          #server_state{clients = Clients,
                        handler = Handler} = S) ->
    ?LOG_DEBUG("eNB connection (id=~p, ~p:~p) -> MME: ~p",
               [Aid, FromAddr, FromPort,
                #{tsn => TSN, sid => SID, ssn => SSN,
                  len => byte_size(Data), data => Data}]),
    s1gw_metrics:ctr_inc(?S1GW_CTR_S1AP_ENB_ALL_RX),
    case dict:find(Aid, Clients) of
        {ok, #client_state{pid = undefined}} ->
            ?LOG_NOTICE("eNB connection (id=~p, ~p:~p) -> MME data ignored (no handler)",
                        [Aid, FromAddr, FromPort]);
        {ok, #client_state{pid = Pid}} ->
            Handler:send_data(Pid, Data);
        error ->
            ?LOG_ERROR("eNB connection (id=~p, ~p:~p) is not known to us?!?",
                       [Aid, FromAddr, FromPort]),
            s1gw_metrics:ctr_inc(?S1GW_CTR_S1AP_ENB_ALL_RX_UNKNOWN_ENB)
    end,
    S;

%% Catch-all for other kinds of SCTP events
sctp_recv({FromAddr, FromPort, AncData, Data},
          #server_state{} = S) ->
    ?LOG_DEBUG("Unhandled SCTP event (~p:~p): ~p, ~p",
               [FromAddr, FromPort, AncData, Data]),
    S.


%% Add a new client to the list, spawning a handler process
client_add(Aid, FromAddr, FromPort,
           #server_state{clients = Clients,
                         handler = Handler,
                         priv = Priv} = S) ->
    {ok, Pid} = Handler:start_link(Aid, Priv),
    s1gw_metrics:gauge_inc(?S1GW_GAUGE_S1AP_ENB_NUM_SCTP_CONNECTIONS),
    NewClient = #client_state{addr_port = {FromAddr, FromPort}, pid = Pid},
    S#server_state{clients = dict:store(Aid, NewClient, Clients)}.


%% Delete an existing client from the list, stopping the handler process
client_del(Aid,
           #server_state{clients = Clients,
                         handler = Handler} = S) ->
    case dict:find(Aid, Clients) of
        {ok, Client} ->
            %% the handler process might be already dead, so we guard
            %% against exceptions like noproc or {nodedown,Node}.
            catch Handler:shutdown(Client#client_state.pid),
            s1gw_metrics:gauge_dec(?S1GW_GAUGE_S1AP_ENB_NUM_SCTP_CONNECTIONS),
            S#server_state{clients = dict:erase(Aid, Clients)};
        error ->
            ?LOG_ERROR("eNB connection (id=~p) is not known to us?!?", [Aid]),
            S
    end.


%% Find a client by process ID
client_find(Pid, #server_state{clients = Clients}) ->
    client_find(Pid, dict:to_list(Clients));

client_find(Pid, [{Aid, Client} | Clients]) ->
    case Client of
        #client_state{pid = Pid} ->
            {ok, {Aid, Client}};
        _ ->
            client_find(Pid, Clients)
    end;

client_find(_Pid, []) ->
    error.


%% Gracefully terminate client connections
close_conns(#server_state{sock = Sock,
                          clients = Clients,
                          handler = Handler}) ->
    close_conns(Sock, Handler, dict:to_list(Clients)).

close_conns(Sock, Handler, [{Aid, Client} | Clients]) ->
    {FromAddr, FromPort} = Client#client_state.addr_port,
    ?LOG_NOTICE("Terminating eNB connection (id=~p, ~p:~p)", [Aid, FromAddr, FromPort]),
    %% the handler process might be already dead, so we guard
    %% against exceptions like noproc or {nodedown,Node}.
    catch Handler:shutdown(Client#client_state.pid),
    %% shutdown an eNB connection gracefully
    sctp_common:shutdown({Sock, Aid}),
    %% ... and so for the remaining clients
    close_conns(Sock, Handler, Clients);

close_conns(_Sock, _Handler, []) ->
    ok.


-spec gen_conn_info(tuple()) -> conn_info().
gen_conn_info({Aid, #client_state{pid = Pid,
                                  addr_port = {Addr, Port}}}) ->
    #{pid => Pid,
      aid => Aid,
      addr => Addr,
      port => Port}.

%% vim:set ts=4 sw=4 et: