# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional, Tuple import torch import torch.nn.functional as F from torch import Tensor def flip_heatmaps(heatmaps: Tensor, flip_indices: Optional[List[int]] = None, flip_mode: str = 'heatmap', shift_heatmap: bool = True): """Flip heatmaps for test-time augmentation. Args: heatmaps (Tensor): The heatmaps to flip. Should be a tensor in shape [B, C, H, W] flip_indices (List[int]): The indices of each keypoint's symmetric keypoint. Defaults to ``None`` flip_mode (str): Specify the flipping mode. Options are: - ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps of symmetric keypoints according to ``flip_indices`` - ``'udp_combined'``: similar to ``'heatmap'`` mode but further flip the x_offset values - ``'offset'``: horizontally flip the offset fields and swap heatmaps of symmetric keypoints according to ``flip_indices``. x_offset values are also reversed shift_heatmap (bool): Shift the flipped heatmaps to align with the original heatmaps and improve accuracy. Defaults to ``True`` Returns: Tensor: flipped heatmaps in shape [B, C, H, W] """ if flip_mode == 'heatmap': heatmaps = heatmaps.flip(-1) if flip_indices is not None: assert len(flip_indices) == heatmaps.shape[1] heatmaps = heatmaps[:, flip_indices] elif flip_mode == 'udp_combined': B, C, H, W = heatmaps.shape heatmaps = heatmaps.view(B, C // 3, 3, H, W) heatmaps = heatmaps.flip(-1) if flip_indices is not None: assert len(flip_indices) == C // 3 heatmaps = heatmaps[:, flip_indices] heatmaps[:, :, 1] = -heatmaps[:, :, 1] heatmaps = heatmaps.view(B, C, H, W) elif flip_mode == 'offset': B, C, H, W = heatmaps.shape heatmaps = heatmaps.view(B, C // 2, -1, H, W) heatmaps = heatmaps.flip(-1) if flip_indices is not None: assert len(flip_indices) == C // 2 heatmaps = heatmaps[:, flip_indices] heatmaps[:, :, 0] = -heatmaps[:, :, 0] heatmaps = heatmaps.view(B, C, H, W) else: raise ValueError(f'Invalid flip_mode value "{flip_mode}"') if shift_heatmap: # clone data to avoid unexpected in-place operation when using CPU heatmaps[..., 1:] = heatmaps[..., :-1].clone() return heatmaps def flip_vectors(x_labels: Tensor, y_labels: Tensor, flip_indices: List[int]): """Flip instance-level labels in specific axis for test-time augmentation. Args: x_labels (Tensor): The vector labels in x-axis to flip. Should be a tensor in shape [B, C, Wx] y_labels (Tensor): The vector labels in y-axis to flip. Should be a tensor in shape [B, C, Wy] flip_indices (List[int]): The indices of each keypoint's symmetric keypoint """ assert x_labels.ndim == 3 and y_labels.ndim == 3 assert len(flip_indices) == x_labels.shape[1] and len( flip_indices) == y_labels.shape[1] x_labels = x_labels[:, flip_indices].flip(-1) y_labels = y_labels[:, flip_indices] return x_labels, y_labels def flip_coordinates(coords: Tensor, flip_indices: List[int], shift_coords: bool, input_size: Tuple[int, int]): """Flip normalized coordinates for test-time augmentation. Args: coords (Tensor): The coordinates to flip. Should be a tensor in shape [B, K, D] flip_indices (List[int]): The indices of each keypoint's symmetric keypoint shift_coords (bool): Shift the flipped coordinates to align with the original coordinates and improve accuracy. Defaults to ``True`` input_size (Tuple[int, int]): The size of input image in [w, h] """ assert coords.ndim == 3 assert len(flip_indices) == coords.shape[1] coords[:, :, 0] = 1.0 - coords[:, :, 0] if shift_coords: img_width = input_size[0] coords[:, :, 0] -= 1.0 / img_width coords = coords[:, flip_indices] return coords def aggregate_heatmaps(heatmaps: List[Tensor], size: Optional[Tuple[int, int]], align_corners: bool = False, mode: str = 'average'): """Aggregate multiple heatmaps. Args: heatmaps (List[Tensor]): Multiple heatmaps to aggregate. Each should be in shape (B, C, H, W) size (Tuple[int, int], optional): The target size in (w, h). All heatmaps will be resized to the target size. If not given, the first heatmap tensor's width and height will be used as the target size. Defaults to ``None`` align_corners (bool): Whether align corners when resizing heatmaps. Defaults to ``False`` mode (str): Aggregation mode in one of the following: - ``'average'``: Get average of heatmaps. All heatmaps mush have the same channel number - ``'concat'``: Concate the heatmaps at the channel dim """ if mode not in {'average', 'concat'}: raise ValueError(f'Invalid aggregation mode `{mode}`') if size is None: h, w = heatmaps[0].shape[2:4] else: w, h = size for i, _heatmaps in enumerate(heatmaps): assert _heatmaps.ndim == 4 if mode == 'average': assert _heatmaps.shape[:2] == heatmaps[0].shape[:2] else: assert _heatmaps.shape[0] == heatmaps[0].shape[0] if _heatmaps.shape[2:4] != (h, w): heatmaps[i] = F.interpolate( _heatmaps, size=(h, w), mode='bilinear', align_corners=align_corners) if mode == 'average': output = sum(heatmaps).div(len(heatmaps)) elif mode == 'concat': output = torch.cat(heatmaps, dim=1) else: raise ValueError() return output