from trex.astf.api import *
import argparse

ul_pkt = (1400*'x')
dl_pkt = (1400*'y')

class Prof1():
    def __init__(self):
        pass

    def calc_src_addr(self, addr_str, inc):
        li = addr_str.split(".")
        num_0 = (int(li[0]) << (3*8))
        num_1 = (int(li[1]) << (2*8))
        num_2 = (int(li[2]) << (1*8))
        num_3 = int(li[3])
        num = num_0 | num_1 | num_2 | num_3
        num += inc
        li[0] = str((num >> (3*8)) & 0xff)
        li[1] = str((num >> (2*8)) & 0xff)
        li[2] = str((num >> (1*8)) & 0xff)
        li[3] = str(num & 0xff)
        return ".".join(li)

    def create_ip_gen(self, prog_c, prog_s, addr_idx):
        # ip generator
        start_ms_addr = "193.%d.0.0" % (self.addr_offset + addr_idx)
        end_ms_addr = self.calc_src_addr(start_ms_addr, (self.num_streams // self.num_addrs) - 1)
        start_srv_addr = "48.%d.0.1" % (self.addr_offset)
        end_srv_addr = self.calc_src_addr(start_srv_addr, self.num_addrs - 1)
        ip_gen_c = ASTFIPGenDist(ip_range=[start_ms_addr, end_ms_addr], distribution="seq")
        ip_gen_s = ASTFIPGenDist(ip_range=[start_srv_addr, end_srv_addr], distribution="seq")
        ip_gen = ASTFIPGen(glob=ASTFIPGenGlobal(ip_offset="1.0.0.0"),
                           dist_client=ip_gen_c,
                           dist_server=ip_gen_s)
        return ip_gen

    def create_template(self, prog_c, prog_s, addr_idx):
        # ip generator
        ip_gen = self.create_ip_gen(prog_c, prog_s, addr_idx)
        # template
        srv_port = 80 + addr_idx
        temp_c = ASTFTCPClientTemplate(program=prog_c, ip_gen=ip_gen, port=srv_port)
        temp_s = ASTFTCPServerTemplate(program=prog_s, assoc=ASTFAssociation(rules=ASTFAssociationRule(port=srv_port)))
        template = ASTFTemplate(client_template=temp_c, server_template=temp_s, tg_name='latency')
        return template

    def get_profile(self, tunables, **kwargs):
        parser = argparse.ArgumentParser(description='Argparser for {}'.format(os.path.basename(__file__)),
                                         formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('--num-streams',
                            type=int,
                            default=1,
                            help="The number of streams.")
        parser.add_argument('--num-addrs',
                            type=int,
                            default=1,
                            help="The number of src IP addresses.")
        parser.add_argument('--addr-offset',
                            type=int,
                            default=0,
                            help="Start index assigning --num-addrs addresses.")
        parser.add_argument('--num-pkts-ul',
                            type=int,
                            default=1,
                            help="The number of UL Packets to transmit per flow.")
        parser.add_argument('--num-pkts-dl',
                            type=int,
                            default=1,
                            help="The number of DL Packets to transmit per flow.")
        parser.add_argument('--dir',
                            type=str,
                            default='ul',
                            choices={'ul', 'dl', 'uldl'},
                            help='')
        args = parser.parse_args(tunables)
        self.num_streams = args.num_streams
        self.num_addrs = args.num_addrs
        self.addr_offset = args.addr_offset
        self.num_pkts_ul = args.num_pkts_ul
        self.num_pkts_dl = args.num_pkts_dl
        do_ul = "ul" in args.dir.lower()
        do_dl = "dl" in args.dir.lower()

        # client commands
        prog_c = ASTFProgram(stream=False, udp_mtu=1400, addon='latency')
        if do_ul:
            for _ in range(self.num_pkts_ul):
                prog_c.send_msg(ul_pkt)
        if do_dl:
            prog_c.set_keepalive_msg(10000)
            prog_c.recv_msg(self.num_pkts_dl)

        prog_s = ASTFProgram(stream=False, udp_mtu=1400, addon='latency')
        if do_ul:
            prog_s.set_keepalive_msg(10000)
            prog_s.recv_msg(self.num_pkts_ul)
        if do_dl:
            for _ in range(self.num_pkts_dl):
                prog_s.send_msg(dl_pkt)

        # Not really used, but must be passed to ASTFProfile:
        default_ip_gen = self.create_ip_gen(prog_c, prog_s, 0)

        templates = []
        for addr_idx in range(self.num_addrs):
            templates.append(self.create_template(prog_c, prog_s, addr_idx = addr_idx))

        return ASTFProfile(default_ip_gen=default_ip_gen, templates=templates)

def register():
    return Prof1()

