from typing import List, Optional, Tuple

import numpy as np

from inference.core.entities.responses.inference import (
    InferenceResponseImage,
    Keypoint,
    KeypointsDetectionInferenceResponse,
    KeypointsPrediction,
)
from inference.core.exceptions import ModelArtefactError
from inference.core.models.object_detection_base import (
    ObjectDetectionBaseOnnxRoboflowInferenceModel,
)
from inference.core.models.types import PreprocessReturnMetadata
from inference.core.models.utils.keypoints import model_keypoints_to_response
from inference.core.models.utils.validate import (
    get_num_classes_from_model_prediction_shape,
)
from inference.core.nms import w_np_non_max_suppression
from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints

DEFAULT_CONFIDENCE = 0.4
DEFAULT_IOU_THRESH = 0.3
DEFAULT_CLASS_AGNOSTIC_NMS = False
DEFAUlT_MAX_DETECTIONS = 300
DEFAULT_MAX_CANDIDATES = 3000


class KeypointsDetectionBaseOnnxRoboflowInferenceModel(
    ObjectDetectionBaseOnnxRoboflowInferenceModel
):
    """Roboflow ONNX Object detection model. This class implements an object detection specific infer method."""

    task_type = "keypoint-detection"

    def __init__(self, model_id: str, *args, **kwargs):
        super().__init__(model_id, *args, **kwargs)

    def get_infer_bucket_file_list(self) -> list:
        """Returns the list of files to be downloaded from the inference bucket for ONNX model.

        Returns:
            list: A list of filenames specific to ONNX models.
        """
        return ["environment.json", "class_names.txt", "keypoints_metadata.json"]

    def postprocess(
        self,
        predictions: Tuple[np.ndarray],
        preproc_return_metadata: PreprocessReturnMetadata,
        class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS,
        confidence: float = DEFAULT_CONFIDENCE,
        iou_threshold: float = DEFAULT_IOU_THRESH,
        max_candidates: int = DEFAULT_MAX_CANDIDATES,
        max_detections: int = DEFAUlT_MAX_DETECTIONS,
        return_image_dims: bool = False,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Postprocesses the object detection predictions.

        Args:
            predictions (np.ndarray): Raw predictions from the model.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False.
            confidence (float): Confidence threshold for filtering detections. Default is 0.5.
            iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5.
            max_candidates (int): Maximum number of candidate detections. Default is 3000.
            max_detections (int): Maximum number of final detections. Default is 300.

        Returns:
            List[KeypointsDetectionInferenceResponse]: The post-processed predictions.
        """
        predictions = predictions[0]
        number_of_classes = len(self.get_class_names)
        num_masks = predictions.shape[2] - 5 - number_of_classes
        predictions = w_np_non_max_suppression(
            predictions,
            conf_thresh=confidence,
            iou_thresh=iou_threshold,
            class_agnostic=class_agnostic_nms,
            max_detections=max_detections,
            max_candidate_detections=max_candidates,
            num_masks=num_masks,
        )

        infer_shape = (self.img_size_h, self.img_size_w)
        img_dims = preproc_return_metadata["img_dims"]
        predictions = post_process_bboxes(
            predictions=predictions,
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        predictions = post_process_keypoints(
            predictions=predictions,
            keypoints_start_index=-num_masks,
            infer_shape=infer_shape,
            img_dims=img_dims,
            preproc=self.preproc,
            resize_method=self.resize_method,
            disable_preproc_static_crop=preproc_return_metadata[
                "disable_preproc_static_crop"
            ],
        )
        return self.make_response(predictions, img_dims, **kwargs)

    def make_response(
        self,
        predictions: List[List[float]],
        img_dims: List[Tuple[int, int]],
        class_filter: Optional[List[str]] = None,
        *args,
        **kwargs,
    ) -> List[KeypointsDetectionInferenceResponse]:
        """Constructs object detection response objects based on predictions.

        Args:
            predictions (List[List[float]]): The list of predictions.
            img_dims (List[Tuple[int, int]]): Dimensions of the images.
            class_filter (Optional[List[str]]): A list of class names to filter, if provided.

        Returns:
            List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions.
        """
        if isinstance(img_dims, dict) and "img_dims" in img_dims:
            img_dims = img_dims["img_dims"]
        keypoint_confidence_threshold = 0.0
        if "request" in kwargs:
            keypoint_confidence_threshold = kwargs["request"].keypoint_confidence
        responses = [
            KeypointsDetectionInferenceResponse(
                predictions=[
                    KeypointsPrediction(
                        # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python)
                        **{
                            "x": (pred[0] + pred[2]) / 2,
                            "y": (pred[1] + pred[3]) / 2,
                            "width": pred[2] - pred[0],
                            "height": pred[3] - pred[1],
                            "confidence": pred[4],
                            "class": self.class_names[int(pred[6])],
                            "class_id": int(pred[6]),
                            "keypoints": model_keypoints_to_response(
                                keypoints_metadata=self.keypoints_metadata,
                                keypoints=pred[7:],
                                predicted_object_class_id=int(
                                    pred[4 + len(self.get_class_names)]
                                ),
                                keypoint_confidence_threshold=keypoint_confidence_threshold,
                            ),
                        }
                    )
                    for pred in batch_predictions
                    if not class_filter
                    or self.class_names[int(pred[6])] in class_filter
                ],
                image=InferenceResponseImage(
                    width=img_dims[ind][1], height=img_dims[ind][0]
                ),
            )
            for ind, batch_predictions in enumerate(predictions)
        ]
        return responses

    def keypoints_count(self) -> int:
        raise NotImplementedError

    def validate_model_classes(self) -> None:
        num_keypoints = self.keypoints_count()
        output_shape = self.get_model_output_shape()
        num_classes = get_num_classes_from_model_prediction_shape(
            len_prediction=output_shape[2], keypoints=num_keypoints
        )
        if num_classes != self.num_classes:
            raise ValueError(
                f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})"
            )