from trex.astf.api import *
from trex.astf.tunnels_topo import TunnelsTopo
import argparse

def calc_src_addr(addr_str, inc):
    li = addr_str.split(".")
    num_0 = (int(li[0]) << (3*8))
    num_1 = (int(li[1]) << (2*8))
    num_2 = (int(li[2]) << (1*8))
    num_3 = int(li[3])
    num = num_0 | num_1 | num_2 | num_3
    num += inc
    li[0] = str((num >> (3*8)) & 0xff)
    li[1] = str((num >> (2*8)) & 0xff)
    li[2] = str((num >> (1*8)) & 0xff)
    li[3] = str(num & 0xff)
    return ".".join(li)


def add_tun(topo, num_streams, num_addrs, addr_offset, start_teid, tun_idx):
    streams_per_tun = num_streams // num_addrs
    src_gtpu_ip = calc_src_addr('172.16.32.2', addr_offset + tun_idx)
    start_teid += tun_idx
    teid_jump = 1*num_addrs
    src_start = '193.%d.0.0' % (addr_offset + tun_idx)
    src_end = calc_src_addr(src_start, streams_per_tun)
    topo.add_tunnel_ctx(
        src_start = src_start,
        src_end = src_end,
        initial_teid = start_teid,
        teid_jump = teid_jump,
        sport = 2152,
        version = 4,
        tunnel_type = 1,
        src_ip = src_gtpu_ip,
        dst_ip = '172.16.32.1',
        activate = True
    )


def get_topo(**kwargs):
    print("get_topo params: %r" % (kwargs))
    num_streams = kwargs.get('num-streams', 1)
    num_addrs = kwargs.get('num-addrs', 1)
    addr_offset = kwargs.get('addr-offset', 0)
    # NOTE Assume previous segments have same amount of streams as this one:
    start_teid = kwargs.get('start-teid', 0x00000001 + ((num_streams // num_addrs) * addr_offset))
    topo = TunnelsTopo()

    for tun_idx in range(num_addrs):
        add_tun(topo, num_streams, num_addrs, addr_offset, start_teid, tun_idx)
    return topo
