Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from copy import deepcopy | |
from typing import Dict, List, Optional, Sequence, Tuple, Union | |
import mmcv | |
import mmengine | |
import numpy as np | |
from mmcv.image import imflip | |
from mmcv.transforms import BaseTransform | |
from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness | |
from mmengine import is_list_of | |
from mmengine.dist import get_dist_info | |
from scipy.stats import truncnorm | |
from mmpose.codecs import * # noqa: F401, F403 | |
from mmpose.registry import KEYPOINT_CODECS, TRANSFORMS | |
from mmpose.structures.bbox import bbox_xyxy2cs, flip_bbox | |
from mmpose.structures.keypoint import flip_keypoints | |
from mmpose.utils.typing import MultiConfig | |
try: | |
import albumentations | |
except ImportError: | |
albumentations = None | |
Number = Union[int, float] | |
class GetBBoxCenterScale(BaseTransform): | |
"""Convert bboxes from [x, y, w, h] to center and scale. | |
The center is the coordinates of the bbox center, and the scale is the | |
bbox width and height normalized by a scale factor. | |
Required Keys: | |
- bbox | |
Added Keys: | |
- bbox_center | |
- bbox_scale | |
Args: | |
padding (float): The bbox padding scale that will be multilied to | |
`bbox_scale`. Defaults to 1.25 | |
""" | |
def __init__(self, padding: float = 1.25) -> None: | |
super().__init__() | |
self.padding = padding | |
def transform(self, results: Dict) -> Optional[dict]: | |
"""The transform function of :class:`GetBBoxCenterScale`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): The result dict | |
Returns: | |
dict: The result dict. | |
""" | |
if 'bbox_center' in results and 'bbox_scale' in results: | |
rank, _ = get_dist_info() | |
if rank == 0: | |
warnings.warn('Use the existing "bbox_center" and "bbox_scale"' | |
'. The padding will still be applied.') | |
results['bbox_scale'] *= self.padding | |
else: | |
bbox = results['bbox'] | |
center, scale = bbox_xyxy2cs(bbox, padding=self.padding) | |
results['bbox_center'] = center | |
results['bbox_scale'] = scale | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ + f'(padding={self.padding})' | |
return repr_str | |
class RandomFlip(BaseTransform): | |
"""Randomly flip the image, bbox and keypoints. | |
Required Keys: | |
- img | |
- img_shape | |
- flip_indices | |
- input_size (optional) | |
- bbox (optional) | |
- bbox_center (optional) | |
- keypoints (optional) | |
- keypoints_visible (optional) | |
- img_mask (optional) | |
Modified Keys: | |
- img | |
- bbox (optional) | |
- bbox_center (optional) | |
- keypoints (optional) | |
- keypoints_visible (optional) | |
- img_mask (optional) | |
Added Keys: | |
- flip | |
- flip_direction | |
Args: | |
prob (float | list[float]): The flipping probability. If a list is | |
given, the argument `direction` should be a list with the same | |
length. And each element in `prob` indicates the flipping | |
probability of the corresponding one in ``direction``. Defaults | |
to 0.5 | |
direction (str | list[str]): The flipping direction. Options are | |
``'horizontal'``, ``'vertical'`` and ``'diagonal'``. If a list is | |
is given, each data sample's flipping direction will be sampled | |
from a distribution determined by the argument ``prob``. Defaults | |
to ``'horizontal'``. | |
""" | |
def __init__(self, | |
prob: Union[float, List[float]] = 0.5, | |
direction: Union[str, List[str]] = 'horizontal') -> None: | |
if isinstance(prob, list): | |
assert is_list_of(prob, float) | |
assert 0 <= sum(prob) <= 1 | |
elif isinstance(prob, float): | |
assert 0 <= prob <= 1 | |
else: | |
raise ValueError(f'probs must be float or list of float, but \ | |
got `{type(prob)}`.') | |
self.prob = prob | |
valid_directions = ['horizontal', 'vertical', 'diagonal'] | |
if isinstance(direction, str): | |
assert direction in valid_directions | |
elif isinstance(direction, list): | |
assert is_list_of(direction, str) | |
assert set(direction).issubset(set(valid_directions)) | |
else: | |
raise ValueError(f'direction must be either str or list of str, \ | |
but got `{type(direction)}`.') | |
self.direction = direction | |
if isinstance(prob, list): | |
assert len(prob) == len(self.direction) | |
def _choose_direction(self) -> str: | |
"""Choose the flip direction according to `prob` and `direction`""" | |
if isinstance(self.direction, | |
List) and not isinstance(self.direction, str): | |
# None means non-flip | |
direction_list: list = list(self.direction) + [None] | |
elif isinstance(self.direction, str): | |
# None means non-flip | |
direction_list = [self.direction, None] | |
if isinstance(self.prob, list): | |
non_prob: float = 1 - sum(self.prob) | |
prob_list = self.prob + [non_prob] | |
elif isinstance(self.prob, float): | |
non_prob = 1. - self.prob | |
# exclude non-flip | |
single_ratio = self.prob / (len(direction_list) - 1) | |
prob_list = [single_ratio] * (len(direction_list) - 1) + [non_prob] | |
cur_dir = np.random.choice(direction_list, p=prob_list) | |
return cur_dir | |
def transform(self, results: dict) -> dict: | |
"""The transform function of :class:`RandomFlip`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): The result dict | |
Returns: | |
dict: The result dict. | |
""" | |
flip_dir = self._choose_direction() | |
if flip_dir is None: | |
results['flip'] = False | |
results['flip_direction'] = None | |
else: | |
results['flip'] = True | |
results['flip_direction'] = flip_dir | |
h, w = results.get('input_size', results['img_shape']) | |
# flip image and mask | |
if isinstance(results['img'], list): | |
results['img'] = [ | |
imflip(img, direction=flip_dir) for img in results['img'] | |
] | |
else: | |
results['img'] = imflip(results['img'], direction=flip_dir) | |
if 'img_mask' in results: | |
results['img_mask'] = imflip( | |
results['img_mask'], direction=flip_dir) | |
# flip bboxes | |
if results.get('bbox', None) is not None: | |
results['bbox'] = flip_bbox( | |
results['bbox'], | |
image_size=(w, h), | |
bbox_format='xyxy', | |
direction=flip_dir) | |
if results.get('bbox_center', None) is not None: | |
results['bbox_center'] = flip_bbox( | |
results['bbox_center'], | |
image_size=(w, h), | |
bbox_format='center', | |
direction=flip_dir) | |
# flip keypoints | |
if results.get('keypoints', None) is not None: | |
keypoints, keypoints_visible = flip_keypoints( | |
results['keypoints'], | |
results.get('keypoints_visible', None), | |
image_size=(w, h), | |
flip_indices=results['flip_indices'], | |
direction=flip_dir) | |
results['keypoints'] = keypoints | |
results['keypoints_visible'] = keypoints_visible | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(prob={self.prob}, ' | |
repr_str += f'direction={self.direction})' | |
return repr_str | |
class RandomHalfBody(BaseTransform): | |
"""Data augmentation with half-body transform that keeps only the upper or | |
lower body at random. | |
Required Keys: | |
- keypoints | |
- keypoints_visible | |
- upper_body_ids | |
- lower_body_ids | |
Modified Keys: | |
- bbox | |
- bbox_center | |
- bbox_scale | |
Args: | |
min_total_keypoints (int): The minimum required number of total valid | |
keypoints of a person to apply half-body transform. Defaults to 8 | |
min_half_keypoints (int): The minimum required number of valid | |
half-body keypoints of a person to apply half-body transform. | |
Defaults to 2 | |
padding (float): The bbox padding scale that will be multilied to | |
`bbox_scale`. Defaults to 1.5 | |
prob (float): The probability to apply half-body transform when the | |
keypoint number meets the requirement. Defaults to 0.3 | |
""" | |
def __init__(self, | |
min_total_keypoints: int = 9, | |
min_upper_keypoints: int = 2, | |
min_lower_keypoints: int = 3, | |
padding: float = 1.5, | |
prob: float = 0.3, | |
upper_prioritized_prob: float = 0.7) -> None: | |
super().__init__() | |
self.min_total_keypoints = min_total_keypoints | |
self.min_upper_keypoints = min_upper_keypoints | |
self.min_lower_keypoints = min_lower_keypoints | |
self.padding = padding | |
self.prob = prob | |
self.upper_prioritized_prob = upper_prioritized_prob | |
def _get_half_body_bbox(self, keypoints: np.ndarray, | |
half_body_ids: List[int] | |
) -> Tuple[np.ndarray, np.ndarray]: | |
"""Get half-body bbox center and scale of a single instance. | |
Args: | |
keypoints (np.ndarray): Keypoints in shape (K, D) | |
upper_body_ids (list): The list of half-body keypont indices | |
Returns: | |
tuple: A tuple containing half-body bbox center and scale | |
- center: Center (x, y) of the bbox | |
- scale: Scale (w, h) of the bbox | |
""" | |
selected_keypoints = keypoints[half_body_ids] | |
center = selected_keypoints.mean(axis=0)[:2] | |
x1, y1 = selected_keypoints.min(axis=0) | |
x2, y2 = selected_keypoints.max(axis=0) | |
w = x2 - x1 | |
h = y2 - y1 | |
scale = np.array([w, h], dtype=center.dtype) * self.padding | |
return center, scale | |
def _random_select_half_body(self, keypoints_visible: np.ndarray, | |
upper_body_ids: List[int], | |
lower_body_ids: List[int] | |
) -> List[Optional[List[int]]]: | |
"""Randomly determine whether applying half-body transform and get the | |
half-body keyponit indices of each instances. | |
Args: | |
keypoints_visible (np.ndarray, optional): The visibility of | |
keypoints in shape (N, K, 1). | |
upper_body_ids (list): The list of upper body keypoint indices | |
lower_body_ids (list): The list of lower body keypoint indices | |
Returns: | |
list[list[int] | None]: The selected half-body keypoint indices | |
of each instance. ``None`` means not applying half-body transform. | |
""" | |
half_body_ids = [] | |
for visible in keypoints_visible: | |
if visible.sum() < self.min_total_keypoints: | |
indices = None | |
elif np.random.rand() > self.prob: | |
indices = None | |
else: | |
upper_valid_ids = [i for i in upper_body_ids if visible[i] > 0] | |
lower_valid_ids = [i for i in lower_body_ids if visible[i] > 0] | |
num_upper = len(upper_valid_ids) | |
num_lower = len(lower_valid_ids) | |
prefer_upper = np.random.rand() < self.upper_prioritized_prob | |
if (num_upper < self.min_upper_keypoints | |
and num_lower < self.min_lower_keypoints): | |
indices = None | |
elif num_lower < self.min_lower_keypoints: | |
indices = upper_valid_ids | |
elif num_upper < self.min_upper_keypoints: | |
indices = lower_valid_ids | |
else: | |
indices = ( | |
upper_valid_ids if prefer_upper else lower_valid_ids) | |
half_body_ids.append(indices) | |
return half_body_ids | |
def transform(self, results: Dict) -> Optional[dict]: | |
"""The transform function of :class:`HalfBodyTransform`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): The result dict | |
Returns: | |
dict: The result dict. | |
""" | |
half_body_ids = self._random_select_half_body( | |
keypoints_visible=results['keypoints_visible'], | |
upper_body_ids=results['upper_body_ids'], | |
lower_body_ids=results['lower_body_ids']) | |
bbox_center = [] | |
bbox_scale = [] | |
for i, indices in enumerate(half_body_ids): | |
if indices is None: | |
bbox_center.append(results['bbox_center'][i]) | |
bbox_scale.append(results['bbox_scale'][i]) | |
else: | |
_center, _scale = self._get_half_body_bbox( | |
results['keypoints'][i], indices) | |
bbox_center.append(_center) | |
bbox_scale.append(_scale) | |
results['bbox_center'] = np.stack(bbox_center) | |
results['bbox_scale'] = np.stack(bbox_scale) | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(min_total_keypoints={self.min_total_keypoints}, ' | |
repr_str += f'min_upper_keypoints={self.min_upper_keypoints}, ' | |
repr_str += f'min_lower_keypoints={self.min_lower_keypoints}, ' | |
repr_str += f'padding={self.padding}, ' | |
repr_str += f'prob={self.prob}, ' | |
repr_str += f'upper_prioritized_prob={self.upper_prioritized_prob})' | |
return repr_str | |
class RandomBBoxTransform(BaseTransform): | |
r"""Rnadomly shift, resize and rotate the bounding boxes. | |
Required Keys: | |
- bbox_center | |
- bbox_scale | |
Modified Keys: | |
- bbox_center | |
- bbox_scale | |
Added Keys: | |
- bbox_rotation | |
Args: | |
shift_factor (float): Randomly shift the bbox in range | |
:math:`[-dx, dx]` and :math:`[-dy, dy]` in X and Y directions, | |
where :math:`dx(y) = x(y)_scale \cdot shift_factor` in pixels. | |
Defaults to 0.16 | |
shift_prob (float): Probability of applying random shift. Defaults to | |
0.3 | |
scale_factor (Tuple[float, float]): Randomly resize the bbox in range | |
:math:`[scale_factor[0], scale_factor[1]]`. Defaults to (0.5, 1.5) | |
scale_prob (float): Probability of applying random resizing. Defaults | |
to 1.0 | |
rotate_factor (float): Randomly rotate the bbox in | |
:math:`[-rotate_factor, rotate_factor]` in degrees. Defaults | |
to 80.0 | |
rotate_prob (float): Probability of applying random rotation. Defaults | |
to 0.6 | |
""" | |
def __init__(self, | |
shift_factor: float = 0.16, | |
shift_prob: float = 0.3, | |
scale_factor: Tuple[float, float] = (0.5, 1.5), | |
scale_prob: float = 1.0, | |
rotate_factor: float = 80.0, | |
rotate_prob: float = 0.6) -> None: | |
super().__init__() | |
self.shift_factor = shift_factor | |
self.shift_prob = shift_prob | |
self.scale_factor = scale_factor | |
self.scale_prob = scale_prob | |
self.rotate_factor = rotate_factor | |
self.rotate_prob = rotate_prob | |
def _truncnorm(low: float = -1., | |
high: float = 1., | |
size: tuple = ()) -> np.ndarray: | |
"""Sample from a truncated normal distribution.""" | |
return truncnorm.rvs(low, high, size=size).astype(np.float32) | |
def _get_transform_params(self, num_bboxes: int) -> Tuple: | |
"""Get random transform parameters. | |
Args: | |
num_bboxes (int): The number of bboxes | |
Returns: | |
tuple: | |
- offset (np.ndarray): Offset factor of each bbox in shape (n, 2) | |
- scale (np.ndarray): Scaling factor of each bbox in shape (n, 1) | |
- rotate (np.ndarray): Rotation degree of each bbox in shape (n,) | |
""" | |
# Get shift parameters | |
offset = self._truncnorm(size=(num_bboxes, 2)) * self.shift_factor | |
offset = np.where( | |
np.random.rand(num_bboxes, 1) < self.shift_prob, offset, 0.) | |
# Get scaling parameters | |
scale_min, scale_max = self.scale_factor | |
mu = (scale_max + scale_min) * 0.5 | |
sigma = (scale_max - scale_min) * 0.5 | |
scale = self._truncnorm(size=(num_bboxes, 1)) * sigma + mu | |
scale = np.where( | |
np.random.rand(num_bboxes, 1) < self.scale_prob, scale, 1.) | |
# Get rotation parameters | |
rotate = self._truncnorm(size=(num_bboxes, )) * self.rotate_factor | |
rotate = np.where( | |
np.random.rand(num_bboxes) < self.rotate_prob, rotate, 0.) | |
return offset, scale, rotate | |
def transform(self, results: Dict) -> Optional[dict]: | |
"""The transform function of :class:`RandomBboxTransform`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): The result dict | |
Returns: | |
dict: The result dict. | |
""" | |
bbox_scale = results['bbox_scale'] | |
num_bboxes = bbox_scale.shape[0] | |
offset, scale, rotate = self._get_transform_params(num_bboxes) | |
results['bbox_center'] += offset * bbox_scale | |
results['bbox_scale'] *= scale | |
results['bbox_rotation'] = rotate | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += f'(shift_prob={self.shift_prob}, ' | |
repr_str += f'shift_factor={self.shift_factor}, ' | |
repr_str += f'scale_prob={self.scale_prob}, ' | |
repr_str += f'scale_factor={self.scale_factor}, ' | |
repr_str += f'rotate_prob={self.rotate_prob}, ' | |
repr_str += f'rotate_factor={self.rotate_factor})' | |
return repr_str | |
class Albumentation(BaseTransform): | |
"""Albumentation augmentation (pixel-level transforms only). | |
Adds custom pixel-level transformations from Albumentations library. | |
Please visit `https://albumentations.ai/docs/` | |
to get more information. | |
Note: we only support pixel-level transforms. | |
Please visit `https://github.com/albumentations-team/` | |
`albumentations#pixel-level-transforms` | |
to get more information about pixel-level transforms. | |
Required Keys: | |
- img | |
Modified Keys: | |
- img | |
Args: | |
transforms (List[dict]): A list of Albumentation transforms. | |
An example of ``transforms`` is as followed: | |
.. code-block:: python | |
[ | |
dict( | |
type='RandomBrightnessContrast', | |
brightness_limit=[0.1, 0.3], | |
contrast_limit=[0.1, 0.3], | |
p=0.2), | |
dict(type='ChannelShuffle', p=0.1), | |
dict( | |
type='OneOf', | |
transforms=[ | |
dict(type='Blur', blur_limit=3, p=1.0), | |
dict(type='MedianBlur', blur_limit=3, p=1.0) | |
], | |
p=0.1), | |
] | |
keymap (dict | None): key mapping from ``input key`` to | |
``albumentation-style key``. | |
Defaults to None, which will use {'img': 'image'}. | |
""" | |
def __init__(self, | |
transforms: List[dict], | |
keymap: Optional[dict] = None) -> None: | |
if albumentations is None: | |
raise RuntimeError('albumentations is not installed') | |
self.transforms = transforms | |
self.aug = albumentations.Compose( | |
[self.albu_builder(t) for t in self.transforms]) | |
if not keymap: | |
self.keymap_to_albu = { | |
'img': 'image', | |
} | |
else: | |
self.keymap_to_albu = keymap | |
def albu_builder(self, cfg: dict) -> albumentations: | |
"""Import a module from albumentations. | |
It resembles some of :func:`build_from_cfg` logic. | |
Args: | |
cfg (dict): Config dict. It should at least contain the key "type". | |
Returns: | |
albumentations.BasicTransform: The constructed transform object | |
""" | |
assert isinstance(cfg, dict) and 'type' in cfg | |
args = cfg.copy() | |
obj_type = args.pop('type') | |
if mmengine.is_str(obj_type): | |
if albumentations is None: | |
raise RuntimeError('albumentations is not installed') | |
rank, _ = get_dist_info() | |
if rank == 0 and not hasattr( | |
albumentations.augmentations.transforms, obj_type): | |
warnings.warn( | |
f'{obj_type} is not pixel-level transformations. ' | |
'Please use with caution.') | |
obj_cls = getattr(albumentations, obj_type) | |
else: | |
raise TypeError(f'type must be a str, but got {type(obj_type)}') | |
if 'transforms' in args: | |
args['transforms'] = [ | |
self.albu_builder(transform) | |
for transform in args['transforms'] | |
] | |
return obj_cls(**args) | |
def transform(self, results: dict) -> dict: | |
"""The transform function of :class:`Albumentation` to apply | |
albumentations transforms. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Return: | |
dict: updated result dict. | |
""" | |
# map result dict to albumentations format | |
results_albu = {} | |
for k, v in self.keymap_to_albu.items(): | |
assert k in results, \ | |
f'The `{k}` is required to perform albumentations transforms' | |
results_albu[v] = results[k] | |
# Apply albumentations transforms | |
results_albu = self.aug(**results_albu) | |
# map the albu results back to the original format | |
for k, v in self.keymap_to_albu.items(): | |
results[k] = results_albu[v] | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' | |
return repr_str | |
class PhotometricDistortion(BaseTransform): | |
"""Apply photometric distortion to image sequentially, every transformation | |
is applied with a probability of 0.5. The position of random contrast is in | |
second or second to last. | |
1. random brightness | |
2. random contrast (mode 0) | |
3. convert color from BGR to HSV | |
4. random saturation | |
5. random hue | |
6. convert color from HSV to BGR | |
7. random contrast (mode 1) | |
8. randomly swap channels | |
Required Keys: | |
- img | |
Modified Keys: | |
- img | |
Args: | |
brightness_delta (int): delta of brightness. | |
contrast_range (tuple): range of contrast. | |
saturation_range (tuple): range of saturation. | |
hue_delta (int): delta of hue. | |
""" | |
def __init__(self, | |
brightness_delta: int = 32, | |
contrast_range: Sequence[Number] = (0.5, 1.5), | |
saturation_range: Sequence[Number] = (0.5, 1.5), | |
hue_delta: int = 18) -> None: | |
self.brightness_delta = brightness_delta | |
self.contrast_lower, self.contrast_upper = contrast_range | |
self.saturation_lower, self.saturation_upper = saturation_range | |
self.hue_delta = hue_delta | |
def _random_flags(self) -> Sequence[Number]: | |
"""Generate the random flags for subsequent transforms. | |
Returns: | |
Sequence[Number]: a sequence of numbers that indicate whether to | |
do the corresponding transforms. | |
""" | |
# contrast_mode == 0 --> do random contrast first | |
# contrast_mode == 1 --> do random contrast last | |
contrast_mode = np.random.randint(2) | |
# whether to apply brightness distortion | |
brightness_flag = np.random.randint(2) | |
# whether to apply contrast distortion | |
contrast_flag = np.random.randint(2) | |
# the mode to convert color from BGR to HSV | |
hsv_mode = np.random.randint(4) | |
# whether to apply channel swap | |
swap_flag = np.random.randint(2) | |
# the beta in `self._convert` to be added to image array | |
# in brightness distortion | |
brightness_beta = np.random.uniform(-self.brightness_delta, | |
self.brightness_delta) | |
# the alpha in `self._convert` to be multiplied to image array | |
# in contrast distortion | |
contrast_alpha = np.random.uniform(self.contrast_lower, | |
self.contrast_upper) | |
# the alpha in `self._convert` to be multiplied to image array | |
# in saturation distortion to hsv-formatted img | |
saturation_alpha = np.random.uniform(self.saturation_lower, | |
self.saturation_upper) | |
# delta of hue to add to image array in hue distortion | |
hue_delta = np.random.randint(-self.hue_delta, self.hue_delta) | |
# the random permutation of channel order | |
swap_channel_order = np.random.permutation(3) | |
return (contrast_mode, brightness_flag, contrast_flag, hsv_mode, | |
swap_flag, brightness_beta, contrast_alpha, saturation_alpha, | |
hue_delta, swap_channel_order) | |
def _convert(self, | |
img: np.ndarray, | |
alpha: float = 1, | |
beta: float = 0) -> np.ndarray: | |
"""Multiple with alpha and add beta with clip. | |
Args: | |
img (np.ndarray): The image array. | |
alpha (float): The random multiplier. | |
beta (float): The random offset. | |
Returns: | |
np.ndarray: The updated image array. | |
""" | |
img = img.astype(np.float32) * alpha + beta | |
img = np.clip(img, 0, 255) | |
return img.astype(np.uint8) | |
def transform(self, results: dict) -> dict: | |
"""The transform function of :class:`PhotometricDistortion` to perform | |
photometric distortion on images. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
Args: | |
results (dict): Result dict from the data pipeline. | |
Returns: | |
dict: Result dict with images distorted. | |
""" | |
assert 'img' in results, '`img` is not found in results' | |
img = results['img'] | |
(contrast_mode, brightness_flag, contrast_flag, hsv_mode, swap_flag, | |
brightness_beta, contrast_alpha, saturation_alpha, hue_delta, | |
swap_channel_order) = self._random_flags() | |
# random brightness distortion | |
if brightness_flag: | |
img = self._convert(img, beta=brightness_beta) | |
# contrast_mode == 0 --> do random contrast first | |
# contrast_mode == 1 --> do random contrast last | |
if contrast_mode == 1: | |
if contrast_flag: | |
img = self._convert(img, alpha=contrast_alpha) | |
if hsv_mode: | |
# random saturation/hue distortion | |
img = mmcv.bgr2hsv(img) | |
if hsv_mode == 1 or hsv_mode == 3: | |
# apply saturation distortion to hsv-formatted img | |
img[:, :, 1] = self._convert( | |
img[:, :, 1], alpha=saturation_alpha) | |
if hsv_mode == 2 or hsv_mode == 3: | |
# apply hue distortion to hsv-formatted img | |
img[:, :, 0] = img[:, :, 0].astype(int) + hue_delta | |
img = mmcv.hsv2bgr(img) | |
if contrast_mode == 1: | |
if contrast_flag: | |
img = self._convert(img, alpha=contrast_alpha) | |
# randomly swap channels | |
if swap_flag: | |
img = img[..., swap_channel_order] | |
results['img'] = img | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += (f'(brightness_delta={self.brightness_delta}, ' | |
f'contrast_range=({self.contrast_lower}, ' | |
f'{self.contrast_upper}), ' | |
f'saturation_range=({self.saturation_lower}, ' | |
f'{self.saturation_upper}), ' | |
f'hue_delta={self.hue_delta})') | |
return repr_str | |
class GenerateTarget(BaseTransform): | |
"""Encode keypoints into Target. | |
The generated target is usually the supervision signal of the model | |
learning, e.g. heatmaps or regression labels. | |
Required Keys: | |
- keypoints | |
- keypoints_visible | |
- dataset_keypoint_weights | |
Added Keys: | |
- The keys of the encoded items from the codec will be updated into | |
the results, e.g. ``'heatmaps'`` or ``'keypoint_weights'``. See | |
the specific codec for more details. | |
Args: | |
encoder (dict | list[dict]): The codec config for keypoint encoding. | |
Both single encoder and multiple encoders (given as a list) are | |
supported | |
multilevel (bool): Determine the method to handle multiple encoders. | |
If ``multilevel==True``, generate multilevel targets from a group | |
of encoders of the same type (e.g. multiple :class:`MSRAHeatmap` | |
encoders with different sigma values); If ``multilevel==False``, | |
generate combined targets from a group of different encoders. This | |
argument will have no effect in case of single encoder. Defaults | |
to ``False`` | |
use_dataset_keypoint_weights (bool): Whether use the keypoint weights | |
from the dataset meta information. Defaults to ``False`` | |
target_type (str, deprecated): This argument is deprecated and has no | |
effect. Defaults to ``None`` | |
""" | |
def __init__(self, | |
encoder: MultiConfig, | |
target_type: Optional[str] = None, | |
multilevel: bool = False, | |
use_dataset_keypoint_weights: bool = False) -> None: | |
super().__init__() | |
if target_type is not None: | |
rank, _ = get_dist_info() | |
if rank == 0: | |
warnings.warn( | |
'The argument `target_type` is deprecated in' | |
' GenerateTarget. The target type and encoded ' | |
'keys will be determined by encoder(s).', | |
DeprecationWarning) | |
self.encoder_cfg = deepcopy(encoder) | |
self.multilevel = multilevel | |
self.use_dataset_keypoint_weights = use_dataset_keypoint_weights | |
if isinstance(self.encoder_cfg, list): | |
self.encoder = [ | |
KEYPOINT_CODECS.build(cfg) for cfg in self.encoder_cfg | |
] | |
else: | |
assert not self.multilevel, ( | |
'Need multiple encoder configs if ``multilevel==True``') | |
self.encoder = KEYPOINT_CODECS.build(self.encoder_cfg) | |
def transform(self, results: Dict) -> Optional[dict]: | |
"""The transform function of :class:`GenerateTarget`. | |
See ``transform()`` method of :class:`BaseTransform` for details. | |
""" | |
if results.get('transformed_keypoints', None) is not None: | |
# use keypoints transformed by TopdownAffine | |
keypoints = results['transformed_keypoints'] | |
elif results.get('keypoints', None) is not None: | |
# use original keypoints | |
keypoints = results['keypoints'] | |
else: | |
raise ValueError( | |
'GenerateTarget requires \'transformed_keypoints\' or' | |
' \'keypoints\' in the results.') | |
keypoints_visible = results['keypoints_visible'] | |
# Encoded items from the encoder(s) will be updated into the results. | |
# Please refer to the document of the specific codec for details about | |
# encoded items. | |
if not isinstance(self.encoder, list): | |
# For single encoding, the encoded items will be directly added | |
# into results. | |
auxiliary_encode_kwargs = { | |
key: results[key] | |
for key in self.encoder.auxiliary_encode_keys | |
} | |
encoded = self.encoder.encode( | |
keypoints=keypoints, | |
keypoints_visible=keypoints_visible, | |
**auxiliary_encode_kwargs) | |
else: | |
encoded_list = [] | |
for _encoder in self.encoder: | |
auxiliary_encode_kwargs = { | |
key: results[key] | |
for key in _encoder.auxiliary_encode_keys | |
} | |
encoded_list.append( | |
_encoder.encode( | |
keypoints=keypoints, | |
keypoints_visible=keypoints_visible, | |
**auxiliary_encode_kwargs)) | |
if self.multilevel: | |
# For multilevel encoding, the encoded items from each encoder | |
# should have the same keys. | |
keys = encoded_list[0].keys() | |
if not all(_encoded.keys() == keys | |
for _encoded in encoded_list): | |
raise ValueError( | |
'Encoded items from all encoders must have the same ' | |
'keys if ``multilevel==True``.') | |
encoded = { | |
k: [_encoded[k] for _encoded in encoded_list] | |
for k in keys | |
} | |
else: | |
# For combined encoding, the encoded items from different | |
# encoders should have no overlapping items, except for | |
# `keypoint_weights`. If multiple `keypoint_weights` are given, | |
# they will be multiplied as the final `keypoint_weights`. | |
encoded = dict() | |
keypoint_weights = [] | |
for _encoded in encoded_list: | |
for key, value in _encoded.items(): | |
if key == 'keypoint_weights': | |
keypoint_weights.append(value) | |
elif key not in encoded: | |
encoded[key] = value | |
else: | |
raise ValueError( | |
f'Overlapping item "{key}" from multiple ' | |
'encoders, which is not supported when ' | |
'``multilevel==False``') | |
if keypoint_weights: | |
encoded['keypoint_weights'] = keypoint_weights | |
if self.use_dataset_keypoint_weights and 'keypoint_weights' in encoded: | |
if isinstance(encoded['keypoint_weights'], list): | |
for w in encoded['keypoint_weights']: | |
w *= results['dataset_keypoint_weights'] | |
else: | |
encoded['keypoint_weights'] *= results[ | |
'dataset_keypoint_weights'] | |
results.update(encoded) | |
return results | |
def __repr__(self) -> str: | |
"""print the basic information of the transform. | |
Returns: | |
str: Formatted string. | |
""" | |
repr_str = self.__class__.__name__ | |
repr_str += (f'(encoder={str(self.encoder_cfg)}, ') | |
repr_str += ('use_dataset_keypoint_weights=' | |
f'{self.use_dataset_keypoint_weights})') | |
return repr_str | |