# Copyright (C) 2022 Harald Welte
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.

import abc

class USBPacketHdr:
    """Representation of an iCE40-usbtrace packet header."""
    def __init__(self, ts: int, pid: int, ok: bool, dat: int):
        self.ts = ts
        self.pid = pid
        self.ok = ok
        self.dat = dat
        # metadata not in header itself
        self.frame_no = None
        self.last_ts = None
        self.frame_ts = None

    @classmethod
    def from_bytes(cls, buffer:bytes) -> 'USBPacketHdr':
        """Construct an instance from a binary header."""
        hdr    = int.from_bytes(buffer[0:4], 'big')
        hdr_ts  = hdr & 0xffff;
        hdr_pid = hdr >> 28
        hdr_ok  = bool(hdr & (1 << 27))
        hdr_dat = (hdr >> 16) & 0x7ff
        return cls(hdr_ts, hdr_pid, hdr_ok, hdr_dat)

    def augment(self, frame_no:int, last_ts:int, frame_ts:int):
        """Augment the header with metadata."""
        self.frame_no = frame_no
        self.frame_ts = frame_ts
        self.expanded_ts = last_ts + self.ts + 1
        return self.expanded_ts

    def __repr__(self):
        return '[%12d %5d+%5d]%s' % (self.expanded_ts, self.frame_no, self.expanded_ts - self.frame_ts,
                                     '' if self.ok else ' [BAD]')


class USBPacket(abc.ABC):
    """Base class for a traced USB packet."""
    def __init__(self, hdr: USBPacketHdr):
        self.hdr = hdr

    def __repr__(self):
        return '%s %s %s' % (repr(self.hdr), self.name, self._repr())

    def _repr(self):
        """Base method to be overridden by any of the non-trivial derived types."""
        return ''

    def _pid_byte(self):
        return (self.hdr.pid | ((self.hdr.pid ^ 0xf) << 4)).to_bytes(1, 'little')

    @abc.abstractmethod
    def raw_bytes(self):
        """Returns the raw bytes for the packet as they would have been on the bus."""
        pass


class USBPacket_DATA(USBPacket):
    """Base class for a traced DATA0/DATA1 USB packet."""
    def __init__(self, hdr: USBPacketHdr, payload: bytes):
        self.data = payload
        super().__init__(hdr)

    def _repr(self):
        return ''.join(['%02x' % x for x in self.data])

    def raw_bytes(self):
        return self._pid_byte() + self.data

class USBPacket_DATA0(USBPacket_DATA):
    name = 'DATA0'

class USBPacket_DATA1(USBPacket_DATA):
    name = 'DATA1'


class USBPacket_Handshake(USBPacket):
    """Base class for a traced Handshake USB packet."""
    def raw_bytes(self):
        return self._pid_byte()

class USBPacket_ACK(USBPacket_Handshake):
    name = 'ACK'

class USBPacket_NAK(USBPacket_Handshake):
    name = 'NAK'

class USBPacket_STALL(USBPacket_Handshake):
    name = 'STALL'


class USBPacket_Token(USBPacket):
    """Base class for a traced Token USB packet."""
    def __init__(self, hdr: USBPacketHdr):
        self.addr = hdr.dat & 0x7f
        self.ep = hdr.dat >> 7
        super().__init__(hdr)

    def _repr(self):
        return 'ADDR=%d EP=%02x' % (self.addr, self.ep)

    def raw_bytes(self):
        return self._pid_byte() + self.hdr.dat.to_bytes(2, 'little')

class USBPacket_OUT(USBPacket_Token):
    name = 'OUT'

class USBPacket_IN(USBPacket_Token):
    name = 'IN'

class USBPacket_SETUP(USBPacket_Token):
    name = 'SETUP'


class USBPacket_SOF(USBPacket):
    name = 'SOF'
    """Class for a traced USB SOF packet."""
    def __init__(self, hdr: USBPacketHdr):
        self.frame = hdr.dat
        super().__init__(hdr)

    def _repr(self):
        return 'frame=%d' % (self.frame)

    def raw_bytes(self):
        return self._pid_byte() + self.hdr.dat.to_bytes(2, 'little')
