import os
import logging
import torch
import cv2
import numpy as np

from typing import List, Dict, Optional
from label_studio_ml.utils import get_image_local_path, InMemoryLRUDictCache

logger = logging.getLogger(__name__)

VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT")
ONNX_CHECKPOINT = os.environ.get("ONNX_CHECKPOINT")
MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt")
LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN")
LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST")


class SAMPredictor(object):

    def __init__(self, model_choice):
        self.model_choice = model_choice

        # cache for embeddings
        # TODO: currently it supports only one image in cache,
        #   since predictor.set_image() should be called each time the new image comes
        #   before making predictions
        #   to extend it to >1 image, we need to store the "active image" state in the cache
        self.cache = InMemoryLRUDictCache(1)

        # if you're not using CUDA, use "cpu" instead .... good luck not burning your computer lol
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.debug(f"Using device {self.device}")

        if model_choice == 'ONNX':
            import onnxruntime
            from segment_anything import sam_model_registry, SamPredictor

            self.model_checkpoint = VITH_CHECKPOINT
            if self.model_checkpoint is None:
                raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")
            if ONNX_CHECKPOINT is None:
                raise FileNotFoundError("ONNX_CHECKPOINT is not set: please set it to the path to the ONNX checkpoint")
            logger.info(f"Using ONNX checkpoint {ONNX_CHECKPOINT} and SAM checkpoint {self.model_checkpoint}")

            self.ort = onnxruntime.InferenceSession(ONNX_CHECKPOINT)
            reg_key = "vit_h"

        elif model_choice == 'SAM':
            from segment_anything import SamPredictor, sam_model_registry

            self.model_checkpoint = VITH_CHECKPOINT
            if self.model_checkpoint is None:
                raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")

            logger.info(f"Using SAM checkpoint {self.model_checkpoint}")
            reg_key = "vit_h"

        elif model_choice == 'MobileSAM':
            from mobile_sam import SamPredictor, sam_model_registry

            self.model_checkpoint = MOBILESAM_CHECKPOINT
            if not self.model_checkpoint:
                raise FileNotFoundError("MOBILE_CHECKPOINT is not set: please set it to the path to the MobileSAM checkpoint")
            logger.info(f"Using MobileSAM checkpoint {self.model_checkpoint}")
            reg_key = 'vit_t'
        else:
            raise ValueError(f"Invalid model choice {model_choice}")

        sam = sam_model_registry[reg_key](checkpoint=self.model_checkpoint)
        sam.to(device=self.device)
        self.predictor = SamPredictor(sam)

    @property
    def model_name(self):
        return f'{self.model_choice}:{self.model_checkpoint}:{self.device}'

    def set_image(self, img_path, calculate_embeddings=True):
        payload = self.cache.get(img_path)
        if payload is None:
            # Get image and embeddings
            logger.debug(f'Payload not found for {img_path} in `IN_MEM_CACHE`: calculating from scratch')
            image_path = get_image_local_path(
                img_path,
                label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN,
                label_studio_host=LABEL_STUDIO_HOST
            )
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            self.predictor.set_image(image)
            payload = {'image_shape': image.shape[:2]}
            logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}')
            if calculate_embeddings:
                image_embedding = self.predictor.get_image_embedding().cpu().numpy()
                payload['image_embedding'] = image_embedding
                logger.debug(f'Finished storing embeddings for {img_path} in `IN_MEM_CACHE`: '
                             f'embedding shape {image_embedding.shape}')
            self.cache.put(img_path, payload)
        else:
            logger.debug(f"Using embeddings for {img_path} from `IN_MEM_CACHE`")
        return payload

    def predict_onnx(
        self,
        img_path,
        point_coords: Optional[List[List]] = None,
        point_labels: Optional[List] = None,
        input_box: Optional[List] = None
    ):
        # calculate embeddings
        payload = self.set_image(img_path, calculate_embeddings=True)
        image_shape = payload['image_shape']
        image_embedding = payload['image_embedding']

        onnx_point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
        onnx_point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
        onnx_box_coords = np.array(input_box, dtype=np.float32).reshape(2, 2) if input_box else None

        onnx_coords, onnx_labels = None, None
        if onnx_point_coords is not None and onnx_box_coords is not None:
            # both keypoints and boxes are present
            onnx_coords = np.concatenate([onnx_point_coords, onnx_box_coords], axis=0)[None, :, :]
            onnx_labels = np.concatenate([onnx_point_labels, np.array([2, 3])], axis=0)[None, :].astype(np.float32)

        elif onnx_point_coords is not None:
            # only keypoints are present
            onnx_coords = np.concatenate([onnx_point_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
            onnx_labels = np.concatenate([onnx_point_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)

        elif onnx_box_coords is not None:
            # only boxes are present
            raise NotImplementedError("Boxes without keypoints are not supported yet")

        onnx_coords = self.predictor.transform.apply_coords(onnx_coords, image_shape).astype(np.float32)

        # TODO: support mask inputs
        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)

        onnx_has_mask_input = np.zeros(1, dtype=np.float32)

        ort_inputs = {
            "image_embeddings": image_embedding,
            "point_coords": onnx_coords,
            "point_labels": onnx_labels,
            "mask_input": onnx_mask_input,
            "has_mask_input": onnx_has_mask_input,
            "orig_im_size": np.array(image_shape, dtype=np.float32)
        }

        masks, prob, low_res_logits = self.ort.run(None, ort_inputs)
        masks = masks > self.predictor.model.mask_threshold
        mask = masks[0, 0, :, :].astype(np.uint8)  # each mask has shape [H, W]
        prob = float(prob[0][0])
        # TODO: support the real multimask output as in https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
        return {
            'masks': [mask],
            'probs': [prob]
        }

    def predict_sam(
        self,
        img_path,
        point_coords: Optional[List[List]] = None,
        point_labels: Optional[List] = None,
        input_box: Optional[List] = None
    ):
        self.set_image(img_path, calculate_embeddings=False)
        point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
        point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
        input_box = np.array(input_box, dtype=np.float32) if input_box else None

        masks, probs, logits = self.predictor.predict(
            point_coords=point_coords,
            point_labels=point_labels,
            box=input_box,
            # TODO: support multimask output
            multimask_output=False
        )
        mask = masks[0, :, :].astype(np.uint8)  # each mask has shape [H, W]
        prob = float(probs[0])
        return {
            'masks': [mask],
            'probs': [prob]
        }

    def predict(
        self, img_path: str,
        point_coords: Optional[List[List]] = None,
        point_labels: Optional[List] = None,
        input_box: Optional[List] = None
    ):
        if self.model_choice == 'ONNX':
            return self.predict_onnx(img_path, point_coords, point_labels, input_box)
        elif self.model_choice in ('SAM', 'MobileSAM'):
            return self.predict_sam(img_path, point_coords, point_labels, input_box)
        else:
            raise NotImplementedError(f"Model choice {self.model_choice} is not supported yet")