File size: 1,445 Bytes
da32488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright (c) Facebook, Inc. and its affiliates.

import torch


class ImageResizeTransform:
    """

    Transform that resizes images loaded from a dataset

    (BGR data in NCHW channel order, typically uint8) to a format ready to be

    consumed by DensePose training (BGR float32 data in NCHW channel order)

    """

    def __init__(self, min_size: int = 800, max_size: int = 1333):
        self.min_size = min_size
        self.max_size = max_size

    def __call__(self, images: torch.Tensor) -> torch.Tensor:
        """

        Args:

            images (torch.Tensor): tensor of size [N, 3, H, W] that contains

                BGR data (typically in uint8)

        Returns:

            images (torch.Tensor): tensor of size [N, 3, H1, W1] where

                H1 and W1 are chosen to respect the specified min and max sizes

                and preserve the original aspect ratio, the data channels

                follow BGR order and the data type is `torch.float32`

        """
        # resize with min size
        images = images.float()
        min_size = min(images.shape[-2:])
        max_size = max(images.shape[-2:])
        scale = min(self.min_size / min_size, self.max_size / max_size)
        images = torch.nn.functional.interpolate(
            images,
            scale_factor=scale,
            mode="bilinear",
            align_corners=False,
        )
        return images