Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Sequence | |
import torch | |
from torchvision import transforms | |
class GaussianBlur(transforms.RandomApply): | |
""" | |
Apply Gaussian Blur to the PIL image. | |
""" | |
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): | |
# NOTE: torchvision is applying 1 - probability to return the original image | |
keep_p = 1 - p | |
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) | |
super().__init__(transforms=[transform], p=keep_p) | |
class MaybeToTensor(transforms.ToTensor): | |
""" | |
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. | |
""" | |
def __call__(self, pic): | |
""" | |
Args: | |
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. | |
Returns: | |
Tensor: Converted image. | |
""" | |
if isinstance(pic, torch.Tensor): | |
return pic | |
return super().__call__(pic) | |
# Use timm's names | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
def make_normalize_transform( | |
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
) -> transforms.Normalize: | |
return transforms.Normalize(mean=mean, std=std) | |
# This roughly matches torchvision's preset for classification training: | |
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 | |
def make_classification_train_transform( | |
*, | |
crop_size: int = 224, | |
interpolation=transforms.InterpolationMode.BICUBIC, | |
hflip_prob: float = 0.5, | |
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
): | |
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] | |
if hflip_prob > 0.0: | |
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) | |
transforms_list.extend( | |
[ | |
MaybeToTensor(), | |
make_normalize_transform(mean=mean, std=std), | |
] | |
) | |
return transforms.Compose(transforms_list) | |
# This matches (roughly) torchvision's preset for classification evaluation: | |
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 | |
def make_classification_eval_transform( | |
*, | |
resize_size: int = 256, | |
interpolation=transforms.InterpolationMode.BICUBIC, | |
crop_size: int = 224, | |
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, | |
std: Sequence[float] = IMAGENET_DEFAULT_STD, | |
) -> transforms.Compose: | |
transforms_list = [ | |
transforms.Resize(resize_size, interpolation=interpolation), | |
transforms.CenterCrop(crop_size), | |
MaybeToTensor(), | |
make_normalize_transform(mean=mean, std=std), | |
] | |
return transforms.Compose(transforms_list) | |