File size: 13,888 Bytes
a083fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Visualizer for results of prediction."""

from __future__ import annotations

import logging as log
import time
from typing import TYPE_CHECKING, NamedTuple

import cv2
import numpy as np
from model_api.performance_metrics import put_highlighted_text

from .vis_utils import ColorPalette

if TYPE_CHECKING:
    from demo_package.streamer import BaseStreamer
    from model_api.models.utils import (
        ClassificationResult,
        DetectionResult,
        InstanceSegmentationResult,
        SegmentedObject,
    )


class BaseVisualizer:
    """Base class for visualizators."""

    def __init__(
        self,
        window_name: str | None = None,
        no_show: bool = False,
        delay: int | None = None,
        output: str = "./outputs",
    ) -> None:
        """Base class for visualizators.

        Args:
            window_name (str]): The name of the window. Defaults to None.
            no_show (bool): Flag to indicate whether to show the window. Defaults to False.
            delay (int]): The delay in seconds. Defaults to None.
            output (str]): The output directory. Defaults to "./outputs".

        Returns:
            None
        """
        self.window_name = "Window" if window_name is None else window_name

        self.delay = delay
        self.no_show = no_show
        if delay is None:
            self.delay = 1
        self.output = output

    def draw(
        self,
        frame: np.ndarray,
        predictions: NamedTuple,
    ) -> np.ndarray:
        """Draw annotations on the image.

        Args:
            frame: Input image
            predictions: Annotations to be drawn on the input image

        Returns:
            Output image with annotations.
        """
        raise NotImplementedError

    def show(self, image: np.ndarray) -> None:
        """Show result image.

        Args:
            image (np.ndarray): Image to be shown.
        """
        if not self.no_show:
            cv2.imshow(self.window_name, image)

    def is_quit(self) -> bool:
        """Check user wish to quit."""
        if self.no_show:
            return False

        return ord("q") == cv2.waitKey(self.delay)

    def video_delay(self, elapsed_time: float, streamer: BaseStreamer) -> None:
        """Check if video frames were inferenced faster than the original video FPS and delay visualizer if so.

        Args:
            elapsed_time (float): Time spent on frame inference
            streamer (BaseStreamer): Streamer object
        """
        if self.no_show:
            return
        if "VIDEO" in str(streamer.get_type()):
            fps_num = streamer.fps()
            orig_frame_time = 1 / fps_num
            if elapsed_time < orig_frame_time:
                time.sleep(orig_frame_time - elapsed_time)


class ClassificationVisualizer(BaseVisualizer):
    """Visualize the predicted classification labels by drawing the annotations on the input image.

    Example:
        >>> predictions = inference_model.predict(frame)
        >>> output = visualizer.draw(frame, predictions)
        >>> visualizer.show(output)
    """

    def draw(
        self,
        frame: np.ndarray,
        predictions: ClassificationResult,
    ) -> np.ndarray:
        """Draw classification annotations on the image.

        Args:
            image: Input image
            annotation: Annotations to be drawn on the input image

        Returns:
            Output image with annotations.
        """
        predictions = predictions.top_labels
        if not any(predictions):
            log.warning("There are no predictions.")
            return frame

        class_label = predictions[0][1]
        font_scale = 0.7
        label_height = cv2.getTextSize(class_label, cv2.FONT_HERSHEY_COMPLEX, font_scale, 2)[0][1]
        initial_labels_pos = frame.shape[0] - label_height * (int(1.5 * len(predictions)) + 1)

        if initial_labels_pos < 0:
            initial_labels_pos = label_height
            log.warning("Too much labels to display on this frame, some will be omitted")
        offset_y = initial_labels_pos

        header = "Label:     Score:"
        label_width = cv2.getTextSize(header, cv2.FONT_HERSHEY_COMPLEX, font_scale, 2)[0][0]
        put_highlighted_text(
            frame,
            header,
            (frame.shape[1] - label_width, offset_y),
            cv2.FONT_HERSHEY_COMPLEX,
            font_scale,
            (255, 0, 0),
            2,
        )

        for idx, class_label, score in predictions:
            label = f"{idx}. {class_label}    {score:.2f}"
            label_width = cv2.getTextSize(label, cv2.FONT_HERSHEY_COMPLEX, font_scale, 2)[0][0]
            offset_y += int(label_height * 1.5)
            put_highlighted_text(
                frame,
                label,
                (frame.shape[1] - label_width, offset_y),
                cv2.FONT_HERSHEY_COMPLEX,
                font_scale,
                (255, 0, 0),
                2,
            )
        return frame


class SemanticSegmentationVisualizer(BaseVisualizer):
    """Visualize the predicted segmentation labels by drawing the annotations on the input image.

    Example:
        >>> masks = inference_model.predict(frame)
        >>> output = visualizer.draw(frame, masks)
        >>> visualizer.show(output)
    """

    def __init__(
        self,
        labels: list[str],
        window_name: str | None = None,
        no_show: bool = False,
        delay: int | None = None,
        output: str = "./outputs",
    ) -> None:
        """Semantic segmentation visualizer.

        Draws the segmentation masks on the input image.

        Parameters:
            labels (List[str]): List of labels.
            window_name (str | None): Name of the window (default is None).
            no_show (bool): Flag indicating whether to show the window (default is False).
            delay (int | None): Delay in milliseconds (default is None).
            output (str): Output path (default is "./outputs").

        Returns:
            None
        """
        super().__init__(window_name, no_show, delay, output)
        self.color_palette = ColorPalette(len(labels)).to_numpy_array()
        self.color_map = self._create_color_map()

    def _create_color_map(self) -> np.ndarray:
        classes = self.color_palette[:, ::-1]  # RGB to BGR
        color_map = np.zeros((256, 1, 3), dtype=np.uint8)
        classes_num = len(classes)
        color_map[:classes_num, 0, :] = classes
        color_map[classes_num:, 0, :] = np.random.uniform(0, 255, size=(256 - classes_num, 3))
        return color_map

    def _apply_color_map(self, input_2d_mask: np.ndarray) -> np.ndarray:
        input_3d = cv2.merge([input_2d_mask, input_2d_mask, input_2d_mask])
        return cv2.LUT(input_3d.astype(np.uint8), self.color_map)

    def draw(self, frame: np.ndarray, masks: SegmentedObject) -> np.ndarray:
        """Draw segmentation annotations on the image.

        Args:
            frame: Input image
            masks: Mask annotations to be drawn on the input image

        Returns:
            Output image with annotations.
        """
        masks = masks.resultImage
        output = self._apply_color_map(masks)
        return cv2.addWeighted(frame, 0.5, output, 0.5, 0)


class ObjectDetectionVisualizer(BaseVisualizer):
    """Visualizes object detection annotations on an input image."""

    def __init__(
        self,
        labels: list[str],
        window_name: str | None = None,
        no_show: bool = False,
        delay: int | None = None,
        output: str = "./outputs",
    ) -> None:
        """Object detection visualizer.

        Draws the object detection annotations on the input image.

        Parameters:
            labels (List[str]): The list of labels.
            window_name (str | None): The name of the window. Defaults to None.
            no_show (bool): Flag to control whether to show the window. Defaults to False.
            delay (int | None): The delay in milliseconds. Defaults to None.
            output (str): The output directory. Defaults to "./outputs".

        Returns:
            None
        """
        super().__init__(window_name, no_show, delay, output)
        self.labels = labels
        self.color_palette = ColorPalette(len(labels))

    def draw(
        self,
        frame: np.ndarray,
        predictions: DetectionResult,
    ) -> np.ndarray:
        """Draw instance segmentation annotations on the image.

        Args:
            image: Input image
            annotation: Annotations to be drawn on the input image

        Returns:
            Output image with annotations.
        """
        for detection in predictions.objects:
            class_id = int(detection.id)
            color = self.color_palette[class_id]
            det_label = self.color_palette[class_id] if self.labels and len(self.labels) >= class_id else f"#{class_id}"
            xmin, ymin, xmax, ymax = detection.xmin, detection.ymin, detection.xmax, detection.ymax
            cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), color, 2)
            cv2.putText(
                frame,
                f"{det_label} {detection.score:.1%}",
                (xmin, ymin - 7),
                cv2.FONT_HERSHEY_COMPLEX,
                0.6,
                color,
                1,
            )

        return frame


class InstanceSegmentationVisualizer(BaseVisualizer):
    """Visualizes Instance Segmentation annotations on an input image."""

    def __init__(
        self,
        labels: list[str],
        window_name: str | None = None,
        no_show: bool = False,
        delay: int | None = None,
        output: str = "./outputs",
    ) -> None:
        """Instance segmentation visualizer.

        Draws the instance segmentation annotations on the input image.

        Args:
            labels (List[str]): The list of labels.
            window_name (str]): The name of the window. Defaults to None.
            no_show (bool): A flag to indicate whether to show the window. Defaults to False.
            delay (int]): The delay in milliseconds. Defaults to None.
            output (str]): The path to the output directory. Defaults to "./outputs".

        Returns:
            None
        """
        super().__init__(window_name, no_show, delay, output)
        self.labels = labels
        colors_num = len(labels) if labels else 80
        self.show_boxes = False
        self.show_scores = True
        self.palette = ColorPalette(colors_num)

    def draw(
        self,
        frame: np.ndarray,
        predictions: InstanceSegmentationResult,
    ) -> np.ndarray:
        """Draw the instance segmentation results on the input frame.

        Args:
            frame: np.ndarray - The input frame on which to draw the instance segmentation results.
            predictions: InstanceSegmentationResult - The instance segmentation results to be drawn.

        Returns:
            np.ndarray - The input frame with the instance segmentation results drawn on it.
        """
        result = frame.copy()
        output_objects = predictions.segmentedObjects
        bboxes = [[output.xmin, output.ymin, output.xmax, output.ymax] for output in output_objects]
        scores = [output.score for output in output_objects]
        masks = [output.mask for output in output_objects]
        label_names = [output.str_label for output in output_objects]

        result = self._overlay_masks(result, masks)
        return self._overlay_labels(result, bboxes, label_names, scores)

    def _overlay_masks(self, image: np.ndarray, masks: list[np.ndarray]) -> np.ndarray:
        segments_image = image.copy()
        aggregated_mask = np.zeros(image.shape[:2], dtype=np.uint8)
        aggregated_colored_mask = np.zeros(image.shape, dtype=np.uint8)
        all_contours = []

        for i, mask in enumerate(masks):
            contours = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2]
            if contours:
                all_contours.append(contours[0])

            mask_color = self.palette[i]
            cv2.bitwise_or(aggregated_mask, mask, dst=aggregated_mask)
            cv2.bitwise_or(aggregated_colored_mask, mask_color, dst=aggregated_colored_mask, mask=mask)

        # Fill the area occupied by all instances with a colored instances mask image
        cv2.bitwise_and(segments_image, (0, 0, 0), dst=segments_image, mask=aggregated_mask)
        cv2.bitwise_or(segments_image, aggregated_colored_mask, dst=segments_image, mask=aggregated_mask)

        cv2.addWeighted(image, 0.5, segments_image, 0.5, 0, dst=image)
        cv2.drawContours(image, all_contours, -1, (0, 0, 0))
        return image

    def _overlay_boxes(self, image: np.ndarray, boxes: list[np.ndarray], classes: list[int]) -> np.ndarray:
        for box, class_id in zip(boxes, classes):
            color = self.palette[class_id]
            top_left, bottom_right = box[:2], box[2:]
            image = cv2.rectangle(image, top_left, bottom_right, color, 2)
        return image

    def _overlay_labels(
        self,
        image: np.ndarray,
        boxes: list[np.ndarray],
        classes: list[str],
        scores: list[float],
    ) -> np.ndarray:
        template = "{}: {:.2f}" if self.show_scores else "{}"

        for box, score, label in zip(boxes, scores, classes):
            text = template.format(label, score)
            textsize = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
            cv2.putText(
                image,
                text,
                (box[0], box[1] + int(textsize[0] / 3)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (255, 255, 255),
                1,
            )
        return image