File size: 3,656 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os.path as osp
from typing import Optional, Sequence, Tuple, Union

import cv2
import numpy as np

from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils.dl_utils import tensor2imgs

DATA_BATCH = Optional[Union[dict, tuple, list]]


# TODO: Due to interface changes, the current class
#  functions incorrectly
@HOOKS.register_module()
class NaiveVisualizationHook(Hook):
    """Show or Write the predicted results during the process of testing.

    Args:
        interval (int): Visualization interval. Defaults to 1.
        draw_gt (bool): Whether to draw the ground truth. Defaults to True.
        draw_pred (bool): Whether to draw the predicted result.
            Defaults to True.
    """
    priority = 'NORMAL'

    def __init__(self,
                 interval: int = 1,
                 draw_gt: bool = True,
                 draw_pred: bool = True):
        self.draw_gt = draw_gt
        self.draw_pred = draw_pred
        self._interval = interval

    def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
                                                           int]) -> np.ndarray:
        """Unpad the input image.

        Args:
            input (np.ndarray): The image to unpad.
            unpad_shape (tuple): The shape of image before padding.

        Returns:
            np.ndarray: The image before padding.
        """
        unpad_width, unpad_height = unpad_shape
        unpad_image = input[:unpad_height, :unpad_width]
        return unpad_image

    def before_train(self, runner) -> None:
        """Call add_graph method of visualizer.

        Args:
            runner (Runner): The runner of the training process.
        """
        runner.visualizer.add_graph(runner.model, None)

    def after_test_iter(self,
                        runner,
                        batch_idx: int,
                        data_batch: DATA_BATCH = None,
                        outputs: Optional[Sequence] = None) -> None:
        """Show or Write the predicted results.

        Args:
            runner (Runner): The runner of the training process.
            batch_idx (int): The index of the current batch in the test loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (Sequence, optional): Outputs from model.
        """
        if self.every_n_inner_iters(batch_idx, self._interval):
            for data, output in zip(data_batch, outputs):  # type: ignore
                input = data['inputs']
                data_sample = data['data_sample']
                input = tensor2imgs(input,
                                    **data_sample.get('img_norm_cfg',
                                                      dict()))[0]
                # TODO We will implement a function to revert the augmentation
                # in the future.
                ori_shape = (data_sample.ori_width, data_sample.ori_height)
                if 'pad_shape' in data_sample:
                    input = self._unpad(input,
                                        data_sample.get('scale', ori_shape))
                origin_image = cv2.resize(input, ori_shape)
                name = osp.basename(data_sample.img_path)
                runner.visualizer.add_datasample(name, origin_image,
                                                 data_sample, output,
                                                 self.draw_gt, self.draw_pred)