import string
from typing import List, Union, Tuple

import nltk
import torch
from numpy.typing import NDArray

from .type_aliases import DEVICE_TYPE, ENCODER_DEVICE_TYPE, NumSentencesType, EmbeddingSlicesType


def get_gpu(gpu: DEVICE_TYPE) -> ENCODER_DEVICE_TYPE:
    """
        Determine the correct GPU device based on the provided input. In the following, output 0 means CUDA device 0.

        Args:
            gpu (Union[bool, str, int, List[Union[str, int]]]): Input specifying the GPU device(s):
                - bool: If True, returns 0 if CUDA is available, otherwise returns "cpu".
                - str: Can be "cpu", "gpu", or "cuda" (case-insensitive). Returns 0 if CUDA is available
                  and the input is not "cpu", otherwise returns "cpu".
                - int: Should be a valid GPU index. Returns the index if CUDA is available and valid,
                  otherwise returns "cpu".
                - List[Union[str, int]]: List containing combinations of the str/int. Processes each
                  element and returns a list of corresponding results.

        Returns:
            Union[str, int, List[Union[str, int]]]: Depending on the input type:
                - str: Returns "cpu" if no GPU is available or the input is "cpu".
                - int: Returns the GPU index if valid and CUDA is available.
                - List[Union[str, int]]: Returns a list of strings and/or integers based on the input list.

        Raises:
            ValueError: If the input gpu type is not recognized or invalid.
            ValueError: If a string input is not one of ["cpu", "gpu", "cuda"].
            ValueError: If an integer input is outside the valid range of GPU indices.

        Notes:
            - This function checks CUDA availability using torch.cuda.is_available() and counts
              available GPUs using torch.cuda.device_count().
            - Case insensitivity is maintained for string inputs ("cpu", "gpu", "cuda").
            - The function ensures robust error handling for invalid input types or out-of-range indices.
        """

    # Ensure gpu index is within the range of total available gpus
    gpu_available = torch.cuda.is_available()
    gpu_count = torch.cuda.device_count()
    correct_strs = ["cpu", "gpu", "cuda"]

    def _get_single_device(gpu_item):
        if isinstance(gpu_item, bool):
            return 0 if gpu_item and gpu_available else "cpu"
        elif isinstance(gpu_item, str):
            if gpu_item.lower() not in correct_strs:
                raise ValueError(f"Wrong gpu type: {gpu_item}. Should be one of {correct_strs}")
            return 0 if (gpu_item.lower() != "cpu") and gpu_available else "cpu"
        elif isinstance(gpu_item, int):
            if gpu_item >= gpu_count:
                raise ValueError(
                    f"There are {gpu_count} GPUs available. Provide a valid GPU index. You provided: {gpu_item}"
                )
            return gpu_item if gpu_available else "cpu"
        else:
            raise ValueError(f"Invalid gpu type: {type(gpu_item)}. Must be bool, str, or int.")

    if isinstance(gpu, list):
        seen_indices = set()
        result = []
        for item in gpu:
            device = _get_single_device(item)
            if isinstance(device, int):
                if device not in seen_indices:
                    seen_indices.add(device)
                    result.append(device)
            else:
                result.append(device)
        return result[0] if len(result) == 1 else result
    else:
        return _get_single_device(gpu)


def prep_sentences(sentences: List[str]) -> List[str]:
    """
    Processes a list of sentences by stripping whitespace (at beginning and the end),
    , filtering out empty sentences or sentences that only contains punctuations.

    Args:
        sentences (List[str]): A list of sentences to be processed.

    Returns:
        List[str]: A list of cleaned sentences

    Raises:
        ValueError: If the resulting list of sentences is empty.

    Example:
        >>> prep_sentences(["Hello, world!", " This is a test. ", "!!!"])
        ['Hello, world!', 'This is a test.']

        >>> prep_sentences(["!!!", "..."])
        ValueError: Document can't be empty.
    """
    out = []
    for sent in sentences:
        sent = sent.strip()
        sent_wo_punctuation = (
            sent.translate(str.maketrans("", "", string.punctuation))
        ).strip()
        if sent_wo_punctuation:
            out.append(sent)

    if len(out) == 0:
        raise ValueError("Document can't be empty.")
    return out


def tokenize_and_prep_document(document: Union[str, List[str]], tokenize: bool) -> List[str]:
    """
    Tokenizes and prepares a document by either tokenizing it into sentences and processing each sentence,
    or directly processing each element if `tokenize` is False.

    Args:
        document (Union[str, List[str]]): The document to be processed. It can be a single string (enitre document) or a
         list of strings (list of sentences).
        tokenize (bool): If True, tokenizes `document` into sentences using NLTK's sentence tokenizer before processing.
                         If False, processes each element of `document` directly as sentences.

    Returns:
        List[str]: A list of cleaned sentences.

    Raises:
        ValueError: If the resulting list of sentences is empty after processing.

    Example:
        >>> tokenize_and_prep_document("Hello, world! This is a test.", True)
        ['Hello, world!', 'This is a test.']

        >>> tokenize_and_prep_document(["Hello, world!", "This is a test."], False)
        ['Hello, world!', 'This is a test.']

        >>> tokenize_and_prep_document("!!! ...", True)
        ValueError: Document can't be empty.
        
    Note: Only the following two cases are possible. 
        tokenizer=True -> document: str
        tokenizer=False -> document: List[str]. 
    """
    if tokenize:
        return prep_sentences(nltk.tokenize.sent_tokenize(document))
    return prep_sentences(document)


def flatten_list(nested_list: list) -> list:
    """
    Recursively flattens a nested list of any depth.

    Parameters:
        nested_list (list): The nested list to flatten.

    Returns:
        list: A flat list containing all the elements of the nested list.
    """
    flat_list = []
    for item in nested_list:
        if isinstance(item, list):
            flat_list.extend(flatten_list(item))
        else:
            flat_list.append(item)
    return flat_list


def is_nested_list_of_type(lst_obj, element_type, depth: int) -> Tuple[bool, str]:
    """
    Check if the given object is a nested list of a specific type up to a specified depth.

    Args:
    - lst_obj: The object to check, expected to be a list or a single element.
    - element_type: The type that each element in the nested list should match.
    - depth (int): The depth of nesting to check. Must be non-negative.

    Returns:
    - Tuple[bool, str]: A tuple containing:
        - A boolean indicating if lst_obj is a nested list of the specified type up to the given depth.
        - A string containing an error message if the check fails, or an empty string if the check passes.

    Raises:
    - ValueError: If depth is negative.

    Example:
    ```python
    # Test cases
    is_nested_list_of_type("test", str, 0)  # Returns (True, "")
    is_nested_list_of_type([1, 2, 3], str, 0)  # Returns (False, "Element is of type int, expected type str.")
    is_nested_list_of_type(["apple", "banana"], str, 1)  # Returns (True, "")
    is_nested_list_of_type([[1, 2], [3, 4]], int, 2)  # Returns (True, "")
    is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)  # Returns (False, "Element at index 1 is of incorrect type.")
    is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3)  # Returns (True, "")
    ```

    Explanation:
    - The function checks if `lst_obj` is a nested list of elements of type `element_type` up to `depth` levels deep.
    - If `depth` is 0, it checks if `lst_obj` itself is of type `element_type`.
    - If `depth` is greater than 0, it recursively checks each level of nesting to ensure all elements match
    `element_type`.
    - Returns a tuple containing a boolean and an error message. The boolean is `True` if `lst_obj` matches the
    criteria, `False` otherwise. The error message provides details if the check fails.
    - Raises a `ValueError` if `depth` is negative, as depth must be a non-negative integer.
    """
    orig_depth = depth

    def _is_nested_list_of_type(lst_o, e_type, d) -> Tuple[bool, str]:
        if d == 0:
            if isinstance(lst_o, e_type):
                return True, ""
            else:
                return False, f"Element is of type {type(lst_o).__name__}, expected type {e_type.__name__}."
        elif d > 0:
            if isinstance(lst_o, list):
                for i, item in enumerate(lst_o):
                    is_valid, err = _is_nested_list_of_type(item, e_type, d - 1)
                    if not is_valid:
                        msg = (f"Element at index {i} has incorrect type.\nGiven Element at index {i}: {lst_o[i]}"
                               f"\n{err}") if d == orig_depth else err
                        return False, msg
                return True, ""
            else:
                return False, f"Object is not a list but {type(lst_o)}."
        else:
            raise ValueError("Depth can't be negative")

    return _is_nested_list_of_type(lst_obj, element_type, depth)


def slice_embeddings(embeddings: NDArray, num_sentences: NumSentencesType) -> EmbeddingSlicesType:
    """
        Slice embeddings into segments based on the provided number of sentences per segment.

        Args:
        - embeddings (np.ndarray): The array of embeddings to be sliced.
        - num_sentences (Union[List[int], List[List[int]]]):
            - If a list of integers: Specifies the number of embeddings to take in each slice.
            - If a list of lists of integers: Specifies multiple nested levels of slicing.

        Returns:
        - List[np.ndarray]: A list of numpy arrays where each array represents a slice of embeddings.

        Raises:
        - TypeError: If `num_sentences` is not of type List[int] or List[List[int]].

        Example Usage:

        ```python
        embeddings = np.random.rand(10, 5)
        num_sentences = [3, 2, 5]
        result = slice_embeddings(embeddings, num_sentences)
        # `result` will be a list of numpy arrays:
        # [embeddings[:3], embeddings[3:5], embeddings[5:]]

        num_sentences_nested = [[2, 1], [3, 4]]
        result_nested = slice_embeddings(embeddings, num_sentences_nested)
        # `result_nested` will be a nested list of numpy arrays:
        # [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]

        slice_embeddings(embeddings, "invalid")  # Raises a TypeError
        ```
    """

    def _slice_embeddings(s_idx: int, n_sentences: List[int]):
        """
            Helper function to slice embeddings starting from index `s_idx`.

            Args:
            - s_idx (int): Starting index for slicing.
            - n_sentences (List[int]): List specifying number of sentences in each slice.

            Returns:
            - Tuple[List[np.ndarray], int]: A tuple containing a list of sliced embeddings and the next starting index.
        """
        _result = []
        for count in n_sentences:
            _result.append(embeddings[s_idx:s_idx + count])
            s_idx += count
        return _result, s_idx

    if isinstance(num_sentences, list) and all(isinstance(item, int) for item in num_sentences):
        result, _ = _slice_embeddings(0, num_sentences)
        return result
    elif isinstance(num_sentences, list) and all(
            isinstance(sublist, list) and all(
                isinstance(item, int) for item in sublist
            )
            for sublist in num_sentences
    ):
        nested_result = []
        start_idx = 0
        for nested_num_sentences in num_sentences:
            embedding_slice, start_idx = _slice_embeddings(start_idx, nested_num_sentences)
            nested_result.append(embedding_slice)

        return nested_result
    else:
        raise TypeError(f"Incorrect Type for {num_sentences=}")