Nutrigenics-chatbot / src /data /transforms.py
OmkarThawakar
initail commit
ed00004
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from src.data.randaugment import RandomAugment
normalize = transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
)
class transform_train:
def __init__(self, image_size=384, min_scale=0.5):
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
image_size,
scale=(min_scale, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
transforms.RandomHorizontalFlip(),
RandomAugment(
2,
5,
isPIL=True,
augs=[
"Identity",
"AutoContrast",
"Brightness",
"Sharpness",
"Equalize",
"ShearX",
"ShearY",
"TranslateX",
"TranslateY",
"Rotate",
],
),
transforms.ToTensor(),
normalize,
]
)
def __call__(self, img):
return self.transform(img)
class transform_test(transforms.Compose):
def __init__(self, image_size=384):
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
normalize,
]
)
def __call__(self, img):
return self.transform(img)