unit_test / uniperceiver /datasets /custom_transforms.py
herrius's picture
Upload 259 files
32b542e
from torchvision import transforms as T
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
from PIL import Image
BICUBIC = Image.BICUBIC
def clip_transforms(mode='train', img_size=224, flip_prob=0.5):
assert mode in ['train', 'test', 'val']
min_size = img_size
max_size = img_size
# assert min_size <= max_size
normalize_transform = T.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
if mode == 'train':
transform = T.Compose(
[
T.Resize(max_size, BICUBIC),
T.RandomCrop(min_size),
T.RandomHorizontalFlip(flip_prob),
T.ToTensor(),
normalize_transform,
]
)
else:
transform = T.Compose(
[
T.Resize(max_size, BICUBIC),
T.CenterCrop(min_size),
T.ToTensor(),
normalize_transform,
]
)
return transform