/* GTP-U ECHO implementation for osmo-upf */

#include <stdint.h>
#include <errno.h>

#include <osmocom/core/endian.h>
#include <osmocom/core/socket.h>
#include <osmocom/core/msgb.h>

#include <osmocom/upf/upf.h>
#include <osmocom/upf/upf_gtp.h>

#define GTP1U_PORT	2152

enum gtp1u_msgt {
	GTP1U_MSGTYPE_ECHO_REQ			= 1,
	GTP1U_MSGTYPE_ECHO_RSP			= 2,
	GTP1U_MSGTYPE_ERRR_IND			= 26,
	GTP1U_MSGTYPE_SUPP_EXT_HDR_NOTIF	= 31,
	GTP1U_MSGTYPE_END_MARKER		= 254,
	GTP1U_MSGTYPE_GPDU			= 255,
};

enum gtp1u_iei {
	GTP1U_IEI_RECOVERY = 14,
};

/* 3GPP TS 29.281 */
struct gtp1u_hdr {
#if OSMO_IS_LITTLE_ENDIAN
	uint8_t pn:1, /*< N-PDU Number flag */
		s:1, /*< Sequence number flag */
		e:1, /*< Extension header flag */
		spare:1,
		pt:1, /*< Protocol Type: GTP=1, GTP'=0 */
		version:3; /*< Version: 1 */
#elif OSMO_IS_BIG_ENDIAN
/* auto-generated from the little endian part above (libosmocore/contrib/struct_endianness.py) */
	uint8_t version:3, pt:1, spare:1, e:1, s:1, pn:1;
#endif
	uint8_t msg_type;
	uint16_t length;
	uint32_t tei; /*< 05 - 08 Tunnel Endpoint ID */
	union {
		uint8_t data1[0];
		struct {
			uint16_t seq_nr;
			uint8_t n_pdu_nr;
			uint8_t next_ext_type;
		} ext;
	};
	uint8_t data2[0];
} __attribute__((packed));

static int tx_echo_resp(struct upf_gtp_dev *dev, const struct osmo_sockaddr *remote, uint16_t seq_nr);

static int rx_echo_req(struct upf_gtp_dev *dev, const struct osmo_sockaddr *remote, const struct gtp1u_hdr *rx_h,
		       size_t msg_len)
{
	uint16_t seq_nr = 0;
	uint8_t recovery_count = 0;
	if (msg_len >= (sizeof(*rx_h) + 2) && rx_h->data2[0] == GTP1U_IEI_RECOVERY)
		recovery_count = rx_h->data2[1];

	seq_nr = rx_h->s;
	LOG_GTP_DEV(dev, LOGL_INFO, "<- %s: rx GTPv1-U Echo Request: seq_nr=%u recovery_count=%u\n",
		    osmo_sockaddr_to_str(remote), seq_nr, recovery_count);

	return tx_echo_resp(dev, remote, rx_h->ext.seq_nr);
}

static void rx_echo_resp(struct upf_gtp_dev *dev, const struct osmo_sockaddr *remote, const struct gtp1u_hdr *rx_h,
			 size_t msg_len)
{
	if (msg_len < (sizeof(*rx_h) + 2)) {
		LOG_GTP_DEV(dev, LOGL_ERROR,
			    "<- %s: rx GTPv1-U Echo Response, but message is too short (%zu < %zu)\n",
			    osmo_sockaddr_to_str_c(OTC_SELECT, remote), msg_len, (sizeof(*rx_h) + 2));
		return;
	}

	uint8_t recovery_count = rx_h->data2[1];
	LOG_GTP_DEV(dev, LOGL_INFO, "<- %s: rx GTPv1-U Echo Response: seq_nr=%u recovery_count=%u\n",
		    osmo_sockaddr_to_str(remote), rx_h->ext.seq_nr, recovery_count);
}

static int tx_echo_resp(struct upf_gtp_dev *dev, const struct osmo_sockaddr *remote, uint16_t seq_nr)
{
	struct msgb *msg;
	struct gtp1u_hdr *tx_h;
	int rc;

	msg = msgb_alloc_headroom(1024, 128, "GTPv1-U-echo-resp");
	tx_h = (void *)msgb_put(msg, sizeof(*tx_h));

	*tx_h = (struct gtp1u_hdr){
		/* 3GPP TS 29.281 5.1 defines that the ECHO REQ & RESP shall contain a sequence nr */
		.s = 1,
		.pt = 1,
		.version = 1,
		.msg_type = GTP1U_MSGTYPE_ECHO_RSP,
		.ext = {
			.seq_nr = seq_nr,
		},
	};

	OSMO_ASSERT(msg->tail == tx_h->data2);

	/* ECHO RESPONSE shall contain a recovery counter */
	msgb_put_u8(msg, GTP1U_IEI_RECOVERY);
	msgb_put_u8(msg, g_upf->tunend.recovery_count);

	osmo_store16be(msg->tail - tx_h->data1, &tx_h->length);

	rc = sendto(dev->gtpv1.ofd.fd, msgb_data(msg), msgb_length(msg), 0, &remote->u.sa, sizeof(*remote));
	if (rc < 0) {
		rc = -errno;
		LOG_GTP_DEV(dev, LOGL_ERROR, "-> %s: tx GTPv1-U Echo Response: sendto(len=%d): %s\n",
			    osmo_sockaddr_to_str(remote), msgb_length(msg), strerror(-rc));
	} else {
		LOG_GTP_DEV(dev, LOGL_INFO, "-> %s: tx GTPv1-U Echo Response: seq_nr=%u recovery_count=%u\n",
			    osmo_sockaddr_to_str(remote), seq_nr, g_upf->tunend.recovery_count);
		rc = 0;
	}
	msgb_free(msg);
	return rc;
}

int upf_gtpu_echo_req_tx(struct upf_gtp_dev *dev, const struct osmo_sockaddr *remote, uint16_t seq_nr)
{
	struct gtp1u_hdr *tx_h;
	int rc;
	uint8_t msgbuf[sizeof(struct gtp1u_hdr) + 2];

	tx_h = (void *)msgbuf;
	*tx_h = (struct gtp1u_hdr){
		/* 3GPP TS 29.281 5.1 defines that the ECHO REQ & RESP shall contain a sequence nr */
		.s = 1,
		.pt = 1,
		.version = 1,
		.msg_type = GTP1U_MSGTYPE_ECHO_REQ,
		.ext = {
			.seq_nr = seq_nr,
		},
	};

	/* ECHO REQUEST shall contain a recovery counter */
	tx_h->data2[0] = GTP1U_IEI_RECOVERY;
	tx_h->data2[1] = g_upf->tunend.recovery_count;

	osmo_store16be(sizeof(msgbuf) - offsetof(struct gtp1u_hdr, data1), &tx_h->length);

	rc = sendto(dev->gtpv1.ofd.fd, msgbuf, sizeof(msgbuf), 0, &remote->u.sa, sizeof(*remote));
	if (rc < 0) {
		rc = -errno;
		LOG_GTP_DEV(dev, LOGL_ERROR, "GTP1-U sendto(len=%zu, to=%s): %s\n", sizeof(msgbuf),
			    osmo_sockaddr_to_str(remote), strerror(-rc));
	} else {
		rc = 0;
	}
	LOG_GTP_DEV(dev, LOGL_INFO, "<- %s: tx GTP1-U Echo Request: seq_nr=%u recovery_count=%u\n",
		    osmo_sockaddr_to_str(remote), seq_nr, g_upf->tunend.recovery_count);
	return rc;
}

int upf_gtpu_read_cb(struct osmo_fd *ofd, unsigned int what)
{
	struct upf_gtp_dev *dev = ofd->data;

	ssize_t sz;
	uint8_t buf[4096];
	struct osmo_sockaddr remote;
	socklen_t remote_len = sizeof(remote);
	const struct gtp1u_hdr *h;
	uint16_t h_length;

	if ((sz = recvfrom(dev->gtpv1.ofd.fd, buf, sizeof(buf), 0, &remote.u.sa, &remote_len)) < 0) {
		LOG_GTP_DEV(dev, LOGL_ERROR, "recvfrom() failed: %s\n", strerror(errno));
		return -1;
	}
	if (sz == 0) {
		LOG_GTP_DEV(dev, LOGL_ERROR, "recvfrom() yields zero bytes\n");
		return -1;
	}

	/* A GTPv1-U header of size 8 is valid, but this code expects to handle only ECHO REQUEST messages. These are
	 * required to have a sequence number, hence this check here consciously uses the full sizeof(*h) == 12. */
	if (sz < sizeof(*h)) {
		LOG_GTP_DEV(dev, LOGL_ERROR,
			    "<- %s: rx GTPv1-U packet smaller than the GTPv1-U header + sequence nr: %zd < %zu\n",
			    osmo_sockaddr_to_str(&remote), sz, sizeof(*h));
		return -1;
	}

	h = (const struct gtp1u_hdr *)buf;
	if (h->version != 1) {
		LOG_GTP_DEV(dev, LOGL_ERROR, "<- %s: rx GTPv1-U v%u: only GTP version 1 supported\n",
			    osmo_sockaddr_to_str(&remote), h->version);
		return -1;
	}

	h_length = osmo_load16be(&h->length);
	if (offsetof(struct gtp1u_hdr, data1) + h_length > sz) {
		LOG_GTP_DEV(dev, LOGL_ERROR, "<- %s: rx GTPv1-U: header + h.length = %zu > received bytes = %zd\n",
			    osmo_sockaddr_to_str(&remote), offsetof(struct gtp1u_hdr, data1) + h_length, sz);
		return -1;
	}

	switch (h->msg_type) {
	case GTP1U_MSGTYPE_ECHO_REQ:
		return rx_echo_req(dev, &remote, h, sz);
	case GTP1U_MSGTYPE_ECHO_RSP:
		rx_echo_resp(dev, &remote, h, sz);
		return 0;
	case GTP1U_MSGTYPE_ERRR_IND:
		/* 3GPP TS 29.281 7.3.1: Log "Tunnel Endpoint Identifier Data I" and "GTP-U Peer Address" */
		LOG_GTP_DEV(dev, LOGL_NOTICE, "%s rx: GTPv1-U Error Indication not supported\n",
			    osmo_sockaddr_to_str(&remote));
		return 0;
	case GTP1U_MSGTYPE_GPDU:
		LOG_GTP_DEV(dev, LOGL_NOTICE, "%s rx: GTPv1-U PDU TEID=0x%08x over slow path not supported\n",
			    osmo_sockaddr_to_str(&remote), osmo_load32be(&h->tei));
		return 0;
	default:
		LOG_GTP_DEV(dev, LOGL_ERROR, "%s rx: GTPv1-U message type %u not supported\n",
			    osmo_sockaddr_to_str(&remote), h->msg_type);
		return -1;
	}
}

int upf_gtpu_echo_setup(struct upf_gtp_dev *dev)
{
	if (dev->gtpv1.ofd.fd == -1) {
		LOGP(DGTP, LOGL_ERROR, "Cannot setup GTPv1-U ECHO: socket not initialized\n");
		return -EINVAL;
	}

	/* the caller should already have osmo_fd_register()ed when setting up the socket. */
	OSMO_ASSERT(osmo_fd_is_registered(&dev->gtpv1.ofd));
	/* make sure there is no cb yet that this would be replacing. */
	OSMO_ASSERT(dev->gtpv1.ofd.cb == NULL);

	dev->gtpv1.ofd.cb = upf_gtpu_read_cb;
	dev->gtpv1.ofd.data = dev;
	return 0;
}