import itertools

import numpy as np
import operator

from numba.core import types, errors
from numba import prange
from numba.parfors.parfor import internal_prange

from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
                                         AbstractTemplate, infer_global, infer,
                                         infer_getattr, signature,
                                         bound_function, make_callable_template)

from numba.cpython.builtins import get_type_min_value, get_type_max_value

from numba.core.extending import (
    typeof_impl, type_callable, models, register_model, make_attribute_wrapper,
    )


@infer_global(print)
class Print(AbstractTemplate):
    def generic(self, args, kws):
        for a in args:
            sig = self.context.resolve_function_type("print_item", (a,), {})
            if sig is None:
                raise TypeError("Type %s is not printable." % a)
            assert sig.return_type is types.none
        return signature(types.none, *args)

@infer
class PrintItem(AbstractTemplate):
    key = "print_item"

    def generic(self, args, kws):
        arg, = args
        return signature(types.none, *args)


@infer_global(abs)
class Abs(ConcreteTemplate):
    int_cases = [signature(ty, ty) for ty in sorted(types.signed_domain)]
    uint_cases = [signature(ty, ty) for ty in sorted(types.unsigned_domain)]
    real_cases = [signature(ty, ty) for ty in sorted(types.real_domain)]
    complex_cases = [signature(ty.underlying_float, ty)
                     for ty in sorted(types.complex_domain)]
    cases = int_cases + uint_cases +  real_cases + complex_cases


@infer_global(slice)
class Slice(ConcreteTemplate):
    cases = [
        signature(types.slice2_type, types.intp),
        signature(types.slice2_type, types.none),
        signature(types.slice2_type, types.none, types.none),
        signature(types.slice2_type, types.none, types.intp),
        signature(types.slice2_type, types.intp, types.none),
        signature(types.slice2_type, types.intp, types.intp),
        signature(types.slice3_type, types.intp, types.intp, types.intp),
        signature(types.slice3_type, types.none, types.intp, types.intp),
        signature(types.slice3_type, types.intp, types.none, types.intp),
        signature(types.slice3_type, types.intp, types.intp, types.none),
        signature(types.slice3_type, types.intp, types.none, types.none),
        signature(types.slice3_type, types.none, types.intp, types.none),
        signature(types.slice3_type, types.none, types.none, types.intp),
        signature(types.slice3_type, types.none, types.none, types.none),
    ]


@infer_global(range, typing_key=range)
@infer_global(prange, typing_key=prange)
@infer_global(internal_prange, typing_key=internal_prange)
class Range(ConcreteTemplate):
    cases = [
        signature(types.range_state32_type, types.int32),
        signature(types.range_state32_type, types.int32, types.int32),
        signature(types.range_state32_type, types.int32, types.int32,
                  types.int32),
        signature(types.range_state64_type, types.int64),
        signature(types.range_state64_type, types.int64, types.int64),
        signature(types.range_state64_type, types.int64, types.int64,
                  types.int64),
        signature(types.unsigned_range_state64_type, types.uint64),
        signature(types.unsigned_range_state64_type, types.uint64, types.uint64),
        signature(types.unsigned_range_state64_type, types.uint64, types.uint64,
                  types.uint64),
    ]


@infer
class GetIter(AbstractTemplate):
    key = "getiter"

    def generic(self, args, kws):
        assert not kws
        [obj] = args
        if isinstance(obj, types.IterableType):
            return signature(obj.iterator_type, obj)


@infer
class IterNext(AbstractTemplate):
    key = "iternext"

    def generic(self, args, kws):
        assert not kws
        [it] = args
        if isinstance(it, types.IteratorType):
            return signature(types.Pair(it.yield_type, types.boolean), it)


@infer
class PairFirst(AbstractTemplate):
    """
    Given a heterogeneous pair, return the first element.
    """
    key = "pair_first"

    def generic(self, args, kws):
        assert not kws
        [pair] = args
        if isinstance(pair, types.Pair):
            return signature(pair.first_type, pair)


@infer
class PairSecond(AbstractTemplate):
    """
    Given a heterogeneous pair, return the second element.
    """
    key = "pair_second"

    def generic(self, args, kws):
        assert not kws
        [pair] = args
        if isinstance(pair, types.Pair):
            return signature(pair.second_type, pair)


def choose_result_bitwidth(*inputs):
    return max(types.intp.bitwidth, *(tp.bitwidth for tp in inputs))

def choose_result_int(*inputs):
    """
    Choose the integer result type for an operation on integer inputs,
    according to the integer typing NBEP.
    """
    bitwidth = choose_result_bitwidth(*inputs)
    signed = any(tp.signed for tp in inputs)
    return types.Integer.from_bitwidth(bitwidth, signed)


# The "machine" integer types to take into consideration for operator typing
# (according to the integer typing NBEP)
machine_ints = (
    sorted(set((types.intp, types.int64))) +
    sorted(set((types.uintp, types.uint64)))
    )

# Explicit integer rules for binary operators; smaller ints will be
# automatically upcast.
integer_binop_cases = tuple(
    signature(choose_result_int(op1, op2), op1, op2)
    for op1, op2 in itertools.product(machine_ints, machine_ints)
    )


class BinOp(ConcreteTemplate):
    cases = list(integer_binop_cases)
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]


@infer_global(operator.add)
class BinOpAdd(BinOp):
    pass


@infer_global(operator.iadd)
class BinOpAdd(BinOp):
    pass


@infer_global(operator.sub)
class BinOpSub(BinOp):
    pass


@infer_global(operator.isub)
class BinOpSub(BinOp):
    pass


@infer_global(operator.mul)
class BinOpMul(BinOp):
    pass


@infer_global(operator.imul)
class BinOpMul(BinOp):
    pass


@infer_global(operator.mod)
class BinOpMod(ConcreteTemplate):
    cases = list(integer_binop_cases)
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]


@infer_global(operator.imod)
class BinOpMod(ConcreteTemplate):
    cases = list(integer_binop_cases)
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]


@infer_global(operator.truediv)
class BinOpTrueDiv(ConcreteTemplate):
    cases = [signature(types.float64, op1, op2)
             for op1, op2 in itertools.product(machine_ints, machine_ints)]
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]


@infer_global(operator.itruediv)
class BinOpTrueDiv(ConcreteTemplate):
    cases = [signature(types.float64, op1, op2)
             for op1, op2 in itertools.product(machine_ints, machine_ints)]
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]


@infer_global(operator.floordiv)
class BinOpFloorDiv(ConcreteTemplate):
    cases = list(integer_binop_cases)
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]


@infer_global(operator.ifloordiv)
class BinOpFloorDiv(ConcreteTemplate):
    cases = list(integer_binop_cases)
    cases += [signature(op, op, op) for op in sorted(types.real_domain)]


@infer_global(divmod)
class DivMod(ConcreteTemplate):
    _tys = machine_ints + sorted(types.real_domain)
    cases = [signature(types.UniTuple(ty, 2), ty, ty) for ty in _tys]


@infer_global(operator.pow)
class BinOpPower(ConcreteTemplate):
    cases = list(integer_binop_cases)
    # Ensure that float32 ** int doesn't go through DP computations
    cases += [signature(types.float32, types.float32, op)
              for op in (types.int32, types.int64, types.uint64)]
    cases += [signature(types.float64, types.float64, op)
              for op in (types.int32, types.int64, types.uint64)]
    cases += [signature(op, op, op)
              for op in sorted(types.real_domain)]
    cases += [signature(op, op, op)
              for op in sorted(types.complex_domain)]


@infer_global(operator.ipow)
class BinOpPower(ConcreteTemplate):
    cases = list(integer_binop_cases)
    # Ensure that float32 ** int doesn't go through DP computations
    cases += [signature(types.float32, types.float32, op)
              for op in (types.int32, types.int64, types.uint64)]
    cases += [signature(types.float64, types.float64, op)
              for op in (types.int32, types.int64, types.uint64)]
    cases += [signature(op, op, op)
              for op in sorted(types.real_domain)]
    cases += [signature(op, op, op)
              for op in sorted(types.complex_domain)]


@infer_global(pow)
class PowerBuiltin(BinOpPower):
    # TODO add 3 operand version
    pass


class BitwiseShiftOperation(ConcreteTemplate):
    # For bitshifts, only the first operand's signedness matters
    # to choose the operation's signedness (the second operand
    # should always be positive but will generally be considered
    # signed anyway, since it's often a constant integer).
    # (also, see issue #1995 for right-shifts)

    # The RHS type is fixed to 64-bit signed/unsigned ints.
    # The implementation will always cast the operands to the width of the
    # result type, which is the widest between the LHS type and (u)intp.
    cases = [signature(max(op, types.intp), op, op2)
             for op in sorted(types.signed_domain)
             for op2 in [types.uint64, types.int64]]
    cases += [signature(max(op, types.uintp), op, op2)
              for op in sorted(types.unsigned_domain)
              for op2 in [types.uint64, types.int64]]
    unsafe_casting = False


@infer_global(operator.lshift)
class BitwiseLeftShift(BitwiseShiftOperation):
    pass

@infer_global(operator.ilshift)
class BitwiseLeftShift(BitwiseShiftOperation):
    pass


@infer_global(operator.rshift)
class BitwiseRightShift(BitwiseShiftOperation):
    pass


@infer_global(operator.irshift)
class BitwiseRightShift(BitwiseShiftOperation):
    pass


class BitwiseLogicOperation(BinOp):
    cases = [signature(types.boolean, types.boolean, types.boolean)]
    cases += list(integer_binop_cases)
    unsafe_casting = False


@infer_global(operator.and_)
class BitwiseAnd(BitwiseLogicOperation):
    pass


@infer_global(operator.iand)
class BitwiseAnd(BitwiseLogicOperation):
    pass


@infer_global(operator.or_)
class BitwiseOr(BitwiseLogicOperation):
    pass


@infer_global(operator.ior)
class BitwiseOr(BitwiseLogicOperation):
    pass


@infer_global(operator.xor)
class BitwiseXor(BitwiseLogicOperation):
    pass


@infer_global(operator.ixor)
class BitwiseXor(BitwiseLogicOperation):
    pass


# Bitwise invert and negate are special: we must not upcast the operand
# for unsigned numbers, as that would change the result.
# (i.e. ~np.int8(0) == 255 but ~np.int32(0) == 4294967295).

@infer_global(operator.invert)
class BitwiseInvert(ConcreteTemplate):
    # Note Numba follows the Numpy semantics of returning a bool,
    # while Python returns an int.  This makes it consistent with
    # np.invert() and makes array expressions correct.
    cases = [signature(types.boolean, types.boolean)]
    cases += [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)]
    cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)]

    unsafe_casting = False


class UnaryOp(ConcreteTemplate):
    cases = [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)]
    cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)]
    cases += [signature(op, op) for op in sorted(types.real_domain)]
    cases += [signature(op, op) for op in sorted(types.complex_domain)]
    cases += [signature(types.intp, types.boolean)]


@infer_global(operator.neg)
class UnaryNegate(UnaryOp):
    pass


@infer_global(operator.pos)
class UnaryPositive(UnaryOp):
   pass


@infer_global(operator.not_)
class UnaryNot(ConcreteTemplate):
    cases = [signature(types.boolean, types.boolean)]
    cases += [signature(types.boolean, op) for op in sorted(types.signed_domain)]
    cases += [signature(types.boolean, op) for op in sorted(types.unsigned_domain)]
    cases += [signature(types.boolean, op) for op in sorted(types.real_domain)]
    cases += [signature(types.boolean, op) for op in sorted(types.complex_domain)]


class OrderedCmpOp(ConcreteTemplate):
    cases = [signature(types.boolean, types.boolean, types.boolean)]
    cases += [signature(types.boolean, op, op) for op in sorted(types.signed_domain)]
    cases += [signature(types.boolean, op, op) for op in sorted(types.unsigned_domain)]
    cases += [signature(types.boolean, op, op) for op in sorted(types.real_domain)]


class UnorderedCmpOp(ConcreteTemplate):
    cases = OrderedCmpOp.cases + [
        signature(types.boolean, op, op) for op in sorted(types.complex_domain)]


@infer_global(operator.lt)
class CmpOpLt(OrderedCmpOp):
    pass


@infer_global(operator.le)
class CmpOpLe(OrderedCmpOp):
    pass


@infer_global(operator.gt)
class CmpOpGt(OrderedCmpOp):
    pass


@infer_global(operator.ge)
class CmpOpGe(OrderedCmpOp):
    pass


# more specific overloads should be registered first
@infer_global(operator.eq)
class ConstOpEq(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        (arg1, arg2) = args
        if isinstance(arg1, types.Literal) and isinstance(arg2, types.Literal):
            return signature(types.boolean, arg1, arg2)


@infer_global(operator.ne)
class ConstOpNotEq(ConstOpEq):
    pass


@infer_global(operator.eq)
class CmpOpEq(UnorderedCmpOp):
    pass


@infer_global(operator.ne)
class CmpOpNe(UnorderedCmpOp):
    pass


class TupleCompare(AbstractTemplate):
    def generic(self, args, kws):
        [lhs, rhs] = args
        if isinstance(lhs, types.BaseTuple) and isinstance(rhs, types.BaseTuple):
            for u, v in zip(lhs, rhs):
                # Check element-wise comparability
                res = self.context.resolve_function_type(self.key, (u, v), {})
                if res is None:
                    break
            else:
                return signature(types.boolean, lhs, rhs)


@infer_global(operator.eq)
class TupleEq(TupleCompare):
    pass


@infer_global(operator.ne)
class TupleNe(TupleCompare):
    pass


@infer_global(operator.ge)
class TupleGe(TupleCompare):
    pass


@infer_global(operator.gt)
class TupleGt(TupleCompare):
    pass


@infer_global(operator.le)
class TupleLe(TupleCompare):
    pass


@infer_global(operator.lt)
class TupleLt(TupleCompare):
    pass


@infer_global(operator.add)
class TupleAdd(AbstractTemplate):
    def generic(self, args, kws):
        if len(args) == 2:
            a, b = args
            if (isinstance(a, types.BaseTuple) and isinstance(b, types.BaseTuple)
                and not isinstance(a, types.BaseNamedTuple)
                and not isinstance(b, types.BaseNamedTuple)):
                res = types.BaseTuple.from_types(tuple(a) + tuple(b))
                return signature(res, a, b)


class CmpOpIdentity(AbstractTemplate):
    def generic(self, args, kws):
        [lhs, rhs] = args
        return signature(types.boolean, lhs, rhs)


@infer_global(operator.is_)
class CmpOpIs(CmpOpIdentity):
    pass


@infer_global(operator.is_not)
class CmpOpIsNot(CmpOpIdentity):
    pass


def normalize_1d_index(index):
    """
    Normalize the *index* type (an integer or slice) for indexing a 1D
    sequence.
    """
    if isinstance(index, types.SliceType):
        return index

    elif isinstance(index, types.Integer):
        return types.intp if index.signed else types.uintp


@infer_global(operator.getitem)
class GetItemCPointer(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        ptr, idx = args
        if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer):
            return signature(ptr.dtype, ptr, normalize_1d_index(idx))


@infer_global(operator.setitem)
class SetItemCPointer(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        ptr, idx, val = args
        if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer):
            return signature(types.none, ptr, normalize_1d_index(idx), ptr.dtype)


@infer_global(len)
class Len(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        (val,) = args
        if isinstance(val, (types.Buffer, types.BaseTuple)):
            return signature(types.intp, val)
        elif isinstance(val, (types.RangeType)):
            return signature(val.dtype, val)

@infer_global(tuple)
class TupleConstructor(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        # empty tuple case
        if len(args) == 0:
            return signature(types.Tuple(()))
        (val,) = args
        # tuple as input
        if isinstance(val, types.BaseTuple):
            return signature(val, val)


@infer_global(operator.contains)
class Contains(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        (seq, val) = args

        if isinstance(seq, (types.Sequence)):
            return signature(types.boolean, seq, val)

@infer_global(operator.truth)
class TupleBool(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        (val,) = args
        if isinstance(val, (types.BaseTuple)):
            return signature(types.boolean, val)


@infer
class StaticGetItemTuple(AbstractTemplate):
    key = "static_getitem"

    def generic(self, args, kws):
        tup, idx = args
        ret = None
        if not isinstance(tup, types.BaseTuple):
            return
        if isinstance(idx, int):
            try:
                ret = tup.types[idx]
            except IndexError:
                raise errors.NumbaIndexError("tuple index out of range")
        elif isinstance(idx, slice):
            ret = types.BaseTuple.from_types(tup.types[idx])
        if ret is not None:
            sig = signature(ret, *args)
            return sig


@infer
class StaticGetItemLiteralList(AbstractTemplate):
    key = "static_getitem"

    def generic(self, args, kws):
        tup, idx = args
        ret = None
        if not isinstance(tup, types.LiteralList):
            return
        if isinstance(idx, int):
            ret = tup.types[idx]
        if ret is not None:
            sig = signature(ret, *args)
            return sig


@infer
class StaticGetItemLiteralStrKeyDict(AbstractTemplate):
    key = "static_getitem"

    def generic(self, args, kws):
        tup, idx = args
        ret = None
        if not isinstance(tup, types.LiteralStrKeyDict):
            return
        if isinstance(idx, str):
            if idx in tup.fields:
                lookup = tup.fields.index(idx)
            else:
                raise errors.NumbaKeyError(f"Key '{idx}' is not in dict.")
            ret = tup.types[lookup]
        if ret is not None:
            sig = signature(ret, *args)
            return sig

@infer
class StaticGetItemClass(AbstractTemplate):
    """This handles the "static_getitem" when a Numba type is subscripted e.g:
    var = typed.List.empty_list(float64[::1, :])
    It only allows this on simple numerical types. Compound types, like
    records, are not supported.
    """
    key = "static_getitem"

    def generic(self, args, kws):
        clazz, idx = args
        if not isinstance(clazz, types.NumberClass):
            return
        ret = clazz.dtype[idx]
        sig = signature(ret, *args)
        return sig


# Generic implementation for "not in"

@infer
class GenericNotIn(AbstractTemplate):
    key = "not in"

    def generic(self, args, kws):
        args = args[::-1]
        sig = self.context.resolve_function_type(operator.contains, args, kws)
        return signature(sig.return_type, *sig.args[::-1])


#-------------------------------------------------------------------------------

@infer_getattr
class MemoryViewAttribute(AttributeTemplate):
    key = types.MemoryView

    def resolve_contiguous(self, buf):
        return types.boolean

    def resolve_c_contiguous(self, buf):
        return types.boolean

    def resolve_f_contiguous(self, buf):
        return types.boolean

    def resolve_itemsize(self, buf):
        return types.intp

    def resolve_nbytes(self, buf):
        return types.intp

    def resolve_readonly(self, buf):
        return types.boolean

    def resolve_shape(self, buf):
        return types.UniTuple(types.intp, buf.ndim)

    def resolve_strides(self, buf):
        return types.UniTuple(types.intp, buf.ndim)

    def resolve_ndim(self, buf):
        return types.intp


#-------------------------------------------------------------------------------


@infer_getattr
class BooleanAttribute(AttributeTemplate):
    key = types.Boolean

    def resolve___class__(self, ty):
        return types.NumberClass(ty)

    @bound_function("number.item")
    def resolve_item(self, ty, args, kws):
        assert not kws
        if not args:
            return signature(ty)


@infer_getattr
class NumberAttribute(AttributeTemplate):
    key = types.Number

    def resolve___class__(self, ty):
        return types.NumberClass(ty)

    def resolve_real(self, ty):
        return getattr(ty, "underlying_float", ty)

    def resolve_imag(self, ty):
        return getattr(ty, "underlying_float", ty)

    @bound_function("complex.conjugate")
    def resolve_conjugate(self, ty, args, kws):
        assert not args
        assert not kws
        return signature(ty)

    @bound_function("number.item")
    def resolve_item(self, ty, args, kws):
        assert not kws
        if not args:
            return signature(ty)


@infer_getattr
class NPTimedeltaAttribute(AttributeTemplate):
    key = types.NPTimedelta

    def resolve___class__(self, ty):
        return types.NumberClass(ty)


@infer_getattr
class NPDatetimeAttribute(AttributeTemplate):
    key = types.NPDatetime

    def resolve___class__(self, ty):
        return types.NumberClass(ty)


@infer_getattr
class SliceAttribute(AttributeTemplate):
    key = types.SliceType

    def resolve_start(self, ty):
        return types.intp

    def resolve_stop(self, ty):
        return types.intp

    def resolve_step(self, ty):
        return types.intp

    @bound_function("slice.indices")
    def resolve_indices(self, ty, args, kws):
        assert not kws
        if len(args) != 1:
            raise errors.NumbaTypeError(
                "indices() takes exactly one argument (%d given)" % len(args)
            )
        typ, = args
        if not isinstance(typ, types.Integer):
            raise errors.NumbaTypeError(
                "'%s' object cannot be interpreted as an integer" % typ
            )
        return signature(types.UniTuple(types.intp, 3), types.intp)


#-------------------------------------------------------------------------------


@infer_getattr
class NumberClassAttribute(AttributeTemplate):
    key = types.NumberClass

    def resolve___call__(self, classty):
        """
        Resolve a NumPy number class's constructor (e.g. calling numpy.int32(...))
        """
        ty = classty.instance_type

        def typer(val):
            if isinstance(val, (types.BaseTuple, types.Sequence)):
                # Array constructor, e.g. np.int32([1, 2])
                fnty = self.context.resolve_value_type(np.array)
                sig = fnty.get_call_type(self.context, (val, types.DType(ty)),
                                         {})
                return sig.return_type
            elif isinstance(val, (types.Number, types.Boolean, types.IntEnumMember)):
                 # Scalar constructor, e.g. np.int32(42)
                 return ty
            elif isinstance(val, (types.NPDatetime, types.NPTimedelta)):
                # Constructor cast from datetime-like, e.g.
                # > np.int64(np.datetime64("2000-01-01"))
                if ty.bitwidth == 64:
                    return ty
                else:
                    msg = (f"Cannot cast {val} to {ty} as {ty} is not 64 bits "
                           "wide.")
                    raise errors.TypingError(msg)
            else:
                if (isinstance(val, types.Array) and val.ndim == 0 and
                    val.dtype == ty):
                    # This is 0d array -> scalar degrading
                    return ty
                else:
                    # unsupported
                    msg = f"Casting {val} to {ty} directly is unsupported."
                    if isinstance(val, types.Array):
                        # array casts are supported a different way.
                        msg += f" Try doing '<array>.astype(np.{ty})' instead"
                    raise errors.TypingError(msg)

        return types.Function(make_callable_template(key=ty, typer=typer))


@infer_getattr
class TypeRefAttribute(AttributeTemplate):
    key = types.TypeRef

    def resolve___call__(self, classty):
        """
        Resolve a core number's constructor (e.g. calling int(...))

        Note:

        This is needed because of the limitation of the current type-system
        implementation.  Specifically, the lack of a higher-order type
        (i.e. passing the ``DictType`` vs ``DictType(key_type, value_type)``)
        """
        ty = classty.instance_type

        if isinstance(ty, type) and issubclass(ty, types.Type):
            # Redirect the typing to a:
            #   @type_callable(ty)
            #   def typeddict_call(context):
            #        ...
            # For example, see numba/typed/typeddict.py
            #   @type_callable(DictType)
            #   def typeddict_call(context):
            class Redirect(object):

                def __init__(self, context):
                    self.context =  context

                def __call__(self, *args, **kwargs):
                    result = self.context.resolve_function_type(ty, args, kwargs)
                    if hasattr(result, "pysig"):
                        self.pysig = result.pysig
                    return result

            return types.Function(make_callable_template(key=ty,
                                                         typer=Redirect(self.context)))


#------------------------------------------------------------------------------


class MinMaxBase(AbstractTemplate):

    def _unify_minmax(self, tys):
        for ty in tys:
            if not isinstance(ty, (types.Number, types.NPDatetime, types.NPTimedelta)):
                return
        return self.context.unify_types(*tys)

    def generic(self, args, kws):
        """
        Resolve a min() or max() call.
        """
        assert not kws

        if not args:
            return
        if len(args) == 1:
            # max(arg) only supported if arg is an iterable
            if isinstance(args[0], types.BaseTuple):
                tys = list(args[0])
                if not tys:
                    raise TypeError("%s() argument is an empty tuple"
                                    % (self.key.__name__,))
            else:
                return
        else:
            # max(*args)
            tys = args
        retty = self._unify_minmax(tys)
        if retty is not None:
            return signature(retty, *args)


@infer_global(max)
class Max(MinMaxBase):
    pass


@infer_global(min)
class Min(MinMaxBase):
    pass


@infer_global(round)
class Round(ConcreteTemplate):
    cases = [
        signature(types.intp, types.float32),
        signature(types.int64, types.float64),
        signature(types.float32, types.float32, types.intp),
        signature(types.float64, types.float64, types.intp),
    ]


#------------------------------------------------------------------------------


@infer_global(bool)
class Bool(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        [arg] = args
        if isinstance(arg, (types.Boolean, types.Number)):
            return signature(types.boolean, arg)
        # XXX typing for bool cannot be polymorphic because of the
        # types.Function thing, so we redirect to the operator.truth
        # intrinsic.
        return self.context.resolve_function_type(operator.truth, args, kws)


@infer_global(int)
class Int(AbstractTemplate):

    def generic(self, args, kws):
        if kws:
            raise errors.NumbaAssertionError('kws not supported')

        [arg] = args

        if isinstance(arg, types.Integer):
            return signature(arg, arg)
        if isinstance(arg, (types.Float, types.Boolean)):
            return signature(types.intp, arg)
        if isinstance(arg, types.NPDatetime):
            if arg.unit == 'ns':
                return signature(types.int64, arg)
            else:
                raise errors.NumbaTypeError(f"Only datetime64[ns] can be converted, but got datetime64[{arg.unit}]")
        if isinstance(arg, types.NPTimedelta):
            return signature(types.int64, arg)


@infer_global(float)
class Float(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws

        [arg] = args

        if isinstance(arg, types.UnicodeType):
            msg = 'argument must be a string literal'
            raise errors.RequireLiteralValue(msg)

        if isinstance(arg, types.StringLiteral):
            return signature(types.float64, arg)

        if arg not in types.number_domain:
            raise errors.NumbaTypeError("float() only support for numbers")

        if arg in types.complex_domain:
            raise errors.NumbaTypeError("float() does not support complex")

        if arg in types.integer_domain:
            return signature(types.float64, arg)

        elif arg in types.real_domain:
            return signature(arg, arg)


@infer_global(complex)
class Complex(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws

        if len(args) == 1:
            [arg] = args
            if arg not in types.number_domain:
                raise errors.NumbaTypeError("complex() only support for numbers")
            if arg == types.float32:
                return signature(types.complex64, arg)
            else:
                return signature(types.complex128, arg)

        elif len(args) == 2:
            [real, imag] = args
            if (real not in types.number_domain or
                imag not in types.number_domain):
                raise errors.NumbaTypeError("complex() only support for numbers")
            if real == imag == types.float32:
                return signature(types.complex64, real, imag)
            else:
                return signature(types.complex128, real, imag)


#------------------------------------------------------------------------------

@infer_global(enumerate)
class Enumerate(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        it = args[0]
        if len(args) > 1 and not isinstance(args[1], types.Integer):
            raise errors.NumbaTypeError("Only integers supported as start "
                                        "value in enumerate")
        elif len(args) > 2:
            #let python raise its own error
            enumerate(*args)

        if isinstance(it, types.IterableType):
            enumerate_type = types.EnumerateType(it)
            return signature(enumerate_type, *args)


@infer_global(zip)
class Zip(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        if all(isinstance(it, types.IterableType) for it in args):
            zip_type = types.ZipType(args)
            return signature(zip_type, *args)


@infer_global(iter)
class Iter(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        if len(args) == 1:
            it = args[0]
            if isinstance(it, types.IterableType):
                return signature(it.iterator_type, *args)


@infer_global(next)
class Next(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        if len(args) == 1:
            it = args[0]
            if isinstance(it, types.IteratorType):
                return signature(it.yield_type, *args)


#------------------------------------------------------------------------------

@infer_global(type)
class TypeBuiltin(AbstractTemplate):

    def generic(self, args, kws):
        assert not kws
        if len(args) == 1:
            # One-argument type() -> return the __class__
            # Avoid literal types
            arg = types.unliteral(args[0])
            classty = self.context.resolve_getattr(arg, "__class__")
            if classty is not None:
                return signature(classty, *args)


#------------------------------------------------------------------------------

@infer_getattr
class OptionalAttribute(AttributeTemplate):
    key = types.Optional

    def generic_resolve(self, optional, attr):
        return self.context.resolve_getattr(optional.type, attr)

#------------------------------------------------------------------------------

@infer_getattr
class DeferredAttribute(AttributeTemplate):
    key = types.DeferredType

    def generic_resolve(self, deferred, attr):
        return self.context.resolve_getattr(deferred.get(), attr)

#------------------------------------------------------------------------------

@infer_global(get_type_min_value)
@infer_global(get_type_max_value)
class MinValInfer(AbstractTemplate):
    def generic(self, args, kws):
        assert not kws
        assert len(args) == 1
        if isinstance(args[0], (types.DType, types.NumberClass)):
            return signature(args[0].dtype, *args)


#------------------------------------------------------------------------------


class IndexValue(object):
    """
    Index and value
    """
    def __init__(self, ind, val):
        self.index = ind
        self.value = val

    def __repr__(self):
        return 'IndexValue(%f, %f)' % (self.index, self.value)


class IndexValueType(types.Type):
    def __init__(self, val_typ):
        self.val_typ = val_typ
        super(IndexValueType, self).__init__(
                                    name='IndexValueType({})'.format(val_typ))


@typeof_impl.register(IndexValue)
def typeof_index(val, c):
    val_typ = typeof_impl(val.value, c)
    return IndexValueType(val_typ)


@type_callable(IndexValue)
def type_index_value(context):
    def typer(ind, mval):
        if ind == types.intp or ind == types.uintp:
            return IndexValueType(mval)
    return typer


@register_model(IndexValueType)
class IndexValueModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [
            ('index', types.intp),
            ('value', fe_type.val_typ),
            ]
        models.StructModel.__init__(self, dmm, fe_type, members)


make_attribute_wrapper(IndexValueType, 'index', 'index')
make_attribute_wrapper(IndexValueType, 'value', 'value')
