JohanDL's picture
initial commit
f1dd031
from typing import Callable, Dict, List, Optional, Sequence, Union
import cv2
import numpy as np
from mmcv.transforms import TRANSFORMS
from mmcv.transforms.utils import cache_random_params
from mmcv.transforms.wrappers import *
# Define type of transform or transform config
Transform = Union[Dict, Callable[[Dict], Dict]]
# Indicator of keys marked by KeyMapper._map_input, which means ignoring the
# marked keys in KeyMapper._apply_transform so they will be invisible to
# wrapped transforms.
# This can be 2 possible case:
# 1. The key is required but missing in results
# 2. The key is manually set as ... (Ellipsis) in ``mapping``, which means
# the original value in results should be ignored
IgnoreKey = object()
# Import nullcontext if python>=3.7, otherwise use a simple alternative
# implementation.
try:
from contextlib import nullcontext # type: ignore
except ImportError:
from contextlib import contextmanager
@contextmanager # type: ignore
def nullcontext(resource=None):
try:
yield resource
finally:
pass
def imdenormalize(img, mean, std, to_bgr=True):
assert img.dtype != np.uint8
mean = mean.reshape(1, -1).astype(np.float64)
std = std.reshape(1, -1).astype(np.float64)
img = cv2.multiply(img, std) # make a copy
cv2.add(img, mean, img) # inplace
if to_bgr:
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
return img
@TRANSFORMS.register_module()
class MasaTransformBroadcaster(KeyMapper):
"""A transform wrapper to apply the wrapped transforms to multiple data
items. For example, apply Resize to multiple images.
Args:
transforms (list[dict | callable]): Sequence of transform object or
config dict to be wrapped.
mapping (dict): A dict that defines the input key mapping.
Note that to apply the transforms to multiple data items, the
outer keys of the target items should be remapped as a list with
the standard inner key (The key required by the wrapped transform).
See the following example and the document of
``mmcv.transforms.wrappers.KeyMapper`` for details.
remapping (dict): A dict that defines the output key mapping.
The keys and values have the same meanings and rules as in the
``mapping``. Default: None.
auto_remap (bool, optional): If True, an inverse of the mapping will
be used as the remapping. If auto_remap is not given, it will be
automatically set True if 'remapping' is not given, and vice
versa. Default: None.
allow_nonexist_keys (bool): If False, the outer keys in the mapping
must exist in the input data, or an exception will be raised.
Default: False.
share_random_params (bool): If True, the random transform
(e.g., RandomFlip) will be conducted in a deterministic way and
have the same behavior on all data items. For example, to randomly
flip either both input image and ground-truth image, or none.
Default: False.
"""
def __init__(
self,
transforms: List[Union[Dict, Callable[[Dict], Dict]]],
mapping: Optional[Dict] = None,
remapping: Optional[Dict] = None,
auto_remap: Optional[bool] = None,
allow_nonexist_keys: bool = False,
share_random_params: bool = False,
):
super().__init__(
transforms, mapping, remapping, auto_remap, allow_nonexist_keys
)
self.share_random_params = share_random_params
def scatter_sequence(self, data: Dict) -> List[Dict]:
"""Scatter the broadcasting targets to a list of inputs of the wrapped
transforms."""
# infer split number from input
seq_len = 0
key_rep = None
if self.mapping:
keys = self.mapping.keys()
else:
keys = data.keys()
for key in keys:
assert isinstance(data[key], Sequence)
if seq_len:
if len(data[key]) != seq_len:
raise ValueError(
"Got inconsistent sequence length: "
f"{seq_len} ({key_rep}) vs. "
f"{len(data[key])} ({key})"
)
else:
seq_len = len(data[key])
key_rep = key
assert seq_len > 0, "Fail to get the number of broadcasting targets"
scatters = []
for i in range(seq_len): # type: ignore
scatter = data.copy()
for key in keys:
scatter[key] = data[key][i]
scatters.append(scatter)
return scatters
def transform(self, results: Dict):
"""Broadcast wrapped transforms to multiple targets."""
# Apply input remapping
inputs = self._map_input(results, self.mapping)
# Scatter sequential inputs into a list
input_scatters = self.scatter_sequence(inputs)
# Control random parameter sharing with a context manager
if self.share_random_params:
# The context manager :func`:cache_random_params` will let
# cacheable method of the transforms cache their outputs. Thus
# the random parameters will only generated once and shared
# by all data items.
ctx = cache_random_params # type: ignore
else:
ctx = nullcontext # type: ignore
with ctx(self.transforms):
output_scatters = [
self._apply_transforms(_input) for _input in input_scatters
]
outputs = {
key: [_output[key] for _output in output_scatters]
for key in output_scatters[0]
}
# Apply remapping
outputs = self._map_output(outputs, self.remapping)
results.update(outputs)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(transforms = {self.transforms}"
repr_str += f", mapping = {self.mapping}"
repr_str += f", remapping = {self.remapping}"
repr_str += f", auto_remap = {self.auto_remap}"
repr_str += f", allow_nonexist_keys = {self.allow_nonexist_keys}"
repr_str += f", share_random_params = {self.share_random_params})"
return repr_str