File size: 775 Bytes
d8aafaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torchvision.transforms import functional as TF


class Compose:
    """Composes several transforms together."""

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, boxes):
        for t in self.transforms:
            image, boxes = t(image, boxes)
        return image, boxes


class RandomHorizontalFlip:
    """Randomly horizontally flips the image along with the bounding boxes."""

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, image, boxes):
        if torch.rand(1) < self.p:
            image = TF.hflip(image)
            # Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax]
            boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
        return image, boxes