Spaces:
Build error
Build error
import torch | |
from . import transforms as T | |
class DetectionPresetTrain: | |
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): | |
if data_augmentation == "hflip": | |
self.transforms = T.Compose( | |
[ | |
T.RandomHorizontalFlip(p=hflip_prob), | |
T.PILToTensor(), | |
T.ConvertImageDtype(torch.float), | |
] | |
) | |
elif data_augmentation == "ssd": | |
self.transforms = T.Compose( | |
[ | |
T.RandomPhotometricDistort(), | |
T.RandomZoomOut(fill=list(mean)), | |
T.RandomIoUCrop(), | |
T.RandomHorizontalFlip(p=hflip_prob), | |
T.PILToTensor(), | |
T.ConvertImageDtype(torch.float), | |
] | |
) | |
elif data_augmentation == "ssdlite": | |
self.transforms = T.Compose( | |
[ | |
T.RandomIoUCrop(), | |
T.RandomHorizontalFlip(p=hflip_prob), | |
T.PILToTensor(), | |
T.ConvertImageDtype(torch.float), | |
] | |
) | |
else: | |
raise ValueError( | |
f'Unknown data augmentation policy "{data_augmentation}"') | |
def __call__(self, img, target): | |
return self.transforms(img, target) | |
class DetectionPresetEval: | |
def __init__(self): | |
self.transforms = T.ToTensor() | |
def __call__(self, img, target): | |
return self.transforms(img, target) | |