from torchvision import transforms import torch def one_d_image_train_aug(to_3_channels=False): mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081) return transforms.Compose([ transforms.Resize(32), # transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)), transforms.Normalize(mean, std) ]) def one_d_image_test_aug(to_3_channels=False): mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081) return transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Lambda((lambda x: torch.cat([x] * 3)) if to_3_channels else (lambda x: x)), transforms.Normalize(mean, std) ]) def cifar_like_image_train_aug(mean=None, std=None): if mean is None: mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) return transforms.Compose([ transforms.Resize(40), # NOTE: this is critical!!! or you may crop a small part of an image transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) def cifar_like_image_test_aug(mean=None, std=None): if mean is None: mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) return transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), transforms.Normalize(mean, std) ]) def imagenet_like_image_train_aug(): mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] return transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop((224, 224), padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std) ]) def imagenet_like_image_test_aug(): mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean, std) ]) def cityscapes_like_image_train_aug(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def cityscapes_like_image_test_aug(): return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def cityscapes_like_label_aug(): import numpy as np return transforms.Compose([ transforms.Resize((224, 224)), transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long()) ]) def pil_image_to_tensor(img_size=224): return transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor() ])