#!/usr/bin/env python3

import usb
import time

# USB defines

PID_OUT   = 0b0001
PID_IN    = 0b1001
PID_SOF   = 0b0101
PID_SETUP = 0b1101

PID_DATA0 = 0b0011
PID_DATA1 = 0b1011

PID_ACK   = 0b0010
PID_NAK   = 0b1010
PID_STALL = 0b1110

PID = {
    PID_OUT   : "OUT",
    PID_IN    : "IN",
    PID_SOF   : "SOF",
    PID_SETUP : "SETUP",
    PID_DATA0 : "DATA0",
    PID_DATA1 : "DATA1",
    PID_ACK   : "ACK",
    PID_NAK   : "NAK",
    PID_STALL : "STALL",
}


class ICE40USBTracer:

    IUT_INTF_CAPTURE_STATUS   = 0x10
    IUT_INTF_CAPTURE_START    = 0x12
    IUT_INTF_CAPTURE_STOP     = 0x13
    IUT_INTF_BUFFER_GET_LEVEL = 0x20
    IUT_INTF_BUFFER_FLUSH     = 0x21


    def __init__(self):

        # Device
        self.dev = usb.core.find(idVendor=0x1d50, idProduct=0x617e)

        # Initial state
        self.data = bytearray()
        self.ts = 0
        self.frame_ts = 0
        self.frame_no = 0

    def _dev_start(self):
        self.dev.ctrl_transfer(0x41, self.IUT_INTF_CAPTURE_START, 0, 0)

    def _dev_stop(self):
        self.dev.ctrl_transfer(0x41, self.IUT_INTF_CAPTURE_STOP, 0, 0)

    def _dev_flush(self):
        self.dev.ctrl_transfer(0x41, self.IUT_INTF_BUFFER_FLUSH, 0, 0)

    def _dev_status(self):
        data = self.dev.ctrl_transfer(0xc1, self.IUT_INTF_CAPTURE_STATUS, 0, 0, 6)
        return int.from_bytes(data[2:6], 'little')

    def _parse_packet(self, buffer, ofs):

        # Readoff header
        hdr    = int.from_bytes(buffer[ofs+0:ofs+4], 'big')

        hdr_ts  = hdr & 0xffff;
        hdr_pid = hdr >> 28
        hdr_ok  = bool(hdr & (1 << 27))
        hdr_dat = (hdr >> 16) & 0x7ff

        # Timestamp
        self.ts += hdr_ts + 1

        # Default values
        extra = ""
        plen = 4

        # What PID ?
        if hdr_pid in [ PID_DATA0, PID_DATA1 ]:
            # Is it all here ?
            plen = 4 + hdr_dat
            if (len(buffer) - ofs) < plen:
                return 0

            # Data as hex
            extra = ' '.join(['%02x' % x for x in buffer[ofs+4:ofs+4+hdr_dat]])

        elif hdr_pid in [ PID_ACK, PID_NAK, PID_STALL ]:
            # Nothing to do, packet is just the PID byte
            pass

        elif hdr_pid in [ PID_OUT, PID_IN, PID_SETUP ]:
            # Extra info
            extra = "addr=%d ep=%d" % (hdr_dat & 0x7f, hdr_dat >> 7)

        elif hdr_pid in [ PID_SOF ]:
            # Extra info
            extra = "frame=%d delta_ts=%d" % (hdr_dat,self.ts-self.frame_ts)

            # Save new time
            self.frame_ts = self.ts
            self.frame_no = hdr_dat

        elif hdr_pid == 0:
            # Don't print TS overflow packets
            # FIXME: Line state reported in hdr_dat, could be used to
            #        find bus resets / idle / ...
            if hdr_ok is True:
                return plen

        # Print
        print("[%12d %5d+%5d] %5s%s%s" % (
            self.ts,
            self.frame_no, self.ts - self.frame_ts,
            PID.get(hdr_pid, "!WTF!"), " " if hdr_ok else " [BAD] ", extra)
        )

        return plen


    def run(self):
        # Perform reset (and ignore errors)
        try:
            self._dev_stop()
        except:
            pass

        try:
            self._dev_flush()
        except:
            pass

        # Start running
        self._dev_start()
        run = True
        stat_time = time.time() + 1.0

        # Run in a loop
        while True:
            try:
                # Grab new chunk
                try:
                    chunk = self.dev.read(0x81, 1024)
                except usb.core.USBTimeoutError:
                    if run:
                        continue
                    else:
                        break

                self.data += chunk

                # Process
                ofs = 0

                while True:
                    # Enough for header ?
                    if (len(self.data) - ofs) < 4:
                        break

                    # Parse packet
                    plen = self._parse_packet(self.data, ofs)
                    if plen <= 0:
                        break

                    ofs += plen

                # Trim
                self.data = self.data[ofs:]

                # Time to print stat ?
                if time.time() > stat_time:
                    stat_time = time.time() + 1
                    print("Level: %d" % (self._dev_status(),))

            except KeyboardInterrupt:
                if run:
                    self._dev_stop()
                    run = False
                else:
                    self._dev_flush()
                    break


if __name__ == '__main__':
    ICE40USBTracer().run()
