File size: 15,908 Bytes
26364eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264e65b
26364eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264e65b
26364eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import streamlit as st
import numpy as np
import plotly.express as px
import cv2
from src.error_analysis import ErrorAnalysis, transform_gt_bbox_format
import yaml
import os
from src.confusion_matrix import ConfusionMatrix
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd


def amend_cm_df(cm_df, labels_dict):
    """Helper function to amend the index and column name for readability
    Example - index currently is 0, 1 ... -> GT - person
    Likewise in Column - 0, 1 ... -> Pred - person etc

    Args:
        cm_df (_type_): confusion matrix dataframe.
        labels_dict (_type_): dictionary of the class labels

    Returns:
        cm_df: confusion matrix dataframe with index and column names filled
    """

    index_list = list(labels_dict.values())
    index_list.append("background")

    cm_df = cm_df.set_axis([f"GT - {elem}" for elem in index_list])
    cm_df = cm_df.set_axis([f"Pred - {elem}" for elem in index_list], axis=1)
    cm_df = cm_df.astype(int)

    return cm_df

def find_top_left_pos(mask):
    """gets the top left position of the mask 

    Args:
        mask (_type_): _description_

    Returns:
        _type_: _description_
    """

    return np.unravel_index(np.argmax(mask, axis=None), mask.shape)


class ImageTool:
    
    def __init__(self, cfg_path="cfg/cfg.yml"):

        # getting the config object                
        cfg_file = open(cfg_path)
        self.cfg_obj = yaml.load(cfg_file, Loader=yaml.FullLoader)
        
        # initialising the model and getting the annotations
        self.ea_obj = ErrorAnalysis(cfg_path)
        self.inference_folder = self.ea_obj.inference_folder
        self.ea_obj.get_annots()
        self.gt_annots = self.ea_obj.gt_dict
        self.all_img = os.listdir(self.inference_folder)
        self.ea_obj.model.score_threshold = self.cfg_obj["visual_tool"]["conf_threshold"]
        self.ea_obj.model.iou_threshold = self.cfg_obj["visual_tool"]["iou_threshold"]

        # for labels
        self.labels_dict = self.cfg_obj["error_analysis"]["labels_dict"]
        self.labels_dict = {v: k for k, v in self.labels_dict.items()}
        self.inference_labels_dict = self.cfg_obj["error_analysis"]["inference_labels_dict"]
        self.inference_labels_dict = {v: k for k, v in self.inference_labels_dict.items()}
        self.idx_base = self.cfg_obj["error_analysis"]["idx_base"]

        # for visualisation
        self.bbox_thickness = self.cfg_obj["visual_tool"]["bbox_thickness"]
        self.font_scale = self.cfg_obj["visual_tool"]["font_scale"]
        self.font_thickness = self.cfg_obj["visual_tool"]["font_thickness"]
        self.pred_colour = tuple(self.cfg_obj["visual_tool"]["pred_colour"])
        self.gt_colour = tuple(self.cfg_obj["visual_tool"]["gt_colour"])

    def show_img(self, img_fname="000000011149.jpg", show_preds=False, show_gt=False):
        """generate img with option to overlay with GT and/or preds

        Args:
            img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg".
            show_preds (bool, optional): Toggle True to run model to get the preds. Defaults to False.
            show_gt (bool, optional): Toggle True to get the GT labels/boxes. Defaults to False.

        Returns:
            fig (Plotly Figure): image with overlays if toggled True
            cm_df (pd.DataFrame): confusion matrix of the pred versus GT
            cm_tpfpfn_dict (Dict): confusion matrix dictionary of tp/fp/fn
        """

        # get the image's file path. Concatenates with the folder in question
        img = cv2.imread(f"{self.inference_folder}{img_fname}")

        labels = {"x": "X", "y": "Y", "color": "Colour"}

        if show_preds:

            preds = self.get_preds(img_fname)
            if self.ea_obj.task == "det":
                img = self.draw_pred_bboxes(img, preds)
            elif self.ea_obj.task == "seg":
                img = self.draw_pred_masks(img, preds)

        if show_gt:

            gt_annots = self.get_gt_annot(img_fname)

            if self.ea_obj.task == "det":
                img = self.draw_gt_bboxes(img, preds)
            elif self.ea_obj.task == "seg":
                img = self.draw_gt_masks(img, gt_annots)

        fig = px.imshow(img[..., ::-1], aspect="equal", labels=labels)

        if show_gt and show_preds:

            cm_df, cm_tpfpfn_dict = self.generate_cm_one_image(preds, gt_annots)
            return [fig, cm_df, cm_tpfpfn_dict]

        return fig

    def show_img_sbs(self, img_fname="000000011149.jpg"):
        """generate two imageso with confusion matrix and tp/fp/fn. fig1 is image with GT overlay, while fig2 is the image witih pred overlay.

        Args:
            img_fname (str, optional): Filename of the image. Defaults to "000000011149.jpg".

        Returns:
            list: fig1 - imshow of image with GT overlay
                  fig2 - imshow of image with pred overlay
                  cm_df - confusion matrix dataframe
                  cm_tpfpfn_df - confusion matrix dictionary of tp/fp/fn
        """

        # shows the image side by side
        img = cv2.imread(f"{self.inference_folder}{img_fname}")
        labels = {"x": "X", "y": "Y", "color": "Colour"}

        img_pred = img.copy()
        img_gt = img.copy()

        preds = self.get_preds(img_fname)

        gt_annots = self.get_gt_annot(img_fname)

        if self.ea_obj.task == 'det':
            img_pred = self.draw_pred_bboxes(img_pred, preds)
            img_gt = self.draw_gt_bboxes(img_gt, gt_annots)

        elif self.ea_obj.task == 'seg':
            img_pred = self.draw_pred_masks(img_pred, preds)
            img_gt = self.draw_gt_masks(img_gt, gt_annots)


        fig1 = px.imshow(img_gt[..., ::-1], aspect="equal", labels=labels)
        fig2 = px.imshow(img_pred[..., ::-1], aspect="equal", labels=labels)
        fig2.update_yaxes(visible=False)

        cm_df, cm_tpfpfn_df = self.generate_cm_one_image(preds, gt_annots)

        return [fig1, fig2, cm_df, cm_tpfpfn_df]

    def generate_cm_one_image(self, preds, gt_annots):
        """Generates confusion matrix between the inference and the Ground Truth of an image

        Args:
            preds (array): inference output of the model on the image
            gt_annots (array): Ground Truth labels of the image

        Returns:
            cm_df (DataFrame): Confusion matrix dataframe. 
            cm_tpfpfn_df (DataFrame): TP/FP/FN dataframe
        """

        num_classes = len(list(self.cfg_obj["error_analysis"]["labels_dict"].keys()))
        idx_base = self.cfg_obj["error_analysis"]["idx_base"]

        conf_threshold, iou_threshold = (
            self.ea_obj.model.score_threshold,
            self.ea_obj.model.iou_threshold,
        )
        cm = ConfusionMatrix(
            num_classes=num_classes,
            CONF_THRESHOLD=conf_threshold,
            IOU_THRESHOLD=iou_threshold,
        )
        if self.ea_obj.task == 'det':
            gt_annots[:, 0] -= idx_base
            preds[:, -1] -= idx_base
        elif self.ea_obj.task == 'seg':
            gt_annots = [[gt[0] - idx_base, gt[1]] for gt in gt_annots]

        cm.process_batch(preds, gt_annots, task = self.ea_obj.task)

        confusion_matrix_df = cm.return_as_df()
        cm.get_tpfpfn()
        
        cm_tpfpfn_dict = {
            "True Positive": cm.tp,
            "False Positive": cm.fp,
            "False Negative": cm.fn,
        }

        cm_tpfpfn_df = pd.DataFrame(cm_tpfpfn_dict, index=[0])
        cm_tpfpfn_df = cm_tpfpfn_df.set_axis(["Values"], axis=0)
        cm_tpfpfn_df = cm_tpfpfn_df.astype(int)
        # amend df

        confusion_matrix_df = amend_cm_df(confusion_matrix_df, self.labels_dict)
        # print (cm.matrix)

        return confusion_matrix_df, cm_tpfpfn_df

    def get_preds(self, img_fname="000000011149.jpg"):
        """Using the model in the Error Analysis object, run inference to get outputs

        Args:
            img_fname (str): Image filename. Defaults to "000000011149.jpg".

        Returns:
            outputs (array): Inference output of the model on the image
        """

        # run inference using the error analysis object per image
        outputs, img_shape = self.ea_obj.generate_inference(img_fname)

        if self.ea_obj.task == 'det':
            # converts image coordinates from normalised to integer values
            # image shape is [Y, X, C] (because Rows are Y)
            # So don't get confused!
            outputs[:, 0] *= img_shape[1]
            outputs[:, 1] *= img_shape[0]
            outputs[:, 2] *= img_shape[1]
            outputs[:, 3] *= img_shape[0]
    
        return outputs

    def get_gt_annot(self, img_fname):
        """Retrieve the Ground Truth annotations of the image. 

        Args:
            img_fname (_type_): Image filename

        Returns:
            grount_truth (array): GT labels of the image
        """
        ground_truth = self.gt_annots[img_fname].copy()
        img = cv2.imread(f"{self.inference_folder}{img_fname}")
        
        # converts image coordinates from normalised to integer values
        # image shape is [Y, X, C] (because Rows are Y)
        # So don't get confused!
        if self.ea_obj.task == 'det':
            img_shape = img.shape
            ground_truth = transform_gt_bbox_format(ground_truth, img_shape, format="coco")
            ground_truth[:, 1] *= img_shape[1]
            ground_truth[:, 2] *= img_shape[0]
            ground_truth[:, 3] *= img_shape[1]
            ground_truth[:, 4] *= img_shape[0]

        return ground_truth
    
    def draw_pred_masks(self, img_pred, inference_outputs):
        """Overlay mask onto img_pred

        Args:
            img_pred (_type_): _description_
            preds (_type_): _description_
        """

        pred_mask = sum([output[0] for output in inference_outputs])
        pred_mask = np.where(pred_mask > 1, 1, pred_mask)
        # mask_3d = np.stack((mask,mask,mask),axis=0)
        # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3)
        colour = np.array(self.pred_colour, dtype='uint8')
        masked_img = np.where(pred_mask[...,None], colour, img_pred)
        masked_img = masked_img.astype(np.uint8)

        img_pred = cv2.addWeighted(img_pred, 0.7, masked_img, 0.3, 0)
        
        def put_text_ina_mask(output, img):
            
            coords = find_top_left_pos(output[0])

            img = cv2.putText(img, self.inference_labels_dict[output[2]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, 
                        color = self.pred_colour, thickness = self.font_thickness)

            return img
        
        for output in inference_outputs:
            img_pred = put_text_ina_mask(output, img_pred)

        return img_pred
    
    def draw_gt_masks(self, img_gt, gt_outputs):
        """Overlay mask onto img_pred

        Args:
            img_pred (_type_): _description_
            preds (_type_): _description_
        """

        gt_mask = sum([output[1] for output in gt_outputs])
        gt_mask = np.where(gt_mask > 1, 1, gt_mask)
        # mask_3d = np.stack((mask,mask,mask),axis=0)
        # mask_3d = mask_3d.reshape(mask.shape[0], mask.shape[1], 3)
        colour = np.array(self.gt_colour, dtype='uint8')
        masked_img = np.where(gt_mask[...,None], colour, img_gt)
                
        def put_text_ina_mask(output, img):

            coords = find_top_left_pos(output[1])

            img = cv2.putText(img, self.labels_dict[output[0]], (coords[1], coords[0] + 5), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale = self.font_scale, 
                        color = self.gt_colour, thickness = self.font_thickness)

            return img
        
        img_gt = cv2.addWeighted(img_gt, 0.7, masked_img, 0.3,0)

        for output in gt_outputs:
            img_gt = put_text_ina_mask(output, img_gt)

        return img_gt

    def draw_pred_bboxes(self, img_pred, preds):
        """Draws the preds onto the image

        Args:
            img_pred (array): image 
            preds (array): model inference outputs 

        Returns:
            img_pred (array): image with outputs on overlay
        """
        for pred in preds:
            pred = pred.astype(int)
            img_pred = cv2.rectangle(
                img_pred,
                (pred[0], pred[1]),
                (pred[2], pred[3]),
                color=self.pred_colour,
                thickness=self.bbox_thickness,
            )
            img_pred = cv2.putText(
                img_pred,
                self.labels_dict[pred[5]],
                (pred[0] + 5, pred[1] + 25),
                color=self.pred_colour,
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=self.font_scale,
                thickness=self.font_thickness,
            )
        return img_pred

    def draw_gt_bboxes(self, img_gt, gt_annots, **kwargs):
        """Draws the GT onto the image

        Args:
            img_gt (array): image
            gt_annots (array): GT labels

        Returns:
            img_gt (array): image with GT overlay
        """
        for annot in gt_annots:
            annot = annot.astype(int)
            # print (annot)
            img_gt = cv2.rectangle(
                img_gt,
                (annot[1], annot[2]),
                (annot[3], annot[4]),
                color=self.gt_colour,
                thickness=self.bbox_thickness,
            )
            img_gt = cv2.putText(
                img_gt,
                self.labels_dict[annot[0]],
                (annot[1] + 5, annot[2] + 25),
                color=(0, 255, 0),
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=self.font_scale,
                thickness=self.font_thickness,
            )
        return img_gt

    def plot_with_preds_gt(self, option, side_by_side=False, plot_type=None):
        """Rules on what plot to generate

        Args:
            option (_string_): image filename. Toggled on the app itself. See app.py
            side_by_side (bool, optional): Whether to have two plots side by side.
                                            Defaults to False.
            plot_type (_type_, optional): "all" - both GT and pred will be plotted,
                                "pred" - only preds,
                                "GT" - only ground truth
                                None - only image generated
                                Will be overridden if side_by_side = True
                                Defaults to None.
        """

        if plot_type == "all":
            plot, df, cm_tpfpfn_df = self.show_img(
                option, show_preds=True, show_gt=True
            )
            st.plotly_chart(plot, use_container_width=True)
            st.caption("Blue: Model BBox, Green: GT BBox")

            st.table(df)
            st.table(cm_tpfpfn_df)

        elif plot_type == "pred":
            st.plotly_chart(
                self.show_img(option, show_preds=True), use_container_width=True
            )

        elif plot_type == "gt":
            st.plotly_chart(
                self.show_img(option, show_gt=True), use_container_width=True
            )

        elif side_by_side:

            plot1, plot2, df, cm_tpfpfn_df = self.show_img_sbs(option)
            col1, col2 = st.columns(2)

            with col1:
                col1.subheader("Ground Truth")
                st.plotly_chart(plot1, use_container_width=True)
            with col2:
                col2.subheader("Prediction")
                st.plotly_chart(plot2, use_container_width=True)

            st.table(df)
            st.table(cm_tpfpfn_df)

        else:
            st.plotly_chart(self.show_img(option), use_container_width=True)