|
from torchvision import transforms as T |
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
BICUBIC = InterpolationMode.BICUBIC |
|
except ImportError: |
|
from PIL import Image |
|
BICUBIC = Image.BICUBIC |
|
|
|
|
|
def clip_transforms(mode='train', img_size=224, flip_prob=0.5): |
|
assert mode in ['train', 'test', 'val'] |
|
min_size = img_size |
|
max_size = img_size |
|
|
|
|
|
|
|
normalize_transform = T.Normalize( |
|
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) |
|
) |
|
|
|
if mode == 'train': |
|
transform = T.Compose( |
|
[ |
|
T.Resize(max_size, BICUBIC), |
|
T.RandomCrop(min_size), |
|
T.RandomHorizontalFlip(flip_prob), |
|
T.ToTensor(), |
|
normalize_transform, |
|
] |
|
) |
|
else: |
|
transform = T.Compose( |
|
[ |
|
T.Resize(max_size, BICUBIC), |
|
T.CenterCrop(min_size), |
|
T.ToTensor(), |
|
normalize_transform, |
|
] |
|
) |
|
return transform |
|
|