import torch from torchvision.transforms import functional as TF class Compose: """Composes several transforms together.""" def __init__(self, transforms): self.transforms = transforms def __call__(self, image, boxes): for t in self.transforms: image, boxes = t(image, boxes) return image, boxes class RandomHorizontalFlip: """Randomly horizontally flips the image along with the bounding boxes.""" def __init__(self, p=0.5): self.p = p def __call__(self, image, boxes): if torch.rand(1) < self.p: image = TF.hflip(image) # Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax] boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]] return image, boxes class RandomVerticalFlip: """Randomly vertically flips the image along with the bounding boxes.""" def __init__(self, p=0.5): self.p = p def __call__(self, image, boxes): if torch.rand(1) < self.p: image = TF.vflip(image) # Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax] boxes[:, [2, 4]] = 1 - boxes[:, [2, 4]] return image, boxes