YOLO / utils /dataargument.py
henry000's picture
✨ [Add] A instance of dataaugments
d8aafaa
raw
history blame
775 Bytes
import torch
from torchvision.transforms import functional as TF
class Compose:
"""Composes several transforms together."""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, boxes):
for t in self.transforms:
image, boxes = t(image, boxes)
return image, boxes
class RandomHorizontalFlip:
"""Randomly horizontally flips the image along with the bounding boxes."""
def __init__(self, p=0.5):
self.p = p
def __call__(self, image, boxes):
if torch.rand(1) < self.p:
image = TF.hflip(image)
# Assuming boxes are in the format [cls, xmin, ymin, xmax, ymax]
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
return image, boxes