import random import os import math import mmcv import torch import numpy as np import torchvision.transforms as T from torchvision import datasets from torch.utils.data import Dataset from megatron.data.autoaugment import ImageNetPolicy from tasks.vision.segmentation.cityscapes import Cityscapes import tasks.vision.segmentation.transforms as ET from megatron.data.autoaugment import ImageNetPolicy from megatron import get_args from PIL import Image, ImageOps class VitSegmentationJointTransform(): def __init__(self, train=True, resolution=None): self.train = train if self.train: self.transform0 = ET.RandomSizeAndCrop(resolution) self.transform1 = ET.RandomHorizontallyFlip() def __call__(self, img, mask): if self.train: img, mask = self.transform0(img, mask) img, mask = self.transform1(img, mask) return img, mask class VitSegmentationImageTransform(): def __init__(self, train=True, resolution=None): args = get_args() self.train = train assert args.fp16 or args.bf16 self.data_type = torch.half if args.fp16 else torch.bfloat16 self.mean_std = args.mean_std if self.train: assert resolution is not None self.transform = T.Compose([ ET.PhotoMetricDistortion(), T.ToTensor(), T.Normalize(*self.mean_std), T.ConvertImageDtype(self.data_type) ]) else: self.transform = T.Compose([ T.ToTensor(), T.Normalize(*self.mean_std), T.ConvertImageDtype(self.data_type) ]) def __call__(self, input): output = self.transform(input) return output class VitSegmentationTargetTransform(): def __init__(self, train=True, resolution=None): self.train = train def __call__(self, input): output = torch.from_numpy(np.array(input, dtype=np.int32)).long() return output class RandomSeedSegmentationDataset(Dataset): def __init__(self, dataset, joint_transform, image_transform, target_transform): args = get_args() self.base_seed = args.seed self.curr_seed = self.base_seed self.dataset = dataset self.joint_transform = joint_transform self.image_transform = image_transform self.target_transform = target_transform def __len__(self): return len(self.dataset) def set_epoch(self, epoch): self.curr_seed = self.base_seed + 100 * epoch def __getitem__(self, idx): seed = idx + self.curr_seed img, mask = self.dataset[idx] torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) img, mask = self.joint_transform(img, mask) img = self.image_transform(img) mask = self.target_transform(mask) return img, mask def build_cityscapes_train_valid_datasets(data_path, image_size): args = get_args() args.num_classes = Cityscapes.num_classes args.ignore_index = Cityscapes.ignore_index args.color_table = Cityscapes.color_table args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = \ VitSegmentationJointTransform(train=True, resolution=image_size) val_joint_transform = \ VitSegmentationJointTransform(train=False, resolution=image_size) train_image_transform = \ VitSegmentationImageTransform(train=True, resolution=image_size) val_image_transform = \ VitSegmentationImageTransform(train=False, resolution=image_size) train_target_transform = \ VitSegmentationTargetTransform(train=True, resolution=image_size) val_target_transform = \ VitSegmentationTargetTransform(train=False, resolution=image_size) # training dataset train_data = Cityscapes( root=data_path[0], split='train', mode='fine', resolution=image_size ) train_data = RandomSeedSegmentationDataset( train_data, joint_transform=train_joint_transform, image_transform=train_image_transform, target_transform=train_target_transform) # validation dataset val_data = Cityscapes( root=data_path[0], split='val', mode='fine', resolution=image_size ) val_data = RandomSeedSegmentationDataset( val_data, joint_transform=val_joint_transform, image_transform=val_image_transform, target_transform=val_target_transform) return train_data, val_data def build_train_valid_datasets(data_path, image_size): return build_cityscapes_train_valid_datasets(data_path, image_size)