#!/usr/bin/env python3

# Copyright (C) 2025 by sysmocom - s.f.m.c. GmbH <info@sysmocom.de>
# Author: Vadim Yanitskiy <vyanitskiy@sysmocom.de>
#
# All Rights Reserved
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import logging
import argparse
import cmd2
import sys

import tabulate
import urllib.request
import http.client
import json

# local logger for this module
log = logging.getLogger(__name__)


class RestIface:
    ''' REST interface for OsmoS1GW '''

    HTTPResponse = http.client.HTTPResponse
    RESTResponse = dict | list[dict]

    def __init__(self, host: str, port: int):
        self.url = f'http://{host}:{port}'

    def send_req(self, method: str,
                       path: str = '',
                       data: dict = {}) -> HTTPResponse:
        ''' Send an HTTP request to the given endpoint (path) '''
        req = urllib.request.Request(f'{self.url}/{path}', method=method)
        req.add_header('Accept', 'application/json')
        if data:
            req.add_header('Content-Type', 'application/json')
            req.data = json.dumps(data).encode('utf-8')
        log.debug(f'HTTP {req.method} {req.full_url}')
        return urllib.request.urlopen(req)

    def send_get_req(self, path: str, query: dict = {}) -> HTTPResponse:
        ''' Send an HTTP GET request to the given endpoint (path) '''
        if query:
            path += '?' + urllib.parse.urlencode(query)
        return self.send_req('GET', path)

    def send_post_req(self, path: str, data: dict = {}) -> HTTPResponse:
        ''' Send an HTTP POST request to the given endpoint (path) '''
        return self.send_req('POST', path, data)

    def send_delete_req(self, path: str, data: dict = {}) -> HTTPResponse:
        ''' Send an HTTP DELETE request to the given endpoint (path) '''
        return self.send_req('DELETE', path, data)

    def fetch_spec(self) -> RESTResponse:
        ''' Fetch the OpenAPI specification (JSON) '''
        with self.send_get_req('swagger/spec.json') as f:
            return json.load(f)

    def metrics_list(self, type: str = 'all', path: str = '') -> RESTResponse:
        ''' MetricsList :: Get a list of metrics '''
        query = {'type' : type}
        if path:
            query['path'] = path
        with self.send_get_req('metrics-list', query) as f:
            return json.load(f)

    def pfcp_assoc_state(self) -> RESTResponse:
        ''' PfcpAssocState :: Get the PFCP association state '''
        with self.send_get_req('pfcp/assoc') as f:
            return json.load(f)

    def pfcp_assoc_setup(self) -> RESTResponse:
        ''' PfcpAssocSetup :: Initiate the PFCP Association Setup procedure '''
        with self.send_post_req('pfcp/assoc') as f:
            return json.load(f)

    def pfcp_assoc_release(self) -> RESTResponse:
        ''' PfcpAssocRelease :: Initiate the PFCP Association Release procedure '''
        with self.send_delete_req('pfcp/assoc') as f:
            return json.load(f)

    def pfcp_heartbeat(self) -> RESTResponse:
        ''' PfcpHeartbeat :: Send a PFCP Heartbeat Request to the peer '''
        with self.send_post_req('pfcp/heartbeat') as f:
            return json.load(f)

    def enb_list(self) -> RESTResponse:
        ''' EnbList :: Get a list of eNB connections '''
        with self.send_get_req('enb-list') as f:
            return json.load(f)

    def enb_info(self, enb_id: str) -> RESTResponse:
        ''' EnbInfo :: Get information about a specific eNB '''
        with self.send_get_req(f'enb/{enb_id}') as f:
            return json.load(f)

    def enb_erab_list(self, enb_id: str) -> RESTResponse:
        ''' EnbErabList :: Get E-RAB list for a specific eNB '''
        with self.send_get_req(f'enb/{enb_id}/erab-list') as f:
            return json.load(f)

    def erab_list(self) -> RESTResponse:
        ''' ErabList :: Get E-RAB list for all eNBs '''
        with self.send_get_req('erab-list') as f:
            return json.load(f)

    def erab_info(self, pid: str) -> RESTResponse:
        ''' ErabInfo :: Get information about a specific E-RAB '''
        with self.send_get_req(f'erab/pid:{pid}') as f:
            return json.load(f)


class OsmoS1GWCli(cmd2.Cmd):
    DESC = 'Interactive CLI for OsmoS1GW'

    CAT_METRICS = 'Metrics commands'
    CAT_PFCP = 'PFCP related commands'
    CAT_ENB = 'eNB related commands'
    CAT_ERAB = 'E-RAB related commands'

    def __init__(self, argv):
        super().__init__(allow_cli_args=False, include_py=True)

        if argv.verbose > 0:
            logging.root.setLevel(logging.DEBUG)
            self.debug = True

        self.intro = cmd2.style('Welcome to %s!' % self.DESC, fg=cmd2.Fg.RED)
        self.default_category = 'Built-in commands'
        self.prompt = 'OsmoS1GW# '

        self.tablefmt = 'github' # default table format for tabulate
        self.add_settable(cmd2.Settable('tablefmt', str, 'Table format for tabulate', self,
                                        choices=tabulate.tabulate_formats))

        self.iface = RestIface(argv.HOST, argv.port)

    def do_fetch_openapi_spec(self, opts):
        ''' Fetch the OpenAPI specification (JSON), dump as text '''
        spec = self.iface.fetch_spec()
        self.poutput(json.dumps(spec, indent=4))

    @staticmethod
    def metrics_list_item(item: dict) -> dict:
        ''' Generate a table row for the given metric '''
        return {
            'Name': item.get('name'),
            'Type': item.get('type'),
            'Value': item.get('value'),
        }

    metrics_list_parser = cmd2.Cmd2ArgumentParser()
    metrics_list_parser.add_argument('-t', '--type',
                                     type=str, default='all',
                                     choices=('all', 'counter', 'gauge'),
                                     help='Metric type (default: %(default)s)')
    metrics_list_parser.add_argument('PATH',
                                     type=str, default='', nargs='?',
                                     help='Metric path')

    @cmd2.with_argparser(metrics_list_parser)
    @cmd2.with_category(CAT_METRICS)
    def do_metrics_list(self, opts) -> None:
        ''' Get a list of metrics '''
        data = self.iface.metrics_list(opts.type, opts.PATH)
        self.poutput(tabulate.tabulate(map(self.metrics_list_item, data),
                                       headers='keys', tablefmt=self.tablefmt))

    @cmd2.with_category(CAT_PFCP)
    def do_pfcp_assoc_state(self, opts) -> None:
        ''' Get the PFCP association state '''
        data = self.iface.pfcp_assoc_state()
        table = [] # [param, value]
        table.append(['State', data['state']])
        table.append(['Local address', data['laddr']])
        table.append(['Remote address', data['raddr']])
        table.append(['Local Recovery TimeStamp', data['lrts']])
        if 'rrts' in data:
            table.append(['Remote Recovery TimeStamp', data['rrts']])
        self.poutput(tabulate.tabulate(table,
                                       headers=['Parameter', 'Value'],
                                       tablefmt=self.tablefmt))

    @cmd2.with_category(CAT_PFCP)
    def do_pfcp_assoc_setup(self, opts) -> None:
        ''' Initiate the PFCP Association Setup procedure '''
        raise NotImplementedError

    @cmd2.with_category(CAT_PFCP)
    def do_pfcp_assoc_release(self, opts) -> None:
        ''' Initiate the PFCP Association Release procedure '''
        raise NotImplementedError

    @cmd2.with_category(CAT_PFCP)
    def do_pfcp_heartbeat(self, opts) -> None:
        ''' Send a PFCP Heartbeat Request '''
        data = self.iface.pfcp_heartbeat()
        if data['success']:
            self.poutput('Heartbeat succeeded')
        else:
            self.perror('Heartbeat failed: {message}'.format(**data))

    @staticmethod
    def enb_list_item(item: dict) -> dict:
        ''' Generate a table row for the given eNB '''
        enb_addr = lambda item: '{enb_saddr}:{enb_sport} ({enb_sctp_aid})'.format(**item)
        mme_addr = lambda item: '{mme_daddr}:{mme_dport} ({mme_sctp_aid})'.format(**item)
        return {
            'eNB handle': item.get('handle'),
            'PID': item.get('pid'),
            'Global-eNB-ID': item.get('genb_id', '(unknown)'),
            'State': item.get('state'),
            'eNB addr:port (aid)': enb_addr(item) if 'enb_saddr' in item else None,
            'MME addr:port (aid)': mme_addr(item) if 'mme_daddr' in item else None,
            'Uptime (s)': item.get('uptime'),
            '# E-RABs': item.get('erab_count'),
        }

    def enb_list_print(self, items: list[dict]) -> None:
        ''' Print a list of eNBs in tabular form '''
        self.poutput(tabulate.tabulate(map(self.enb_list_item, items),
                                       headers='keys', tablefmt=self.tablefmt))

    def enb_info_print(self, item: dict) -> None:
        ''' Print eNB info in tabular form '''
        self.poutput(tabulate.tabulate(self.enb_list_item(item).items(),
                                       headers=['Parameter', 'Value'],
                                       tablefmt=self.tablefmt))

    @cmd2.with_category(CAT_ENB)
    def do_enb_list(self, opts) -> None:
        ''' Get a list of eNB connections '''
        data = self.iface.enb_list()
        self.enb_list_print(data)

    @staticmethod
    def gen_enb_id(opts) -> str:
        ''' Generate the EnbId parameter value (for URL) '''
        if opts.handle is not None:
            return f'handle:{opts.handle}'
        elif opts.pid is not None:
            return f'pid:{opts.pid}'
        elif opts.genbid is not None:
            return f'genbid:{opts.genbid}'
        elif opts.enb_sctp_aid is not None:
            return f'enb-sctp-aid:{opts.enb_sctp_aid}'
        elif opts.mme_sctp_aid is not None:
            return f'mme-sctp-aid:{opts.mme_sctp_aid}'
        raise ValueError # shall not happen

    @staticmethod
    def add_enb_id_group(parser):
        ''' Add argparse group for the EnbId parameter '''
        enb_id_group = parser.add_argument_group('eNB ID')
        enb_id_group = enb_id_group.add_mutually_exclusive_group(required=True)
        enb_id_group.add_argument('-H', '--handle',
                                  type=int,
                                  help='eNB handle (example: 0)')
        enb_id_group.add_argument('-P', '--pid',
                                  type=str,
                                  help='eNB process ID (example: 0.33.1)')
        enb_id_group.add_argument('-G', '--genbid',
                                  type=str,
                                  help='Global-eNB-ID (example: 262-42-1337)')
        enb_id_group.add_argument('--enb-sctp-aid',
                                  type=int, metavar='AID',
                                  help='eNB association identifier (example: 42)')
        enb_id_group.add_argument('--mme-sctp-aid',
                                  type=int, metavar='AID',
                                  help='MME association identifier (example: 42)')
        return enb_id_group

    enb_info_parser = cmd2.Cmd2ArgumentParser()
    add_enb_id_group(enb_info_parser)

    @cmd2.with_argparser(enb_info_parser)
    @cmd2.with_category(CAT_ENB)
    def do_enb_info(self, opts) -> None:
        ''' Get information about a specific eNB '''
        enb_id = self.gen_enb_id(opts)
        data = self.iface.enb_info(enb_id)
        self.enb_info_print(data)

    @staticmethod
    def erab_list_item(item: dict) -> dict:
        ''' Generate a table row for the given E-RAB (brief) '''
        return {
            'PID': item.get('pid'),
            'MME-UE-S1AP-ID': item.get('mme_ue_id'),
            'E-RAB-ID': item.get('erab_id'),
            'State': item.get('state'),
        }

    @classmethod
    def erab_list_item_full(cls, item: dict) -> dict:
        ''' Generate a table row for the given E-RAB (full) '''
        f_teid = lambda params: '0x{teid:08x}@{tla}'.format(**params)
        seid = lambda val: f'0x{val:016x}'
        return {
            **cls.erab_list_item(item),
            'SEID (local)': seid(item.get('pfcp_lseid')) if 'pfcp_lseid' in item else None,
            'SEID (remote)': seid(item.get('pfcp_rseid')) if 'pfcp_rseid' in item else None,
            'U2C F-TEID': f_teid(item.get('f_teid_u2c')) if 'f_teid_u2c' in item else None,
            'C2U F-TEID': f_teid(item.get('f_teid_c2u')) if 'f_teid_c2u' in item else None,
            'A2U F-TEID': f_teid(item.get('f_teid_a2u')) if 'f_teid_a2u' in item else None,
            'U2A F-TEID': f_teid(item.get('f_teid_u2a')) if 'f_teid_u2a' in item else None,
        }

    def erab_list_print(self, items: list[dict], full: bool) -> None:
        ''' Print a list of E-RABs in tabular form '''
        func = self.erab_list_item_full if full else self.erab_list_item
        self.poutput(tabulate.tabulate(map(func, items),
                                       headers='keys', tablefmt=self.tablefmt))

    def erab_info_print(self, item: dict) -> None:
        ''' Print E-RAB info in tabular form '''
        self.poutput(tabulate.tabulate(self.erab_list_item_full(item).items(),
                                       headers=['Parameter', 'Value'],
                                       tablefmt=self.tablefmt))

    enb_erab_list_parser = cmd2.Cmd2ArgumentParser()
    enb_erab_list_parser.add_argument('-f', '--full',
                                      action='store_true',
                                      help='Print full table (more columns)')
    add_enb_id_group(enb_erab_list_parser)

    @cmd2.with_argparser(enb_erab_list_parser)
    @cmd2.with_category(CAT_ENB)
    def do_enb_erab_list(self, opts) -> None:
        ''' Get E-RAB list for a specific eNB '''
        enb_id = self.gen_enb_id(opts)
        data = self.iface.enb_erab_list(enb_id)
        self.erab_list_print(data, opts.full)

    erab_list_parser = cmd2.Cmd2ArgumentParser()
    erab_list_parser.add_argument('-f', '--full',
                                  action='store_true',
                                  help='Print full table (more columns)')

    @cmd2.with_argparser(erab_list_parser)
    @cmd2.with_category(CAT_ERAB)
    def do_erab_list(self, opts) -> None:
        ''' Get E-RAB list for all eNBs '''
        data = self.iface.erab_list()
        self.erab_list_print(data, opts.full)

    erab_info_parser = cmd2.Cmd2ArgumentParser()
    erab_info_parser.add_argument('-P', '--pid',
                                  type=str, required=True,
                                  help='E-RAB process ID (example: 0.33.1)')

    @cmd2.with_argparser(erab_info_parser)
    @cmd2.with_category(CAT_ERAB)
    def do_erab_info(self, opts) -> None:
        ''' Get information about a specific E-RAB '''
        data = self.iface.erab_info(opts.pid)
        self.erab_info_print(data)


ap = argparse.ArgumentParser(prog='osmo-s1gw-cli', description=OsmoS1GWCli.DESC)

ap.add_argument('-v', '--verbose', action='count', default=0,
                help='print debug logging')
ap.add_argument('-p', '--port', metavar='PORT', type=int, default=8080,
                help='OsmoS1GW REST port (default: %(default)s)')
ap.add_argument('HOST', type=str, nargs='?', default='localhost',
                help='OsmoS1GW REST host/address (default: %(default)s)')

logging.basicConfig(
    format='\r[%(levelname)s] %(filename)s:%(lineno)d %(message)s', level=logging.INFO)

if __name__ == '__main__':
    argv = ap.parse_args()
    app = OsmoS1GWCli(argv)
    sys.exit(app.cmdloop())
