error_analysis_obj_det / src /st_image_tools.py
tappyness1
initial commit
b78b0dc
raw
history blame
10.9 kB
import streamlit as st
import numpy as np
import plotly.express as px
import cv2
from src.error_analysis import ErrorAnalysis, transform_gt_bbox_format
import yaml
import os
from src.confusion_matrix import ConfusionMatrix
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd
def amend_cm_df(cm_df, labels_dict):
"""Helper function to amend the index and column name for readability
Example - index currently is 0, 1 ... -> GT - person
Likewise in Column - 0, 1 ... -> Pred - person etc
Args:
cm_df (_type_): _description_
labels_dict (_type_): _description_
Returns:
_type_: _description_
"""
index_list = list(labels_dict.values())
index_list.append("background")
cm_df = cm_df.set_axis([f"GT - {elem}" for elem in index_list])
cm_df = cm_df.set_axis([f"Pred - {elem}" for elem in index_list], axis=1)
cm_df = cm_df.astype(int)
return cm_df
class ImageTool:
def __init__(self, cfg_path="cfg/cfg.yml"):
# inistialising the model and getting the annotations
self.ea_obj = ErrorAnalysis(cfg_path)
cfg_file = open(cfg_path)
self.cfg_obj = yaml.load(cfg_file, Loader=yaml.FullLoader)
self.inference_folder = self.ea_obj.inference_folder
self.ea_obj.get_annots()
self.gt_annots = self.ea_obj.gt_dict
self.all_img = os.listdir(self.inference_folder)
# for labels
self.labels_dict = self.cfg_obj["error_analysis"]["labels_dict"]
self.labels_dict = {v: k for k, v in self.labels_dict.items()}
self.idx_base = self.cfg_obj["error_analysis"]["idx_base"]
# for visualisation
self.bbox_thickness = self.cfg_obj["visual_tool"]["bbox_thickness"]
self.font_scale = self.cfg_obj["visual_tool"]["font_scale"]
self.font_thickness = self.cfg_obj["visual_tool"]["font_thickness"]
self.pred_colour = tuple(self.cfg_obj["visual_tool"]["pred_colour"])
self.gt_colour = tuple(self.cfg_obj["visual_tool"]["gt_colour"])
def show_img(self, img_fname="000000011149.jpg", show_preds=False, show_gt=False):
"""_summary_
Args:
img_fname (str, optional): _description_. Defaults to "000000011149.jpg".
show_preds (bool, optional): _description_. Defaults to False.
show_gt (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
img = cv2.imread(f"{self.inference_folder}{img_fname}")
labels = {"x": "X", "y": "Y", "color": "Colour"}
if show_preds:
preds = self.get_preds(img_fname)
img = self.draw_pred_bboxes(img, preds)
if show_gt:
gt_annots = self.get_gt_annot(img_fname)
img = self.draw_gt_bboxes(img, gt_annots)
fig = px.imshow(img[..., ::-1], aspect="equal", labels=labels)
if show_gt and show_preds:
cm_df, cm_tpfpfn_dict = self.generate_cm_one_image(preds, gt_annots)
return [fig, cm_df, cm_tpfpfn_dict]
return fig
def show_img_sbs(self, img_fname="000000011149.jpg"):
"""_summary_
Args:
img_fname (str, optional): _description_. Defaults to "000000011149.jpg".
Returns:
_type_: _description_
"""
# shows the image side by side
img = cv2.imread(f"{self.inference_folder}{img_fname}")
labels = {"x": "X", "y": "Y", "color": "Colour"}
img_pred = img.copy()
img_gt = img.copy()
preds = self.get_preds(img_fname)
img_pred = self.draw_pred_bboxes(img_pred, preds)
gt_annots = self.get_gt_annot(img_fname)
img_gt = self.draw_gt_bboxes(img_gt, gt_annots)
fig1 = px.imshow(img_gt[..., ::-1], aspect="equal", labels=labels)
fig2 = px.imshow(img_pred[..., ::-1], aspect="equal", labels=labels)
fig2.update_yaxes(visible=False)
cm_df, cm_tpfpfn_df = self.generate_cm_one_image(preds, gt_annots)
return [fig1, fig2, cm_df, cm_tpfpfn_df]
def generate_cm_one_image(self, preds, gt_annots):
"""_summary_
Args:
preds (_type_): _description_
gt_annots (_type_): _description_
Returns:
_type_: _description_
"""
num_classes = len(list(self.cfg_obj["error_analysis"]["labels_dict"].keys()))
idx_base = self.cfg_obj["error_analysis"]["idx_base"]
conf_threshold, iou_threshold = (
self.ea_obj.model.score_threshold,
self.ea_obj.model.iou_threshold,
)
cm = ConfusionMatrix(
num_classes=num_classes,
CONF_THRESHOLD=conf_threshold,
IOU_THRESHOLD=iou_threshold,
)
gt_annots[:, 0] -= idx_base
preds[:, -1] -= idx_base
cm.process_batch(preds, gt_annots)
confusion_matrix_df = cm.return_as_df()
cm.get_tpfpfn()
cm_tpfpfn_dict = {
"True Positive": cm.tp,
"False Positive": cm.fp,
"False Negative": cm.fn,
}
cm_tpfpfn_df = pd.DataFrame(cm_tpfpfn_dict, index=[0])
cm_tpfpfn_df = cm_tpfpfn_df.set_axis(["Values"], axis=0)
cm_tpfpfn_df = cm_tpfpfn_df.astype(int)
# amend df
confusion_matrix_df = amend_cm_df(confusion_matrix_df, self.labels_dict)
# print (cm.matrix)
return confusion_matrix_df, cm_tpfpfn_df
def get_preds(self, img_fname="000000011149.jpg"):
"""_summary_
Args:
img_fname (str, optional): _description_. Defaults to "000000011149.jpg".
Returns:
_type_: _description_
"""
# run inference using the error analysis object per image
outputs, img_shape = self.ea_obj.generate_inference(img_fname)
# converts image coordinates from normalised to integer values
# image shape is [Y, X, C] (because Rows are Y)
# So don't get confused!
outputs[:, 0] *= img_shape[1]
outputs[:, 1] *= img_shape[0]
outputs[:, 2] *= img_shape[1]
outputs[:, 3] *= img_shape[0]
return outputs
def get_gt_annot(self, img_fname):
"""_summary_
Args:
img_fname (_type_): _description_
Returns:
_type_: _description_
"""
ground_truth = self.gt_annots[img_fname].copy()
img = cv2.imread(f"{self.inference_folder}{img_fname}")
img_shape = img.shape
ground_truth = transform_gt_bbox_format(ground_truth, img_shape, format="coco")
# converts image coordinates from normalised to integer values
# image shape is [Y, X, C] (because Rows are Y)
# So don't get confused!
ground_truth[:, 1] *= img_shape[1]
ground_truth[:, 2] *= img_shape[0]
ground_truth[:, 3] *= img_shape[1]
ground_truth[:, 4] *= img_shape[0]
return ground_truth
def draw_pred_bboxes(self, img_pred, preds):
"""_summary_
Args:
img_pred (_type_): _description_
preds (_type_): _description_
Returns:
_type_: _description_
"""
for pred in preds:
pred = pred.astype(int)
img_pred = cv2.rectangle(
img_pred,
(pred[0], pred[1]),
(pred[2], pred[3]),
color=self.pred_colour,
thickness=self.bbox_thickness,
)
img_pred = cv2.putText(
img_pred,
self.labels_dict[pred[5]],
(pred[0] + 5, pred[1] + 25),
color=self.pred_colour,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=self.font_scale,
thickness=self.font_thickness,
)
return img_pred
def draw_gt_bboxes(self, img_gt, gt_annots, **kwargs):
"""_summary_
Args:
img_gt (_type_): _description_
gt_annots (_type_): _description_
Returns:
_type_: _description_
"""
for annot in gt_annots:
annot = annot.astype(int)
# print (annot)
img_gt = cv2.rectangle(
img_gt,
(annot[1], annot[2]),
(annot[3], annot[4]),
color=self.gt_colour,
thickness=self.bbox_thickness,
)
img_gt = cv2.putText(
img_gt,
self.labels_dict[annot[0]],
(annot[1] + 5, annot[2] + 25),
color=(0, 255, 0),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=self.font_scale,
thickness=self.font_thickness,
)
return img_gt
def plot_with_preds_gt(self, option, side_by_side=False, plot_type=None):
"""Rules on what plot to generate
Args:
option (_string_): image filename. Toggled on the app itself. See app.py
side_by_side (bool, optional): Whether to have two plots side by side.
Defaults to False.
plot_type (_type_, optional): "all" - both GT and pred will be plotted,
"pred" - only preds,
"GT" - only ground truth
None - only image generated
Will be overridden if side_by_side = True
Defaults to None.
"""
if plot_type == "all":
plot, df, cm_tpfpfn_df = self.show_img(
option, show_preds=True, show_gt=True
)
st.plotly_chart(plot, use_container_width=True)
st.caption("Blue: Model BBox, Green: GT BBox")
st.table(df)
st.table(cm_tpfpfn_df)
elif plot_type == "pred":
st.plotly_chart(
self.show_img(option, show_preds=True), use_container_width=True
)
elif plot_type == "gt":
st.plotly_chart(
self.show_img(option, show_gt=True), use_container_width=True
)
elif side_by_side:
plot1, plot2, df, cm_tpfpfn_df = self.show_img_sbs(option)
col1, col2 = st.columns(2)
with col1:
col1.subheader("Ground Truth")
st.plotly_chart(plot1, use_container_width=True)
with col2:
col2.subheader("Prediction")
st.plotly_chart(plot2, use_container_width=True)
st.table(df)
st.table(cm_tpfpfn_df)
else:
st.plotly_chart(self.show_img(option), use_container_width=True)