"""Custom exceptions for the LLMFoundry."""
from collections.abc import Mapping
from typing import Any, Dict, List

class MissingHuggingFaceURLSplitError(ValueError):
    """Error thrown when there's no split used in HF dataset config."""

    def __init__(self) -> None:
        message = 'When using a HuggingFace dataset from a URL, you must set the ' + '`split` key in the dataset config.'
        super().__init__(message)

class NotEnoughDatasetSamplesError(ValueError):
    """Error thrown when there is not enough data to train a model."""

    def __init__(self, dataset_name: str, split: str, dataloader_batch_size: int, world_size: int, full_dataset_size: int, minimum_dataset_size: int) -> None:
        self.dataset_name = dataset_name
        self.split = split
        self.dataloader_batch_size = dataloader_batch_size
        self.world_size = world_size
        self.full_dataset_size = full_dataset_size
        self.minimum_dataset_size = minimum_dataset_size
        message = f'Your dataset (name={dataset_name}, split={split}) ' + f'has {full_dataset_size} samples, but your minimum batch size ' + f'is {minimum_dataset_size} because you are running on {world_size} gpus and ' + f'your per device batch size is {dataloader_batch_size}. Please increase the number ' + f'of samples in your dataset to at least {minimum_dataset_size}.'
        super().__init__(message)

class UnknownExampleTypeError(KeyError):
    """Error thrown when an unknown example type is used in a task."""

    def __init__(self, example: Mapping) -> None:
        message = f'Unknown example type example={example!r}'
        super().__init__(message)

class TooManyKeysInExampleError(ValueError):
    """Error thrown when a data sample has too many keys."""

    def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
        message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
        super().__init__(message)

class NotEnoughChatDataError(ValueError):
    """Error thrown when there is not enough chat data to train a model."""

    def __init__(self) -> None:
        message = 'Chat example must have at least two messages'
        super().__init__(message)

class ConsecutiveRepeatedChatRolesError(ValueError):
    """Error thrown when there are consecutive repeated chat roles."""

    def __init__(self, repeated_role: str) -> None:
        self.repeated_role = repeated_role
        message = f'Conversation roles must alternate but found {repeated_role} repeated consecutively.'
        super().__init__(message)

class InvalidLastChatMessageRoleError(ValueError):
    """Error thrown when the last message role in a chat example is invalid."""

    def __init__(self, last_role: str, expected_roles: set[str]) -> None:
        message = f'Invalid last message role: {last_role}. Expected one of: {expected_roles}'
        super().__init__(message)

class IncorrectMessageKeyQuantityError(ValueError):
    """Error thrown when a message has an incorrect number of keys."""

    def __init__(self, keys: List[str]) -> None:
        self.keys = keys
        message = f'Expected 2 keys in message, but found {len(keys)}'
        super().__init__(message)

class InvalidRoleError(ValueError):
    """Error thrown when a role is invalid."""

    def __init__(self, role: str, valid_roles: set[str]) -> None:
        self.role = role
        self.valid_roles = valid_roles
        message = f'Expected role to be one of {valid_roles} but found: {role}'
        super().__init__(message)

class InvalidContentTypeError(TypeError):
    """Error thrown when the content type is invalid."""

    def __init__(self, content_type: type) -> None:
        self.content_type = content_type
        message = f'Expected content to be a string, but found {content_type}'
        super().__init__(message)

class InvalidPromptTypeError(TypeError):
    """Error thrown when the prompt type is invalid."""

    def __init__(self, prompt_type: type) -> None:
        self.prompt_type = prompt_type
        message = f'Expected prompt to be a string, but found {prompt_type}'
        super().__init__(message)

class InvalidResponseTypeError(TypeError):
    """Error thrown when the response type is invalid."""

    def __init__(self, response_type: type) -> None:
        self.response_type = response_type
        message = f'Expected response to be a string, but found {response_type}'
        super().__init__(message)

class InvalidPromptResponseKeysError(ValueError):
    """Error thrown when missing expected prompt and response keys."""

    def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
        self.example = example
        message = f'Expected mapping={mapping!r} to have keys "prompt" and "response".'
        super().__init__(message)

class InvalidFileExtensionError(FileNotFoundError):
    """Error thrown when a file extension is not a safe extension."""

    def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
        self.dataset_name = dataset_name
        self.valid_extensions = valid_extensions
        message = f'safe_load is set to True. No data files with safe extensions {valid_extensions} ' + f'found for dataset at local path {dataset_name}.'
        super().__init__(message)

class UnableToProcessPromptResponseError(ValueError):
    """Error thrown when a prompt and response cannot be processed."""

    def __init__(self, input: Dict) -> None:
        message = f'Unable to extract prompt/response from {input}'
        super().__init__(message)

class ClusterDoesNotExistError(ValueError):
    """Error thrown when the cluster does not exist."""

    def __init__(self, cluster_id: str) -> None:
        self.cluster_id = cluster_id
        message = f'Cluster with id {cluster_id} does not exist. Check cluster id and try again!'
        super().__init__(message)

class FailedToCreateSQLConnectionError(RuntimeError):
    """Error thrown when client can't sql connect to Databricks."""

    def __init__(self) -> None:
        message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
        super().__init__(message)

class FailedToConnectToDatabricksError(RuntimeError):
    """Error thrown when the client fails to connect to Databricks."""

    def __init__(self) -> None:
        message = 'Failed to create databricks connection. Check hostname and access token!'
        super().__init__(message)

class InputFolderMissingDataError(ValueError):
    """Error thrown when the input folder is missing data."""

    def __init__(self, input_folder: str) -> None:
        self.input_folder = input_folder
        message = f'No text files were found at {input_folder}.'
        super().__init__(message)

class OutputFolderNotEmptyError(FileExistsError):
    """Error thrown when the output folder is not empty."""

    def __init__(self, output_folder: str) -> None:
        self.output_folder = output_folder
        message = f'{output_folder} is not empty. Please remove or empty it and retry.'
        super().__init__(message)