#!/usr/bin/env python3

# Copyright (C) 2022 Sylvain Munaut and Harald Welte
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.

import usb1
import time
import abc
from typing import Tuple

from ice40usbtrace.packet_parser import USBPacketParser
from ice40usbtrace.packet import USBPacket

class USBPacketHandler(abc.ABC):
    """Base class for inheriting packet handlers."""
    @abc.abstractmethod
    def handle_packet(self, packet: USBPacket):
        pass

class USBPacketHandlerPrint(USBPacketHandler):
    def handle_packet(self, packet: USBPacket):
        """Default packet handler callback, printing each packet."""
        print(repr(packet))


class ICE40USBTrace:

    IUT_INTF_CAPTURE_STATUS   = 0x10
    IUT_INTF_CAPTURE_START    = 0x12
    IUT_INTF_CAPTURE_STOP     = 0x13
    IUT_INTF_BUFFER_GET_LEVEL = 0x20
    IUT_INTF_BUFFER_FLUSH     = 0x21

    USB_VID = 0x1d50
    USB_PID = 0x617e

    CTRL_V_IF = usb1.REQUEST_TYPE_VENDOR | usb1.RECIPIENT_INTERFACE

    def __init__(self, context: usb1.USBContext = usb1.USBContext(),
                 packet_handler = USBPacketHandlerPrint()):
        self.usb_context = context
        self.usb_handle = self.usb_context.openByVendorIDAndProductID(self.USB_VID, self.USB_PID)
        self.usb_handle.claimInterface(1)

        self.data = bytearray()
        self.upp = USBPacketParser()
        self.packet_handler = packet_handler

        self.xfer_in = []
        self.running = False

    def _blocking_start(self):
        """Submit a blocking/synchronous CTRL transfer to start the tracing."""
        self.usb_handle.controlWrite(self.CTRL_V_IF, self.IUT_INTF_CAPTURE_START, 0, 0, b'')

    def _blocking_stop(self):
        """Submit a blocking/synchronous CTRL transfer to stop the tracing."""
        self.usb_handle.controlWrite(self.CTRL_V_IF, self.IUT_INTF_CAPTURE_STOP, 0, 0, b'')

    def _blocking_flush(self):
        """Submit a blocking/synchronous CTRL transfer to flush the buffers."""
        self.usb_handle.controlWrite(self.CTRL_V_IF, self.IUT_INTF_BUFFER_FLUSH, 0, 0, b'')

    def _capture_status_cb(self, transfer: usb1.USBTransfer):
        """call-back for URB completion of IUT_INTF_CAPTURE_STATUS."""
        status = transfer.getStatus()
        if status != usb1.TRANSFER_COMPLETED:
            if not self.running:
                # silently ignore during shutdown
                return
            raise Exception("CTRL URB failed with %s" % status)

        data = transfer.getBuffer()
        level = int.from_bytes(data[2:6], 'little')
        print("Level: %d" % (level))

    def _submit_capture_status(self):
        """Triger a non-blocking CAPTURE_STATUS CTRL request. Prints Level on completion."""
        transfer = self.usb_handle.getTransfer()
        transfer.setControl(self.CTRL_V_IF | usb1.ENDPOINT_IN, self.IUT_INTF_CAPTURE_STATUS,
                            0, 0, 6, callback=self._capture_status_cb)
        transfer.submit()

    def _bulk_in_cb(self, transfer: usb1.USBTransfer):
        """call-back for URB completion of BULK IN endpoint."""
        status = transfer.getStatus()
        if status != usb1.TRANSFER_COMPLETED:
            if not self.running:
                # silently ignore during shutdown
                return
            raise Exception("BULK IN URB failed with %s" % status)

        # append
        chunk = transfer.getBuffer()
        self.data += chunk

        # Re-submit the transfer
        transfer.submit()

        # try to process any packets in self.data buffer
        self._process_data()

    def _submit_bulk_in(self):
        """Submit a non-blocking BULK IN request on the endpoint."""
        transfer = self.usb_handle.getTransfer()
        transfer.setBulk(1 | usb1.ENDPOINT_IN, 1024, callback=self._bulk_in_cb)
        transfer.submit()
        self.xfer_in.append(transfer)

    def _process_data(self):
        """Iterate over self.data buffer; parse + remove any completed packets in it."""
        ofs = 0
        while True:
            # Enough for header ?
            if (len(self.data) - ofs) < 4:
                break

            # Parse packet
            packet, plen = self.upp.parse_packet(self.data, ofs)
            if plen <= 0:
                # incomplete packet
                break

            # call packet-handler call-back for each packet
            if packet:
                self.packet_handler.handle_packet(packet)

            ofs += plen

        # Trim
        self.data = self.data[ofs:]

    def start(self):
        """Start the USB tracing process."""
        if self.running:
            raise Exception("Already running; cannot start again")
        # Perform reset (and ignore errors)
        try:
            self._blocking_stop()
        except:
            pass
        try:
            self._blocking_flush()
        except:
            pass

        self._blocking_start()
        self.running = True

        # submit four asynchronous bulk-in URBs
        for i in range(0,4):
            self._submit_bulk_in()

    def stop(self):
        """Stop the USB tracing process."""
        if not self.running:
            raise Exception("Already stopped; cannot stop again")
        self.running = False
        # request device to stop
        self._blocking_stop()
        self._blocking_flush()
        # cancel all IN URB
        for transfer in self.xfer_in:
            transfer.cancel()

    def run(self):
        if not self.running:
            raise Exception("You must call start before calling run")

        stat_time = time.time() + 1.0

        while True:
            try:
                self.usb_context.handleEvents()

                if time.time() > stat_time:
                    stat_time += 1.0
                    self._submit_capture_status()

            except KeyboardInterrupt:
                self.stop()
                break


if __name__ == '__main__':
    ut = ICE40USBTrace()
    ut.start()
    ut.run()
