import streamlit as st
import numpy as np
import plotly.express as px
import cv2
from src.error_analysis import ErrorAnalysis, transform_gt_bbox_format
import yaml
import os
from src.confusion_matrix import ConfusionMatrix
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd


def amend_cm_df(cm_df, labels_dict):
    """Helper function to amend the index and column name for readability
    Example - index currently is 0, 1 ... -> GT - person
    Likewise in Column - 0, 1 ... -> Pred - person etc

    Args:
        cm_df (_type_): confusion matrix dataframe.
        labels_dict (_type_): dictionary of the class labels

    Returns:
        cm_df: confusion matrix dataframe with index and column names filled
    """

    index_list = list(labels_dict.values())
    index_list.append("background")

    cm_df = cm_df.set_axis([f"GT - {elem}" for elem in index_list])
    cm_df = cm_df.set_axis([f"Pred - {elem}" for elem in index_list], axis=1)
    cm_df = cm_df.astype(int)

    return cm_df

def find_top_left_pos(mask):
    """gets the top left position of the mask 

    Args:
        mask (_type_): _description_

    Returns:
        _type_: _description_
    """

    return np.unravel_index(np.argmax(mask, axis=None), mask.shape)


class ImageTool:
    
    def __init__(self, cfg_path="cfg/cfg.yml"):

        # getting the config object                
        cfg_file = open(cfg_path)
        self.cfg_obj = yaml.load(cfg_file, Loader=yaml.FullLoader)
        
        # initialising the model and getting the annotations
        self.ea_obj = ErrorAnalysis(cfg_path)
        self.inference_folder = self.ea_obj.inference_folder
        self.ea_obj.get_annots()
        self.gt_annots = self.ea_obj.gt_dict
        self.all_img = os.listdir(self.inference_folder)
        self.ea_obj.model.score_threshold = self.cfg_obj["visual_tool"]["conf_threshold"]
        self.ea_obj.model.iou_threshold = self.cfg_obj["visual_tool"]["iou_threshold"]

        # for labels
        self.labels_dict = self.cfg_obj["error_analysis"]["labels_dict"]
        self.labels_dict = {v: k for k, v in self.labels_dict.items()}
        self.inference_labels_dict = self.cfg_obj["error_analysis"]["inference_labels_dict"]
        self.inference_labels_dict = {v: k for k, v in self.inference_labels_dict.items()}
        self.idx_base = self.cfg_obj["error_analysis"]["idx_base"]

        # for visualisation
        self.bbox_thickness = self.cfg_obj["visual_tool"]["bbox_thickness"]
        self.font_scale = self.cfg_obj["visual_tool"]["font_scale"]
        self.font_thickness = self.cfg_obj["visual_tool"]["font_thickness"]
        self.pred_colour = tuple(self.cfg_obj["visual_tool"]["pred_colour"])
        self.gt_colour = tuple(self.cfg_obj["visual_tool"]["gt_colour"])

    def show_img(self, img_fname="000000011149.jpg", show_preds=False, show_gt=False):
        """generate img with option to overlay with GT and/or preds

        Args:
            img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg".
            show_preds (bool, optional): Toggle True to run model to get the preds. Defaults to False.
            show_gt (bool, optional): Toggle True to get the GT labels/boxes. Defaults to False.

        Returns:
            fig (Plotly Figure): image with overlays if toggled True
            cm_df (pd.DataFrame): confusion matrix of the pred versus GT
            cm_tpfpfn_dict (Dict): confusion matrix dictionary of tp/fp/fn
        """

        # get the image's file path. Concatenates with the folder in question
        img = cv2.imread(f"{self.inference_folder}{img_fname}")

        labels = {"x": "X", "y": "Y", "color": "Colour"}

        if show_preds:

            preds = self.get_preds(img_fname)
            if self.ea_obj.task == "det":
                img = self.draw_pred_bboxes(img, preds)
            elif self.ea_obj.task == "seg":
                img = self.draw_pred_masks(img, preds)

        if show_gt:

            gt_annots = self.get_gt_annot(img_fname)

            if self.ea_obj.task == "det":
                img = self.draw_gt_bboxes(img, preds)
            elif self.ea_obj.task == "seg":
                img = self.draw_gt_masks(img, gt_annots)

        fig = px.imshow(img[..., ::-1], aspect="equal", labels=labels)

        if show_gt and show_preds:

            cm_df, cm_tpfpfn_dict = self.generate_cm_one_image(preds, gt_annots)
            return [fig, cm_df, cm_tpfpfn_dict]

        return fig

    def show_img_sbs(self, img_fname="000000011149.jpg"):
        """generate two imageso with confusion matrix and tp/fp/fn. fig1 is image with GT overlay, while fig2 is the image witih pred overlay.

        Args:
            img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg".

        Returns:
            list: fig1 - imshow of image with GT overlay
                  fig2 - imshow of image with pred overlay
                  cm_df - confusion matrix dataframe
                  cm_tpfpfn_df - confusion matrix dictionary of tp/fp/fn
        """

        # shows the image side by side
        img = cv2.imread(f"{self.inference_folder}{img_fname}")
        labels = {"x": "X", "y": "Y", "color": "Colour"}

        img_pred = img.copy()
        img_gt = img.copy()

        preds = self.get_preds(img_fname)

        gt_annots = self.get_gt_annot(img_fname)

        if self.ea_obj.task == 'det':
            img_pred = self.draw_pred_bboxes(img_pred, preds)
            img_gt = self.draw_gt_bboxes(img_gt, gt_annots)

        elif self.ea_obj.task == 'seg':
            img_pred = self.draw_pred_masks(img_pred, preds)
            img_gt = self.draw_gt_masks(img_gt, gt_annots)


        fig1 = px.imshow(img_gt[..., ::-1], aspect="equal", labels=labels)
        fig2 = px.imshow(img_pred[..., ::-1], aspect="equal", labels=labels)
        fig2.update_yaxes(visible=False)

        cm_df, cm_tpfpfn_df = self.generate_cm_one_image(preds, gt_annots)

        return [fig1, fig2, cm_df, cm_tpfpfn_df]

    def generate_cm_one_image(self, preds, gt_annots):
        """Generates confusion matrix between the inference and the Ground Truth of an image

        Args:
            preds (array): inference output of the model on the image
            gt_annots (array): Ground Truth labels of the image

        Returns:
            cm_df (DataFrame): Confusion matrix dataframe. 
            cm_tpfpfn_df (DataFrame): TP/FP/FN dataframe
        """

        num_classes = len(list(self.cfg_obj["error_analysis"]["labels_dict"].keys()))
        idx_base = self.cfg_obj["error_analysis"]["idx_base"]

        conf_threshold, iou_threshold = (
            self.ea_obj.model.score_threshold,
            self.ea_obj.model.iou_threshold,
        )
        cm = ConfusionMatrix(
            num_classes=num_classes,
            CONF_THRESHOLD=conf_threshold,
            IOU_THRESHOLD=iou_threshold,
        )
        if self.ea_obj.task == 'det':
            gt_annots[:, 0] -= idx_base
            preds[:, -1] -= idx_base
        elif self.ea_obj.task == 'seg':
            gt_annots = [[gt[0] - idx_base, gt[1]] for gt in gt_annots]

        cm.process_batch(preds, gt_annots, task = self.ea_obj.task)

        confusion_matrix_df = cm.return_as_df()
        cm.get_tpfpfn()
        
        cm_tpfpfn_dict = {
            "True Positive": cm.tp,
            "False Positive": cm.fp,
            "False Negative": cm.fn,
        }

        cm_tpfpfn_df = pd.DataFrame(cm_tpfpfn_dict, index=[0])
        cm_tpfpfn_df = cm_tpfpfn_df.set_axis(["Values"], axis=0)
        cm_tpfpfn_df = cm_tpfpfn_df.astype(int)
        # amend df

        confusion_matrix_df = amend_cm_df(confusion_matrix_df, self.labels_dict)
        # print (cm.matrix)

        return confusion_matrix_df, cm_tpfpfn_df

    def get_preds(self, img_fname="000000011149.jpg"):
        """Using the model in the Error Analysis object, run inference to get outputs

        Args:
            img_fname (str): Image filename. Defaults to "000000011149.jpg".

        Returns:
            outputs (array): Inference output of the model on the image
        """

        # run inference using the error analysis object per image
        outputs, img_shape = self.ea_obj.generate_inference(img_fname)

        if self.ea_obj.task == 'det':
            # converts image coordinates from normalised to integer values
            # image shape is [Y, X, C] (because Rows are Y)
            # So don't get confused!
            outputs[:, 0] *= img_shape[1]
            outputs[:, 1] *= img_shape[0]
            outputs[:, 2] *= img_shape[1]
            outputs[:, 3] *= img_shape[0]
    
        return outputs

    def get_gt_annot(self, img_fname):
        """Retrieve the Ground Truth annotations of the image. 

        Args:
            img_fname (_type_): Image filename

        Returns:
            grount_truth (array): GT labels of the image
        """
        ground_truth = self.gt_annots[img_fname].copy()
        img = cv2.imread(f"{self.inference_folder}{img_fname}")
        
        # converts image coordinates from normalised to integer values
        # image shape is [Y, X, C] (because Rows are Y)
        # So don't get confused!
        if self.ea_obj.task == 'det':
            img_shape = img.shape
            ground_truth = transform_gt_bbox_format(ground_truth, img_shape, format="coco")
            ground_truth[:, 1] *= img_shape[1]
            ground_truth[:, 2] *= img_shape[0]
            ground_truth[:, 3] *= img_shape[1]
            ground_truth[:, 4] *= img_shape[0]

        return ground_truth
    
    def draw_pred_masks(self, img_pred, inference_outputs):
        """Overlay mask onto img_pred

        Args:
            img_pred (_type_): _description_
            preds (_type_): _description_
        """

        pred_mask = sum([output[0] for output in inference_outputs])
        pred_mask = np.where(pred_mask > 1, 1, pred_mask)
        # mask_3d = np.stack((mask,mask,mask),axis=0)
        # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3)
        colour = np.array(self.pred_colour, dtype='uint8')
        masked_img = np.where(pred_mask[...,None], colour, img_pred)
        masked_img = masked_img.astype(np.uint8)

        img_pred = cv2.addWeighted(img_pred, 0.7, masked_img, 0.3, 0)
        
        def put_text_ina_mask(output, img):
            
            coords = find_top_left_pos(output[0])

            img = cv2.putText(img, self.inference_labels_dict[output[2]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, 
                        color = self.pred_colour, thickness = self.font_thickness)

            return img
        
        for output in inference_outputs:
            img_pred = put_text_ina_mask(output, img_pred)

        return img_pred
    
    def draw_gt_masks(self, img_gt, gt_outputs):
        """Overlay mask onto img_pred

        Args:
            img_pred (_type_): _description_
            preds (_type_): _description_
        """

        gt_mask = sum([output[1] for output in gt_outputs])
        gt_mask = np.where(gt_mask > 1, 1, gt_mask)
        # mask_3d = np.stack((mask,mask,mask),axis=0)
        # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3)
        colour = np.array(self.gt_colour, dtype='uint8')
        masked_img = np.where(gt_mask[...,None], colour, img_gt)
                
        def put_text_ina_mask(output, img):

            coords = find_top_left_pos(output[1])

            img = cv2.putText(img, self.labels_dict[output[0]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, 
                        color = self.gt_colour, thickness = self.font_thickness)

            return img
        
        img_gt = cv2.addWeighted(img_gt, 0.7, masked_img, 0.3,0)

        for output in gt_outputs:
            img_gt = put_text_ina_mask(output, img_gt)

        return img_gt

    def draw_pred_bboxes(self, img_pred, preds):
        """Draws the preds onto the image

        Args:
            img_pred (array): image 
            preds (array): model inference outputs 

        Returns:
            img_pred (array): image with outputs on overlay
        """
        for pred in preds:
            pred = pred.astype(int)
            img_pred = cv2.rectangle(
                img_pred,
                (pred[0], pred[1]),
                (pred[2], pred[3]),
                color=self.pred_colour,
                thickness=self.bbox_thickness,
            )
            img_pred = cv2.putText(
                img_pred,
                self.labels_dict[pred[5]],
                (pred[0] + 5, pred[1] + 25),
                color=self.pred_colour,
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=self.font_scale,
                thickness=self.font_thickness,
            )
        return img_pred

    def draw_gt_bboxes(self, img_gt, gt_annots, **kwargs):
        """Draws the GT onto the image

        Args:
            img_gt (array): image
            gt_annots (array): GT labels

        Returns:
            img_gt (array): image with GT overlay
        """
        for annot in gt_annots:
            annot = annot.astype(int)
            # print (annot)
            img_gt = cv2.rectangle(
                img_gt,
                (annot[1], annot[2]),
                (annot[3], annot[4]),
                color=self.gt_colour,
                thickness=self.bbox_thickness,
            )
            img_gt = cv2.putText(
                img_gt,
                self.labels_dict[annot[0]],
                (annot[1] + 5, annot[2] + 25),
                color=(0, 255, 0),
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=self.font_scale,
                thickness=self.font_thickness,
            )
        return img_gt

    def plot_with_preds_gt(self, option, side_by_side=False, plot_type=None):
        """Rules on what plot to generate

        Args:
            option (_string_): image filename. Toggled on the app itself. See app.py
            side_by_side (bool, optional): Whether to have two plots side by side.
                                            Defaults to False.
            plot_type (_type_, optional): "all" - both GT and pred will be plotted,
                                "pred" - only preds,
                                "GT" - only ground truth
                                None - only image generated
                                Will be overridden if side_by_side = True
                                Defaults to None.
        """

        if plot_type == "all":
            plot, df, cm_tpfpfn_df = self.show_img(
                option, show_preds=True, show_gt=True
            )
            st.plotly_chart(plot, use_container_width=True)
            st.caption("Blue: Model BBox, Green: GT BBox")

            st.table(df)
            st.table(cm_tpfpfn_df)

        elif plot_type == "pred":
            st.plotly_chart(
                self.show_img(option, show_preds=True), use_container_width=True
            )

        elif plot_type == "gt":
            st.plotly_chart(
                self.show_img(option, show_gt=True), use_container_width=True
            )

        elif side_by_side:

            plot1, plot2, df, cm_tpfpfn_df = self.show_img_sbs(option)
            col1, col2 = st.columns(2)

            with col1:
                col1.subheader("Ground Truth")
                st.plotly_chart(plot1, use_container_width=True)
            with col2:
                col2.subheader("Prediction")
                st.plotly_chart(plot2, use_container_width=True)

            st.table(df)
            st.table(cm_tpfpfn_df)

        else:
            st.plotly_chart(self.show_img(option), use_container_width=True)