"""General String Encoding Rules (GSER).

"""

import time
import binascii
import math
from copy import copy
import datetime

from . import BaseType, format_bytes, ErrorWithLocation
from . import EncodeError
from . import DecodeError
from . import compiler
from . import format_or
from . import utc_time_from_datetime
from . import generalized_time_from_datetime
from .compiler import enum_values_as_dict


class Type(BaseType):

    def encode(self, data, _separator, _indent):
        raise NotImplementedError('To be implemented by subclasses.')

    def set_size_range(self, minimum, maximum, has_extension_marker):
        pass


class MembersType(Type):

    def __init__(self, name, members, type_name):
        super(MembersType, self).__init__(name, type_name)
        self.members = members

    def encode(self, data, separator, indent):
        encoded_members = []
        member_separator = separator + ' ' * indent

        for member in self.members:
            name = member.name

            if name in data:
                try:
                    encoded_member = member.encode(data[name],
                                                   member_separator,
                                                   indent)
                except ErrorWithLocation as e:
                    # Add member location
                    e.add_location(member)
                    raise e

                encoded_member = u'{}{} {}'.format(member_separator,
                                                   member.name,
                                                   encoded_member)
                encoded_members.append(encoded_member)
            elif member.optional:
                pass
            elif not member.has_default():
                raise EncodeError(
                    "{} member '{}' not found in {}.".format(
                        self.__class__.__name__,
                        name,
                        data))

        encoded_members = ','.join(encoded_members)

        return separator.join(['{' + encoded_members, '}'])

    def __repr__(self):
        return '{}({}, [{}])'.format(
            self.__class__.__name__,
            self.name,
            ', '.join([repr(member) for member in self.members]))


class ArrayType(Type):

    def __init__(self, name, type_name, element_type):
        super(ArrayType, self).__init__(name, type_name)
        self.element_type = element_type

    def encode(self, data, separator, indent):
        encoded_elements = []
        element_separator = separator + ' ' * indent

        for entry in data:
            encoded_element = self.element_type.encode(entry,
                                                       element_separator,
                                                       indent)
            encoded_element = u'{}{}'.format(element_separator,
                                             encoded_element)
            encoded_elements.append(encoded_element)

        encoded_elements = ','.join(encoded_elements)

        return separator.join(['{' + encoded_elements, '}'])

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__,
                                   self.name,
                                   self.element_type)


class Boolean(Type):

    def __init__(self, name):
        super(Boolean, self).__init__(name, 'BOOLEAN')

    def encode(self, data, _separator, _indent):
        return 'TRUE' if data else 'FALSE'


class Integer(Type):

    def __init__(self, name):
        super(Integer, self).__init__(name, 'INTEGER')

    def encode(self, data, _separator, _indent):
        return str(data)


class Real(Type):

    def __init__(self, name):
        super(Real, self).__init__(name, 'REAL')

    def encode(self, data, _separator, _indent):
        if data == float('inf'):
            data = 'PLUS-INFINITY'
        elif data == float('-inf'):
            data = 'MINUS-INFINITY'
        elif math.isnan(data):
            raise EncodeError('Cannot encode floating point number NaN.')
        elif data == 0.0:
            data = '0'
        else:
            data = '{}E0'.format(data)

        return data


class Null(Type):

    def __init__(self, name):
        super(Null, self).__init__(name, 'NULL')

    def encode(self, _data, _separator, _indent):
        return 'NULL'


class BitString(Type):

    def __init__(self, name):
        super(BitString, self).__init__(name, 'BIT STRING')

    def encode(self, data, _separator, _indent):
        encoded = int(binascii.hexlify(data[0]), 16)
        encoded |= (0x80 << (8 * len(data[0])))

        return "'{}'B".format(bin(encoded)[10:10 + data[1]]).upper()


class OctetString(Type):

    def __init__(self, name):
        super(OctetString, self).__init__(name, 'OCTET STRING')

    def encode(self, data, _separator, _indent):
        return "'{}'H".format(format_bytes(data)).upper()


class ObjectIdentifier(Type):

    def __init__(self, name):
        super(ObjectIdentifier, self).__init__(name, 'OBJECT IDENTIFIER')

    def encode(self, data, _separator, _indent):
        return data


class Enumerated(Type):

    def __init__(self, name, values, numeric):
        super(Enumerated, self).__init__(name, 'ENUMERATED')

        if numeric:
            self.data_to_value = enum_values_as_dict(values)
        else:
            self.data_to_value = {
                v: v for v in enum_values_as_dict(values).values()
            }

    def encode(self, data, _separator, _indent):
        return self.data_to_value[data]


class Sequence(MembersType):

    def __init__(self, name, members):
        super(Sequence, self).__init__(name, members, 'SEQUENCE')


class SequenceOf(ArrayType):

    def __init__(self, name, element_type):
        super(SequenceOf, self).__init__(name,
                                         'SEQUENCE OF',
                                         element_type)


class Set(MembersType):

    def __init__(self, name, members):
        super(Set, self).__init__(name, members, 'SET')


class SetOf(ArrayType):

    def __init__(self, name, element_type):
        super(SetOf, self).__init__(name,
                                    'SET OF',
                                    element_type)


class Choice(Type):

    def __init__(self, name, members):
        super(Choice, self).__init__(name, 'CHOICE')
        self.members = members
        self.name_to_member = {member.name: member for member in self.members}

    def format_names(self):
        return format_or(sorted([member.name for member in self.members]))

    def encode(self, data, separator, indent):
        try:
            member = self.name_to_member[data[0]]
        except KeyError:
            raise EncodeError(
                "Expected choice {}, but got '{}'.".format(
                    self.format_names(),
                    data[0]))
        try:
            encoded = member.encode(data[1], separator, indent)
        except ErrorWithLocation as e:
            # Add member location
            e.add_location(member)
            raise e

        return u'{} : {}'.format(data[0], encoded)

    def __repr__(self):
        return 'Choice({}, [{}])'.format(
            self.name,
            ', '.join([repr(member) for member in self.members]))


class UTF8String(Type):

    def __init__(self, name):
        super(UTF8String, self).__init__(name, 'UTF8String')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class NumericString(Type):

    def __init__(self, name):
        super(NumericString, self).__init__(name, 'NumericString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class PrintableString(Type):

    def __init__(self, name):
        super(PrintableString, self).__init__(name, 'PrintableString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class IA5String(Type):

    def __init__(self, name):
        super(IA5String, self).__init__(name, 'IA5String')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class VisibleString(Type):

    def __init__(self, name):
        super(VisibleString, self).__init__(name, 'VisibleString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class GeneralString(Type):

    def __init__(self, name):
        super(GeneralString, self).__init__(name, 'GeneralString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class BMPString(Type):

    def __init__(self, name):
        super(BMPString, self).__init__(name, 'BMPString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class GraphicString(Type):

    def __init__(self, name):
        super(GraphicString, self).__init__(name, 'GraphicString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class UniversalString(Type):

    def __init__(self, name):
        super(UniversalString, self).__init__(name, 'UniversalString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class TeletexString(Type):

    def __init__(self, name):
        super(TeletexString, self).__init__(name, 'TeletexString')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(data)


class ObjectDescriptor(GraphicString):
    pass


class UTCTime(Type):

    def __init__(self, name):
        super(UTCTime, self).__init__(name, 'UTCTime')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(utc_time_from_datetime(data))


class GeneralizedTime(Type):

    def __init__(self, name):
        super(GeneralizedTime, self).__init__(name, 'GeneralizedTime')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(generalized_time_from_datetime(data))


class Date(Type):

    def __init__(self, name):
        super(Date, self).__init__(name, 'DATE')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(str(data))


class TimeOfDay(Type):

    def __init__(self, name):
        super(TimeOfDay, self).__init__(name, 'TIME-OF-DAY')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(str(data))


class DateTime(Type):

    def __init__(self, name):
        super(DateTime, self).__init__(name, 'DATE-TIME')

    def encode(self, data, _separator, _indent):
        return u'"{}"'.format(str(data).replace(' ', 'T'))


class Any(Type):

    def __init__(self, name):
        super(Any, self).__init__(name, 'ANY')

    def encode(self, data, _separator, _indent):
        data = format_bytes(data).upper()

        return "'{}'H".format(data)


class Recursive(compiler.Recursive, Type):

    def __init__(self, name, type_name, module_name):
        super(Recursive, self).__init__(name, 'RECURSIVE')
        self.type_name = type_name
        self.module_name = module_name
        self.inner = None

    def set_inner_type(self, inner):
        self.inner = copy(inner)

    def encode(self, data, separator, indent):
        return self.inner.encode(data, separator, indent)


class CompiledType(compiler.CompiledType):

    def __init__(self, type_name, compiled_type):
        super(CompiledType, self).__init__(compiled_type)
        self._value_name = type_name.lower()
        self._value_type = type_name

    def encode(self, data, indent=None):
        try:
            if indent is None:
                encoded = self._type.encode(data, ' ', 0)
            else:
                encoded = self._type.encode(data, '\n', indent)
        except ErrorWithLocation as e:
            # Add member location
            e.add_location(self._type)
            raise e

        encoded = u'{} {} ::= {}'.format(self._value_name,
                                         self._value_type,
                                         encoded.lstrip(' '))

        return encoded.encode('utf-8')

    def decode(self, data):
        raise NotImplementedError('GSER decoding is not implemented.')


class Compiler(compiler.Compiler):

    def process_type(self, type_name, type_descriptor, module_name):
        compiled_type = self.compile_type(type_name,
                                          type_descriptor,
                                          module_name)

        return CompiledType(type_name, compiled_type)

    def compile_type(self, name, type_descriptor, module_name):
        module_name = self.get_module_name(type_descriptor, module_name)
        type_name = type_descriptor['type']

        if type_name == 'SEQUENCE':
            members, _ = self.compile_members(
                type_descriptor['members'],
                module_name)
            compiled = Sequence(name, members)
        elif type_name == 'SEQUENCE OF':
            compiled = SequenceOf(name,
                                  self.compile_type('',
                                                    type_descriptor['element'],
                                                    module_name))
        elif type_name == 'SET':
            members, _ = self.compile_members(
                type_descriptor['members'],
                module_name)
            compiled = Set(name, members)
        elif type_name == 'SET OF':
            compiled = SetOf(name,
                             self.compile_type('',
                                               type_descriptor['element'],
                                               module_name))
        elif type_name == 'CHOICE':
            members, _ = self.compile_members(
                type_descriptor['members'],
                module_name)
            compiled = Choice(name, members)
        elif type_name == 'INTEGER':
            compiled = Integer(name)
        elif type_name == 'REAL':
            compiled = Real(name)
        elif type_name == 'ENUMERATED':
            compiled = Enumerated(name,
                                  self.get_enum_values(type_descriptor,
                                                       module_name),
                                  self._numeric_enums)
        elif type_name == 'BOOLEAN':
            compiled = Boolean(name)
        elif type_name == 'OBJECT IDENTIFIER':
            compiled = ObjectIdentifier(name)
        elif type_name == 'OCTET STRING':
            compiled = OctetString(name)
        elif type_name == 'TeletexString':
            compiled = TeletexString(name)
        elif type_name == 'NumericString':
            compiled = NumericString(name)
        elif type_name == 'PrintableString':
            compiled = PrintableString(name)
        elif type_name == 'IA5String':
            compiled = IA5String(name)
        elif type_name == 'VisibleString':
            compiled = VisibleString(name)
        elif type_name == 'GeneralString':
            compiled = GeneralString(name)
        elif type_name == 'UTF8String':
            compiled = UTF8String(name)
        elif type_name == 'BMPString':
            compiled = BMPString(name)
        elif type_name == 'GraphicString':
            compiled = GraphicString(name)
        elif type_name == 'UTCTime':
            compiled = UTCTime(name)
        elif type_name == 'UniversalString':
            compiled = UniversalString(name)
        elif type_name == 'GeneralizedTime':
            compiled = GeneralizedTime(name)
        elif type_name == 'DATE':
            compiled = Date(name)
        elif type_name == 'TIME-OF-DAY':
            compiled = TimeOfDay(name)
        elif type_name == 'DATE-TIME':
            compiled = DateTime(name)
        elif type_name == 'BIT STRING':
            compiled = BitString(name)
        elif type_name == 'ANY':
            compiled = Any(name)
        elif type_name == 'ANY DEFINED BY':
            compiled = Any(name)
        elif type_name == 'NULL':
            compiled = Null(name)
        elif type_name == 'EXTERNAL':
            members, _ = self.compile_members(
                self.external_type_descriptor()['members'],
                module_name)
            compiled = Sequence(name, members)
        elif type_name == 'ObjectDescriptor':
            compiled = ObjectDescriptor(name)
        else:
            if type_name in self.types_backtrace:
                compiled = Recursive(name,
                                     type_name,
                                     module_name)
                self.recursive_types.append(compiled)
            else:
                compiled = self.compile_user_type(name,
                                                  type_name,
                                                  module_name)

        return compiled


def compile_dict(specification, numeric_enums=False):
    return Compiler(specification, numeric_enums).process()


def decode_full_length(_data):
    raise DecodeError('Decode length is not supported for this codec.')
