from trex_stl_lib.api import *
from scapy.contrib.gtp import GTP_U_Header, GTPPDUSessionContainer
import argparse

class STLS1(object):

    def calc_src_addr(self, addr_str, inc):
        li = addr_str.split(".")
        num_0 = (int(li[0]) << (3*8))
        num_1 = (int(li[1]) << (2*8))
        num_2 = (int(li[2]) << (1*8))
        num_3 = int(li[3])
        num = num_0 | num_1 | num_2 | num_3
        num += inc
        li[0] = str((num >> (3*8)) & 0xff)
        li[1] = str((num >> (2*8)) & 0xff)
        li[2] = str((num >> (1*8)) & 0xff)
        li[3] = str(num & 0xff)
        return ".".join(li)


    # One stream per GTPU eNB/CN IP address.
    # Assumption: TEIDs are assigned sequentially (in pairs RAN/CN) over each IP
    # address in the list, then continue again at the start of the list.
    def create_stream (self, direction, addr_idx = 0):
        if direction == 0: # UL
            src_gtpu_ip = "172.16.32.2"
            dst_gtp_ip = "172.16.31.2"
            start_teid = 0x00000002
        else: # DL
            src_gtpu_ip = "172.16.31.2"
            dst_gtp_ip = "172.16.32.2"
            start_teid = 0x00000001
        src_gtpu_ip = self.calc_src_addr(src_gtpu_ip, addr_idx)
        start_teid = start_teid + 2*addr_idx

        # vm
        vm = STLScVmRaw( [ STLVmFlowVar(name="gtpu_teid",
                                        min_value=start_teid,
                                        max_value=start_teid + 2*(self.num_streams - 1),
                                        size=4, op="inc", step=2*self.num_addrs),
                           STLVmWrFlowVar(fv_name="gtpu_teid",
                                          pkt_offset="GTP_U_Header.teid"),
                           STLVmFixChecksumHw(l3_offset="IP",l4_offset="UDP",l4_type=CTRexVmInsFixHwCs.L4_TYPE_UDP)
                        ]
                    )

        return STLStream(
            name = "gtpu-ul-%d-%s-%x" % (addr_idx, src_gtpu_ip, start_teid),
            packet =
                    STLPktBuilder(
                        pkt = Ether()/IP(src=src_gtpu_ip,dst=dst_gtp_ip,version=4)/
                                UDP(dport=2152,sport=2152)/
                                GTP_U_Header(teid=start_teid)/
                                GTPPDUSessionContainer(type=1,QFI=1)/
                                IP(src="10.45.0.2",dst="192.168.16.152",version=4)/
                                UDP()/
                                (self.payload_len*'x'),
                        vm = vm
                    ),
             mode = STLTXCont())

    def get_streams (self, direction, tunables, **kwargs):
        parser = argparse.ArgumentParser(description='Argparser for {}'.format(os.path.basename(__file__)),
                                         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('--num-streams',
                            type=int,
                            default=1,
                            help="The number of streams.")
        parser.add_argument('--num-addrs',
                            type=int,
                            default=1,
                            help="The number of src IP addresses.")
        parser.add_argument('--payload-len',
                            type=int,
                            default=1400,
                            help="Payload size in bytes.")
        args = parser.parse_args(tunables)
        self.num_streams = args.num_streams
        self.num_addrs = args.num_addrs
        self.payload_len = args.payload_len
        streams = []
        for addr_idx in range(self.num_addrs):
            streams.append(self.create_stream(direction = 0, addr_idx = addr_idx))
        return STLProfile(streams).get_streams()

def register():
    return STLS1()
