|
import random |
|
import os |
|
import math |
|
import mmcv |
|
import torch |
|
import numpy as np |
|
import torchvision.transforms as T |
|
from torchvision import datasets |
|
from torch.utils.data import Dataset |
|
from megatron.data.autoaugment import ImageNetPolicy |
|
from tasks.vision.segmentation.cityscapes import Cityscapes |
|
import tasks.vision.segmentation.transforms as ET |
|
from megatron.data.autoaugment import ImageNetPolicy |
|
from megatron import get_args |
|
from PIL import Image, ImageOps |
|
|
|
|
|
class VitSegmentationJointTransform(): |
|
def __init__(self, train=True, resolution=None): |
|
self.train = train |
|
if self.train: |
|
self.transform0 = ET.RandomSizeAndCrop(resolution) |
|
self.transform1 = ET.RandomHorizontallyFlip() |
|
|
|
def __call__(self, img, mask): |
|
if self.train: |
|
img, mask = self.transform0(img, mask) |
|
img, mask = self.transform1(img, mask) |
|
return img, mask |
|
|
|
|
|
class VitSegmentationImageTransform(): |
|
def __init__(self, train=True, resolution=None): |
|
args = get_args() |
|
self.train = train |
|
assert args.fp16 or args.bf16 |
|
self.data_type = torch.half if args.fp16 else torch.bfloat16 |
|
self.mean_std = args.mean_std |
|
if self.train: |
|
assert resolution is not None |
|
self.transform = T.Compose([ |
|
ET.PhotoMetricDistortion(), |
|
T.ToTensor(), |
|
T.Normalize(*self.mean_std), |
|
T.ConvertImageDtype(self.data_type) |
|
]) |
|
else: |
|
self.transform = T.Compose([ |
|
T.ToTensor(), |
|
T.Normalize(*self.mean_std), |
|
T.ConvertImageDtype(self.data_type) |
|
]) |
|
|
|
def __call__(self, input): |
|
output = self.transform(input) |
|
return output |
|
|
|
|
|
class VitSegmentationTargetTransform(): |
|
def __init__(self, train=True, resolution=None): |
|
self.train = train |
|
|
|
def __call__(self, input): |
|
output = torch.from_numpy(np.array(input, dtype=np.int32)).long() |
|
return output |
|
|
|
|
|
class RandomSeedSegmentationDataset(Dataset): |
|
def __init__(self, |
|
dataset, |
|
joint_transform, |
|
image_transform, |
|
target_transform): |
|
|
|
args = get_args() |
|
self.base_seed = args.seed |
|
self.curr_seed = self.base_seed |
|
self.dataset = dataset |
|
self.joint_transform = joint_transform |
|
self.image_transform = image_transform |
|
self.target_transform = target_transform |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def set_epoch(self, epoch): |
|
self.curr_seed = self.base_seed + 100 * epoch |
|
|
|
def __getitem__(self, idx): |
|
seed = idx + self.curr_seed |
|
img, mask = self.dataset[idx] |
|
|
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
img, mask = self.joint_transform(img, mask) |
|
img = self.image_transform(img) |
|
mask = self.target_transform(mask) |
|
|
|
return img, mask |
|
|
|
|
|
def build_cityscapes_train_valid_datasets(data_path, image_size): |
|
args = get_args() |
|
args.num_classes = Cityscapes.num_classes |
|
args.ignore_index = Cityscapes.ignore_index |
|
args.color_table = Cityscapes.color_table |
|
args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
|
train_joint_transform = \ |
|
VitSegmentationJointTransform(train=True, resolution=image_size) |
|
val_joint_transform = \ |
|
VitSegmentationJointTransform(train=False, resolution=image_size) |
|
train_image_transform = \ |
|
VitSegmentationImageTransform(train=True, resolution=image_size) |
|
val_image_transform = \ |
|
VitSegmentationImageTransform(train=False, resolution=image_size) |
|
train_target_transform = \ |
|
VitSegmentationTargetTransform(train=True, resolution=image_size) |
|
val_target_transform = \ |
|
VitSegmentationTargetTransform(train=False, resolution=image_size) |
|
|
|
|
|
train_data = Cityscapes( |
|
root=data_path[0], |
|
split='train', |
|
mode='fine', |
|
resolution=image_size |
|
) |
|
train_data = RandomSeedSegmentationDataset( |
|
train_data, |
|
joint_transform=train_joint_transform, |
|
image_transform=train_image_transform, |
|
target_transform=train_target_transform) |
|
|
|
|
|
val_data = Cityscapes( |
|
root=data_path[0], |
|
split='val', |
|
mode='fine', |
|
resolution=image_size |
|
) |
|
|
|
val_data = RandomSeedSegmentationDataset( |
|
val_data, |
|
joint_transform=val_joint_transform, |
|
image_transform=val_image_transform, |
|
target_transform=val_target_transform) |
|
|
|
return train_data, val_data |
|
|
|
|
|
def build_train_valid_datasets(data_path, image_size): |
|
return build_cityscapes_train_valid_datasets(data_path, image_size) |
|
|