import torch from . import transforms as T class DetectionPresetTrain: def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): if data_augmentation == "hflip": self.transforms = T.Compose( [ T.RandomHorizontalFlip(p=hflip_prob), T.PILToTensor(), T.ConvertImageDtype(torch.float), ] ) elif data_augmentation == "ssd": self.transforms = T.Compose( [ T.RandomPhotometricDistort(), T.RandomZoomOut(fill=list(mean)), T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), T.PILToTensor(), T.ConvertImageDtype(torch.float), ] ) elif data_augmentation == "ssdlite": self.transforms = T.Compose( [ T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), T.PILToTensor(), T.ConvertImageDtype(torch.float), ] ) else: raise ValueError( f'Unknown data augmentation policy "{data_augmentation}"') def __call__(self, img, target): return self.transforms(img, target) class DetectionPresetEval: def __init__(self): self.transforms = T.ToTensor() def __call__(self, img, target): return self.transforms(img, target)