import streamlit as st import numpy as np import plotly.express as px import cv2 from src.error_analysis import ErrorAnalysis, transform_gt_bbox_format import yaml import os from src.confusion_matrix import ConfusionMatrix from plotly.subplots import make_subplots import plotly.graph_objects as go import pandas as pd def amend_cm_df(cm_df, labels_dict): """Helper function to amend the index and column name for readability Example - index currently is 0, 1 ... -> GT - person Likewise in Column - 0, 1 ... -> Pred - person etc Args: cm_df (_type_): confusion matrix dataframe. labels_dict (_type_): dictionary of the class labels Returns: cm_df: confusion matrix dataframe with index and column names filled """ index_list = list(labels_dict.values()) index_list.append("background") cm_df = cm_df.set_axis([f"GT - {elem}" for elem in index_list]) cm_df = cm_df.set_axis([f"Pred - {elem}" for elem in index_list], axis=1) cm_df = cm_df.astype(int) return cm_df def find_top_left_pos(mask): """gets the top left position of the mask Args: mask (_type_): _description_ Returns: _type_: _description_ """ return np.unravel_index(np.argmax(mask, axis=None), mask.shape) class ImageTool: def __init__(self, cfg_path="cfg/cfg.yml"): # getting the config object cfg_file = open(cfg_path) self.cfg_obj = yaml.load(cfg_file, Loader=yaml.FullLoader) # initialising the model and getting the annotations self.ea_obj = ErrorAnalysis(cfg_path) self.inference_folder = self.ea_obj.inference_folder self.ea_obj.get_annots() self.gt_annots = self.ea_obj.gt_dict self.all_img = os.listdir(self.inference_folder) self.ea_obj.model.score_threshold = self.cfg_obj["visual_tool"]["conf_threshold"] self.ea_obj.model.iou_threshold = self.cfg_obj["visual_tool"]["iou_threshold"] # for labels self.labels_dict = self.cfg_obj["error_analysis"]["labels_dict"] self.labels_dict = {v: k for k, v in self.labels_dict.items()} self.inference_labels_dict = self.cfg_obj["error_analysis"]["inference_labels_dict"] self.inference_labels_dict = {v: k for k, v in self.inference_labels_dict.items()} self.idx_base = self.cfg_obj["error_analysis"]["idx_base"] # for visualisation self.bbox_thickness = self.cfg_obj["visual_tool"]["bbox_thickness"] self.font_scale = self.cfg_obj["visual_tool"]["font_scale"] self.font_thickness = self.cfg_obj["visual_tool"]["font_thickness"] self.pred_colour = tuple(self.cfg_obj["visual_tool"]["pred_colour"]) self.gt_colour = tuple(self.cfg_obj["visual_tool"]["gt_colour"]) def show_img(self, img_fname="000000011149.jpg", show_preds=False, show_gt=False): """generate img with option to overlay with GT and/or preds Args: img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg". show_preds (bool, optional): Toggle True to run model to get the preds. Defaults to False. show_gt (bool, optional): Toggle True to get the GT labels/boxes. Defaults to False. Returns: fig (Plotly Figure): image with overlays if toggled True cm_df (pd.DataFrame): confusion matrix of the pred versus GT cm_tpfpfn_dict (Dict): confusion matrix dictionary of tp/fp/fn """ # get the image's file path. Concatenates with the folder in question img = cv2.imread(f"{self.inference_folder}{img_fname}") labels = {"x": "X", "y": "Y", "color": "Colour"} if show_preds: preds = self.get_preds(img_fname) if self.ea_obj.task == "det": img = self.draw_pred_bboxes(img, preds) elif self.ea_obj.task == "seg": img = self.draw_pred_masks(img, preds) if show_gt: gt_annots = self.get_gt_annot(img_fname) if self.ea_obj.task == "det": img = self.draw_gt_bboxes(img, preds) elif self.ea_obj.task == "seg": img = self.draw_gt_masks(img, gt_annots) fig = px.imshow(img[..., ::-1], aspect="equal", labels=labels) if show_gt and show_preds: cm_df, cm_tpfpfn_dict = self.generate_cm_one_image(preds, gt_annots) return [fig, cm_df, cm_tpfpfn_dict] return fig def show_img_sbs(self, img_fname="000000011149.jpg"): """generate two imageso with confusion matrix and tp/fp/fn. fig1 is image with GT overlay, while fig2 is the image witih pred overlay. Args: img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg". Returns: list: fig1 - imshow of image with GT overlay fig2 - imshow of image with pred overlay cm_df - confusion matrix dataframe cm_tpfpfn_df - confusion matrix dictionary of tp/fp/fn """ # shows the image side by side img = cv2.imread(f"{self.inference_folder}{img_fname}") labels = {"x": "X", "y": "Y", "color": "Colour"} img_pred = img.copy() img_gt = img.copy() preds = self.get_preds(img_fname) gt_annots = self.get_gt_annot(img_fname) if self.ea_obj.task == 'det': img_pred = self.draw_pred_bboxes(img_pred, preds) img_gt = self.draw_gt_bboxes(img_gt, gt_annots) elif self.ea_obj.task == 'seg': img_pred = self.draw_pred_masks(img_pred, preds) img_gt = self.draw_gt_masks(img_gt, gt_annots) fig1 = px.imshow(img_gt[..., ::-1], aspect="equal", labels=labels) fig2 = px.imshow(img_pred[..., ::-1], aspect="equal", labels=labels) fig2.update_yaxes(visible=False) cm_df, cm_tpfpfn_df = self.generate_cm_one_image(preds, gt_annots) return [fig1, fig2, cm_df, cm_tpfpfn_df] def generate_cm_one_image(self, preds, gt_annots): """Generates confusion matrix between the inference and the Ground Truth of an image Args: preds (array): inference output of the model on the image gt_annots (array): Ground Truth labels of the image Returns: cm_df (DataFrame): Confusion matrix dataframe. cm_tpfpfn_df (DataFrame): TP/FP/FN dataframe """ num_classes = len(list(self.cfg_obj["error_analysis"]["labels_dict"].keys())) idx_base = self.cfg_obj["error_analysis"]["idx_base"] conf_threshold, iou_threshold = ( self.ea_obj.model.score_threshold, self.ea_obj.model.iou_threshold, ) cm = ConfusionMatrix( num_classes=num_classes, CONF_THRESHOLD=conf_threshold, IOU_THRESHOLD=iou_threshold, ) if self.ea_obj.task == 'det': gt_annots[:, 0] -= idx_base preds[:, -1] -= idx_base elif self.ea_obj.task == 'seg': gt_annots = [[gt[0] - idx_base, gt[1]] for gt in gt_annots] cm.process_batch(preds, gt_annots, task = self.ea_obj.task) confusion_matrix_df = cm.return_as_df() cm.get_tpfpfn() cm_tpfpfn_dict = { "True Positive": cm.tp, "False Positive": cm.fp, "False Negative": cm.fn, } cm_tpfpfn_df = pd.DataFrame(cm_tpfpfn_dict, index=[0]) cm_tpfpfn_df = cm_tpfpfn_df.set_axis(["Values"], axis=0) cm_tpfpfn_df = cm_tpfpfn_df.astype(int) # amend df confusion_matrix_df = amend_cm_df(confusion_matrix_df, self.labels_dict) # print (cm.matrix) return confusion_matrix_df, cm_tpfpfn_df def get_preds(self, img_fname="000000011149.jpg"): """Using the model in the Error Analysis object, run inference to get outputs Args: img_fname (str): Image filename. Defaults to "000000011149.jpg". Returns: outputs (array): Inference output of the model on the image """ # run inference using the error analysis object per image outputs, img_shape = self.ea_obj.generate_inference(img_fname) if self.ea_obj.task == 'det': # converts image coordinates from normalised to integer values # image shape is [Y, X, C] (because Rows are Y) # So don't get confused! outputs[:, 0] *= img_shape[1] outputs[:, 1] *= img_shape[0] outputs[:, 2] *= img_shape[1] outputs[:, 3] *= img_shape[0] return outputs def get_gt_annot(self, img_fname): """Retrieve the Ground Truth annotations of the image. Args: img_fname (_type_): Image filename Returns: grount_truth (array): GT labels of the image """ ground_truth = self.gt_annots[img_fname].copy() img = cv2.imread(f"{self.inference_folder}{img_fname}") # converts image coordinates from normalised to integer values # image shape is [Y, X, C] (because Rows are Y) # So don't get confused! if self.ea_obj.task == 'det': img_shape = img.shape ground_truth = transform_gt_bbox_format(ground_truth, img_shape, format="coco") ground_truth[:, 1] *= img_shape[1] ground_truth[:, 2] *= img_shape[0] ground_truth[:, 3] *= img_shape[1] ground_truth[:, 4] *= img_shape[0] return ground_truth def draw_pred_masks(self, img_pred, inference_outputs): """Overlay mask onto img_pred Args: img_pred (_type_): _description_ preds (_type_): _description_ """ pred_mask = sum([output[0] for output in inference_outputs]) pred_mask = np.where(pred_mask > 1, 1, pred_mask) # mask_3d = np.stack((mask,mask,mask),axis=0) # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3) colour = np.array(self.pred_colour, dtype='uint8') masked_img = np.where(pred_mask[...,None], colour, img_pred) masked_img = masked_img.astype(np.uint8) img_pred = cv2.addWeighted(img_pred, 0.7, masked_img, 0.3, 0) def put_text_ina_mask(output, img): coords = find_top_left_pos(output[0]) img = cv2.putText(img, self.inference_labels_dict[output[2]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, color = self.pred_colour, thickness = self.font_thickness) return img for output in inference_outputs: img_pred = put_text_ina_mask(output, img_pred) return img_pred def draw_gt_masks(self, img_gt, gt_outputs): """Overlay mask onto img_pred Args: img_pred (_type_): _description_ preds (_type_): _description_ """ gt_mask = sum([output[1] for output in gt_outputs]) gt_mask = np.where(gt_mask > 1, 1, gt_mask) # mask_3d = np.stack((mask,mask,mask),axis=0) # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3) colour = np.array(self.gt_colour, dtype='uint8') masked_img = np.where(gt_mask[...,None], colour, img_gt) def put_text_ina_mask(output, img): coords = find_top_left_pos(output[1]) img = cv2.putText(img, self.labels_dict[output[0]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, color = self.gt_colour, thickness = self.font_thickness) return img img_gt = cv2.addWeighted(img_gt, 0.7, masked_img, 0.3,0) for output in gt_outputs: img_gt = put_text_ina_mask(output, img_gt) return img_gt def draw_pred_bboxes(self, img_pred, preds): """Draws the preds onto the image Args: img_pred (array): image preds (array): model inference outputs Returns: img_pred (array): image with outputs on overlay """ for pred in preds: pred = pred.astype(int) img_pred = cv2.rectangle( img_pred, (pred[0], pred[1]), (pred[2], pred[3]), color=self.pred_colour, thickness=self.bbox_thickness, ) img_pred = cv2.putText( img_pred, self.labels_dict[pred[5]], (pred[0] + 5, pred[1] + 25), color=self.pred_colour, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=self.font_scale, thickness=self.font_thickness, ) return img_pred def draw_gt_bboxes(self, img_gt, gt_annots, **kwargs): """Draws the GT onto the image Args: img_gt (array): image gt_annots (array): GT labels Returns: img_gt (array): image with GT overlay """ for annot in gt_annots: annot = annot.astype(int) # print (annot) img_gt = cv2.rectangle( img_gt, (annot[1], annot[2]), (annot[3], annot[4]), color=self.gt_colour, thickness=self.bbox_thickness, ) img_gt = cv2.putText( img_gt, self.labels_dict[annot[0]], (annot[1] + 5, annot[2] + 25), color=(0, 255, 0), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=self.font_scale, thickness=self.font_thickness, ) return img_gt def plot_with_preds_gt(self, option, side_by_side=False, plot_type=None): """Rules on what plot to generate Args: option (_string_): image filename. Toggled on the app itself. See app.py side_by_side (bool, optional): Whether to have two plots side by side. Defaults to False. plot_type (_type_, optional): "all" - both GT and pred will be plotted, "pred" - only preds, "GT" - only ground truth None - only image generated Will be overridden if side_by_side = True Defaults to None. """ if plot_type == "all": plot, df, cm_tpfpfn_df = self.show_img( option, show_preds=True, show_gt=True ) st.plotly_chart(plot, use_container_width=True) st.caption("Blue: Model BBox, Green: GT BBox") st.table(df) st.table(cm_tpfpfn_df) elif plot_type == "pred": st.plotly_chart( self.show_img(option, show_preds=True), use_container_width=True ) elif plot_type == "gt": st.plotly_chart( self.show_img(option, show_gt=True), use_container_width=True ) elif side_by_side: plot1, plot2, df, cm_tpfpfn_df = self.show_img_sbs(option) col1, col2 = st.columns(2) with col1: col1.subheader("Ground Truth") st.plotly_chart(plot1, use_container_width=True) with col2: col2.subheader("Prediction") st.plotly_chart(plot2, use_container_width=True) st.table(df) st.table(cm_tpfpfn_df) else: st.plotly_chart(self.show_img(option), use_container_width=True)