File size: 6,438 Bytes
f1dd031
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Callable, Dict, List, Optional, Sequence, Union

import cv2
import numpy as np
from mmcv.transforms import TRANSFORMS
from mmcv.transforms.utils import cache_random_params
from mmcv.transforms.wrappers import *

# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]

# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey = object()

# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try:
    from contextlib import nullcontext  # type: ignore
except ImportError:
    from contextlib import contextmanager

    @contextmanager  # type: ignore
    def nullcontext(resource=None):
        try:
            yield resource
        finally:
            pass


def imdenormalize(img, mean, std, to_bgr=True):
    assert img.dtype != np.uint8
    mean = mean.reshape(1, -1).astype(np.float64)
    std = std.reshape(1, -1).astype(np.float64)
    img = cv2.multiply(img, std)  # make a copy
    cv2.add(img, mean, img)  # inplace
    if to_bgr:
        cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img)  # inplace
    return img


@TRANSFORMS.register_module()
class MasaTransformBroadcaster(KeyMapper):
    """A transform wrapper to apply the wrapped transforms to multiple data
    items. For example, apply Resize to multiple images.

    Args:
        transforms (list[dict | callable]): Sequence of transform object or
            config dict to be wrapped.
        mapping (dict): A dict that defines the input key mapping.
            Note that to apply the transforms to multiple data items, the
            outer keys of the target items should be remapped as a list with
            the standard inner key (The key required by the wrapped transform).
            See the following example and the document of
            ``mmcv.transforms.wrappers.KeyMapper`` for details.
        remapping (dict): A dict that defines the output key mapping.
            The keys and values have the same meanings and rules as in the
            ``mapping``. Default: None.
        auto_remap (bool, optional): If True, an inverse of the mapping will
            be used as the remapping. If auto_remap is not given, it will be
            automatically set True if 'remapping' is not given, and vice
            versa. Default: None.
        allow_nonexist_keys (bool): If False, the outer keys in the mapping
            must exist in the input data, or an exception will be raised.
            Default: False.
        share_random_params (bool): If True, the random transform
            (e.g., RandomFlip) will be conducted in a deterministic way and
            have the same behavior on all data items. For example, to randomly
            flip either both input image and ground-truth image, or none.
            Default: False.

    """

    def __init__(
        self,
        transforms: List[Union[Dict, Callable[[Dict], Dict]]],
        mapping: Optional[Dict] = None,
        remapping: Optional[Dict] = None,
        auto_remap: Optional[bool] = None,
        allow_nonexist_keys: bool = False,
        share_random_params: bool = False,
    ):
        super().__init__(
            transforms, mapping, remapping, auto_remap, allow_nonexist_keys
        )

        self.share_random_params = share_random_params

    def scatter_sequence(self, data: Dict) -> List[Dict]:
        """Scatter the broadcasting targets to a list of inputs of the wrapped
        transforms."""

        # infer split number from input
        seq_len = 0
        key_rep = None

        if self.mapping:
            keys = self.mapping.keys()
        else:
            keys = data.keys()

        for key in keys:
            assert isinstance(data[key], Sequence)
            if seq_len:
                if len(data[key]) != seq_len:
                    raise ValueError(
                        "Got inconsistent sequence length: "
                        f"{seq_len} ({key_rep}) vs. "
                        f"{len(data[key])} ({key})"
                    )
            else:
                seq_len = len(data[key])
                key_rep = key

        assert seq_len > 0, "Fail to get the number of broadcasting targets"

        scatters = []
        for i in range(seq_len):  # type: ignore
            scatter = data.copy()
            for key in keys:
                scatter[key] = data[key][i]
            scatters.append(scatter)
        return scatters

    def transform(self, results: Dict):
        """Broadcast wrapped transforms to multiple targets."""

        # Apply input remapping
        inputs = self._map_input(results, self.mapping)

        # Scatter sequential inputs into a list
        input_scatters = self.scatter_sequence(inputs)

        # Control random parameter sharing with a context manager
        if self.share_random_params:
            # The context manager :func`:cache_random_params` will let
            # cacheable method of the transforms cache their outputs. Thus
            # the random parameters will only generated once and shared
            # by all data items.
            ctx = cache_random_params  # type: ignore
        else:
            ctx = nullcontext  # type: ignore

        with ctx(self.transforms):
            output_scatters = [
                self._apply_transforms(_input) for _input in input_scatters
            ]

        outputs = {
            key: [_output[key] for _output in output_scatters]
            for key in output_scatters[0]
        }

        # Apply remapping
        outputs = self._map_output(outputs, self.remapping)

        results.update(outputs)
        return results

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f"(transforms = {self.transforms}"
        repr_str += f", mapping = {self.mapping}"
        repr_str += f", remapping = {self.remapping}"
        repr_str += f", auto_remap = {self.auto_remap}"
        repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}"
        repr_str += f", share_random_params = {self.share_random_params})"
        return repr_str