import copy
import json
import sys
import warnings
from collections import defaultdict, namedtuple
from dataclasses import (MISSING,
                         fields,
                         is_dataclass  # type: ignore
                         )
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import (Any, Collection, Mapping, Union, get_type_hints,
                    Tuple, TypeVar, Type)
from uuid import UUID

from typing_inspect import is_union_type  # type: ignore

from dataclasses_json import cfg
from dataclasses_json.utils import (_get_type_cons, _get_type_origin,
                                    _handle_undefined_parameters_safe,
                                    _is_collection, _is_mapping, _is_new_type,
                                    _is_optional, _isinstance_safe,
                                    _get_type_arg_param,
                                    _get_type_args, _is_counter,
                                    _NO_ARGS,
                                    _issubclass_safe, _is_tuple)

Json = Union[dict, list, str, int, float, bool, None]

confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude']
FieldOverride = namedtuple('FieldOverride', confs)  # type: ignore


class _ExtendedEncoder(json.JSONEncoder):
    def default(self, o) -> Json:
        result: Json
        if _isinstance_safe(o, Collection):
            if _isinstance_safe(o, Mapping):
                result = dict(o)
            else:
                result = list(o)
        elif _isinstance_safe(o, datetime):
            result = o.timestamp()
        elif _isinstance_safe(o, UUID):
            result = str(o)
        elif _isinstance_safe(o, Enum):
            result = o.value
        elif _isinstance_safe(o, Decimal):
            result = str(o)
        else:
            result = json.JSONEncoder.default(self, o)
        return result


def _user_overrides_or_exts(cls):
    global_metadata = defaultdict(dict)
    encoders = cfg.global_config.encoders
    decoders = cfg.global_config.decoders
    mm_fields = cfg.global_config.mm_fields
    for field in fields(cls):
        if field.type in encoders:
            global_metadata[field.name]['encoder'] = encoders[field.type]
        if field.type in decoders:
            global_metadata[field.name]['decoder'] = decoders[field.type]
        if field.type in mm_fields:
            global_metadata[field.name]['mm_field'] = mm_fields[field.type]
    try:
        cls_config = (cls.dataclass_json_config
                      if cls.dataclass_json_config is not None else {})
    except AttributeError:
        cls_config = {}

    overrides = {}
    for field in fields(cls):
        field_config = {}
        # first apply global overrides or extensions
        field_metadata = global_metadata[field.name]
        if 'encoder' in field_metadata:
            field_config['encoder'] = field_metadata['encoder']
        if 'decoder' in field_metadata:
            field_config['decoder'] = field_metadata['decoder']
        if 'mm_field' in field_metadata:
            field_config['mm_field'] = field_metadata['mm_field']
        # then apply class-level overrides or extensions
        field_config.update(cls_config)
        # last apply field-level overrides or extensions
        field_config.update(field.metadata.get('dataclasses_json', {}))
        overrides[field.name] = FieldOverride(*map(field_config.get, confs))
    return overrides


def _encode_json_type(value, default=_ExtendedEncoder().default):
    if isinstance(value, Json.__args__):  # type: ignore
        if isinstance(value, list):
            return [_encode_json_type(i) for i in value]
        elif isinstance(value, dict):
            return {k: _encode_json_type(v) for k, v in value.items()}
        else:
            return value
    return default(value)


def _encode_overrides(kvs, overrides, encode_json=False):
    override_kvs = {}
    for k, v in kvs.items():
        if k in overrides:
            exclude = overrides[k].exclude
            # If the exclude predicate returns true, the key should be
            #  excluded from encoding, so skip the rest of the loop
            if exclude and exclude(v):
                continue
            letter_case = overrides[k].letter_case
            original_key = k
            k = letter_case(k) if letter_case is not None else k
            if k in override_kvs:
                raise ValueError(
                    f"Multiple fields map to the same JSON "
                    f"key after letter case encoding: {k}"
                )

            encoder = overrides[original_key].encoder
            v = encoder(v) if encoder is not None else v

        if encode_json:
            v = _encode_json_type(v)
        override_kvs[k] = v
    return override_kvs


def _decode_letter_case_overrides(field_names, overrides):
    """Override letter case of field names for encode/decode"""
    names = {}
    for field_name in field_names:
        field_override = overrides.get(field_name)
        if field_override is not None:
            letter_case = field_override.letter_case
            if letter_case is not None:
                names[letter_case(field_name)] = field_name
    return names


def _decode_dataclass(cls, kvs, infer_missing):
    if _isinstance_safe(kvs, cls):
        return kvs
    overrides = _user_overrides_or_exts(cls)
    kvs = {} if kvs is None and infer_missing else kvs
    field_names = [field.name for field in fields(cls)]
    decode_names = _decode_letter_case_overrides(field_names, overrides)
    kvs = {decode_names.get(k, k): v for k, v in kvs.items()}
    missing_fields = {field for field in fields(cls) if field.name not in kvs}

    for field in missing_fields:
        if field.default is not MISSING:
            kvs[field.name] = field.default
        elif field.default_factory is not MISSING:
            kvs[field.name] = field.default_factory()
        elif infer_missing:
            kvs[field.name] = None

    # Perform undefined parameter action
    kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from")

    init_kwargs = {}
    types = get_type_hints(cls)
    for field in fields(cls):
        # The field should be skipped from being added
        # to init_kwargs as it's not intended as a constructor argument.
        if not field.init:
            continue

        field_value = kvs[field.name]
        field_type = types[field.name]
        if field_value is None:
            if not _is_optional(field_type):
                warning = (
                    f"value of non-optional type {field.name} detected "
                    f"when decoding {cls.__name__}"
                )
                if infer_missing:
                    warnings.warn(
                        f"Missing {warning} and was defaulted to None by "
                        f"infer_missing=True. "
                        f"Set infer_missing=False (the default) to prevent "
                        f"this behavior.", RuntimeWarning
                    )
                else:
                    warnings.warn(
                        f"'NoneType' object {warning}.", RuntimeWarning
                    )
            init_kwargs[field.name] = field_value
            continue

        while True:
            if not _is_new_type(field_type):
                break

            field_type = field_type.__supertype__

        if (field.name in overrides
                and overrides[field.name].decoder is not None):
            # FIXME hack
            if field_type is type(field_value):
                init_kwargs[field.name] = field_value
            else:
                init_kwargs[field.name] = overrides[field.name].decoder(
                    field_value)
        elif is_dataclass(field_type):
            # FIXME this is a band-aid to deal with the value already being
            # serialized when handling nested marshmallow schema
            # proper fix is to investigate the marshmallow schema generation
            # code
            if is_dataclass(field_value):
                value = field_value
            else:
                value = _decode_dataclass(field_type, field_value,
                                          infer_missing)
            init_kwargs[field.name] = value
        elif _is_supported_generic(field_type) and field_type != str:
            init_kwargs[field.name] = _decode_generic(field_type,
                                                      field_value,
                                                      infer_missing)
        else:
            init_kwargs[field.name] = _support_extended_types(field_type,
                                                              field_value)

    return cls(**init_kwargs)


def _support_extended_types(field_type, field_value):
    if _issubclass_safe(field_type, datetime):
        # FIXME this is a hack to deal with mm already decoding
        # the issue is we want to leverage mm fields' missing argument
        # but need this for the object creation hook
        if isinstance(field_value, datetime):
            res = field_value
        else:
            tz = datetime.now(timezone.utc).astimezone().tzinfo
            res = datetime.fromtimestamp(field_value, tz=tz)
    elif _issubclass_safe(field_type, Decimal):
        res = (field_value
               if isinstance(field_value, Decimal)
               else Decimal(field_value))
    elif _issubclass_safe(field_type, UUID):
        res = (field_value
               if isinstance(field_value, UUID)
               else UUID(field_value))
    elif _issubclass_safe(field_type, (int, float, str, bool)):
        res = (field_value
               if isinstance(field_value, field_type)
               else field_type(field_value))
    else:
        res = field_value
    return res


def _is_supported_generic(type_):
    if type_ is _NO_ARGS:
        return False
    not_str = not _issubclass_safe(type_, str)
    is_enum = _issubclass_safe(type_, Enum)
    return (not_str and _is_collection(type_)) or _is_optional(
        type_) or is_union_type(type_) or is_enum


def _decode_generic(type_, value, infer_missing):
    if value is None:
        res = value
    elif _issubclass_safe(type_, Enum):
        # Convert to an Enum using the type as a constructor.
        # Assumes a direct match is found.
        res = type_(value)
    # FIXME this is a hack to fix a deeper underlying issue. A refactor is due.
    elif _is_collection(type_):
        if _is_mapping(type_) and not _is_counter(type_):
            k_type, v_type = _get_type_args(type_, (Any, Any))
            # a mapping type has `.keys()` and `.values()`
            # (see collections.abc)
            ks = _decode_dict_keys(k_type, value.keys(), infer_missing)
            vs = _decode_items(v_type, value.values(), infer_missing)
            xs = zip(ks, vs)
        elif _is_tuple(type_):
            types = _get_type_args(type_)
            if Ellipsis in types:
                xs = _decode_items(types[0], value, infer_missing)
            else:
                xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing)
        elif _is_counter(type_):
            xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values()))
        else:
            xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing)

        # get the constructor if using corresponding generic type in `typing`
        # otherwise fallback on constructing using type_ itself
        materialize_type = type_
        try:
            materialize_type = _get_type_cons(type_)
        except (TypeError, AttributeError):
            pass
        res = materialize_type(xs)
    else:  # Optional or Union
        _args = _get_type_args(type_)
        if _args is _NO_ARGS:
            # Any, just accept
            res = value
        elif _is_optional(type_) and len(_args) == 2:  # Optional
            type_arg = _get_type_arg_param(type_, 0)
            if is_dataclass(type_arg) or is_dataclass(value):
                res = _decode_dataclass(type_arg, value, infer_missing)
            elif _is_supported_generic(type_arg):
                res = _decode_generic(type_arg, value, infer_missing)
            else:
                res = _support_extended_types(type_arg, value)
        else:  # Union (already decoded or try to decode a dataclass)
            type_options = _get_type_args(type_)
            res = value  # assume already decoded
            if type(value) is dict and dict not in type_options:
                for type_option in type_options:
                    if is_dataclass(type_option):
                        try:
                            res = _decode_dataclass(type_option, value, infer_missing)
                            break
                        except (KeyError, ValueError, AttributeError):
                            continue
                if res == value:
                    warnings.warn(
                        f"Failed to decode {value} Union dataclasses."
                        f"Expected Union to include a matching dataclass and it didn't."
                    )
    return res


def _decode_dict_keys(key_type, xs, infer_missing):
    """
    Because JSON object keys must be strs, we need the extra step of decoding
    them back into the user's chosen python type
    """
    decode_function = key_type
    # handle NoneType keys... it's weird to type a Dict as NoneType keys
    # but it's valid...
    # Issue #341 and PR #346:
    #   This is a special case for Python 3.7 and Python 3.8.
    #   By some reason, "unbound" dicts are counted
    #   as having key type parameter to be TypeVar('KT')
    if key_type is None or key_type == Any or isinstance(key_type, TypeVar):
        decode_function = key_type = (lambda x: x)
    # handle a nested python dict that has tuples for keys. E.g. for
    # Dict[Tuple[int], int], key_type will be typing.Tuple[int], but
    # decode_function should be tuple, so map() doesn't break.
    #
    # Note: _get_type_origin() will return typing.Tuple for python
    # 3.6 and tuple for 3.7 and higher.
    elif _get_type_origin(key_type) in {tuple, Tuple}:
        decode_function = tuple
        key_type = key_type

    return map(decode_function, _decode_items(key_type, xs, infer_missing))


def _decode_items(type_args, xs, infer_missing):
    """
    This is a tricky situation where we need to check both the annotated
    type info (which is usually a type from `typing`) and check the
    value's type directly using `type()`.

    If the type_arg is a generic we can use the annotated type, but if the
    type_arg is a typevar we need to extract the reified type information
    hence the check of `is_dataclass(vs)`
    """
    def _decode_item(type_arg, x):
        if is_dataclass(type_arg) or is_dataclass(xs):
            return _decode_dataclass(type_arg, x, infer_missing)
        if _is_supported_generic(type_arg):
            return _decode_generic(type_arg, x, infer_missing)
        return x

    def handle_pep0673(pre_0673_hint: str) -> Union[Type, str]:
        for module in sys.modules:
            maybe_resolved = getattr(sys.modules[module], type_args, None)
            if maybe_resolved:
                return maybe_resolved

        warnings.warn(f"Could not resolve self-reference for type {pre_0673_hint}, "
                      f"decoded type might be incorrect or decode might fail altogether.")
        return pre_0673_hint

    # Before https://peps.python.org/pep-0673 (3.11+) self-type hints are simply strings
    if sys.version_info.minor < 11 and type_args is not type and type(type_args) is str:
        type_args = handle_pep0673(type_args)

    if _isinstance_safe(type_args, Collection) and not _issubclass_safe(type_args, Enum):
        if len(type_args) == len(xs):
            return list(_decode_item(type_arg, x) for type_arg, x in zip(type_args, xs))
        else:
            raise TypeError(f"Number of types specified in the collection type {str(type_args)} "
                            f"does not match number of elements in the collection. In case you are working with tuples"
                            f"take a look at this document "
                            f"docs.python.org/3/library/typing.html#annotating-tuples.")
    return list(_decode_item(type_args, x) for x in xs)


def _asdict(obj, encode_json=False):
    """
    A re-implementation of `asdict` (based on the original in the `dataclasses`
    source) to support arbitrary Collection and Mapping types.
    """
    if is_dataclass(obj):
        result = []
        overrides = _user_overrides_or_exts(obj)
        for field in fields(obj):
            if overrides[field.name].encoder:
                value = getattr(obj, field.name)
            else:
                value = _asdict(
                    getattr(obj, field.name),
                    encode_json=encode_json
                )
            result.append((field.name, value))

        result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result),
                                                   usage="to")
        return _encode_overrides(dict(result), _user_overrides_or_exts(obj),
                                 encode_json=encode_json)
    elif isinstance(obj, Mapping):
        return dict((_asdict(k, encode_json=encode_json),
                     _asdict(v, encode_json=encode_json)) for k, v in
                    obj.items())
    # enum.IntFlag and enum.Flag are regarded as collections in Python 3.11, thus a check against Enum is needed
    elif isinstance(obj, Collection) and not isinstance(obj, (str, bytes, Enum)):
        return list(_asdict(v, encode_json=encode_json) for v in obj)
    else:
        return copy.deepcopy(obj)