|
|
|
|
|
import torch |
|
|
|
|
|
class ImageResizeTransform: |
|
""" |
|
Transform that resizes images loaded from a dataset |
|
(BGR data in NCHW channel order, typically uint8) to a format ready to be |
|
consumed by DensePose training (BGR float32 data in NCHW channel order) |
|
""" |
|
|
|
def __init__(self, min_size: int = 800, max_size: int = 1333): |
|
self.min_size = min_size |
|
self.max_size = max_size |
|
|
|
def __call__(self, images: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
images (torch.Tensor): tensor of size [N, 3, H, W] that contains |
|
BGR data (typically in uint8) |
|
Returns: |
|
images (torch.Tensor): tensor of size [N, 3, H1, W1] where |
|
H1 and W1 are chosen to respect the specified min and max sizes |
|
and preserve the original aspect ratio, the data channels |
|
follow BGR order and the data type is `torch.float32` |
|
""" |
|
|
|
images = images.float() |
|
min_size = min(images.shape[-2:]) |
|
max_size = max(images.shape[-2:]) |
|
scale = min(self.min_size / min_size, self.max_size / max_size) |
|
images = torch.nn.functional.interpolate( |
|
images, |
|
scale_factor=scale, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
return images |
|
|