Shawn001's picture
Upload 53 files
c2c125c
import random
import os
import math
import mmcv
import torch
import numpy as np
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset
from megatron.data.autoaugment import ImageNetPolicy
from tasks.vision.segmentation.cityscapes import Cityscapes
import tasks.vision.segmentation.transforms as ET
from megatron.data.autoaugment import ImageNetPolicy
from megatron import get_args
from PIL import Image, ImageOps
class VitSegmentationJointTransform():
def __init__(self, train=True, resolution=None):
self.train = train
if self.train:
self.transform0 = ET.RandomSizeAndCrop(resolution)
self.transform1 = ET.RandomHorizontallyFlip()
def __call__(self, img, mask):
if self.train:
img, mask = self.transform0(img, mask)
img, mask = self.transform1(img, mask)
return img, mask
class VitSegmentationImageTransform():
def __init__(self, train=True, resolution=None):
args = get_args()
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
self.mean_std = args.mean_std
if self.train:
assert resolution is not None
self.transform = T.Compose([
ET.PhotoMetricDistortion(),
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
def __call__(self, input):
output = self.transform(input)
return output
class VitSegmentationTargetTransform():
def __init__(self, train=True, resolution=None):
self.train = train
def __call__(self, input):
output = torch.from_numpy(np.array(input, dtype=np.int32)).long()
return output
class RandomSeedSegmentationDataset(Dataset):
def __init__(self,
dataset,
joint_transform,
image_transform,
target_transform):
args = get_args()
self.base_seed = args.seed
self.curr_seed = self.base_seed
self.dataset = dataset
self.joint_transform = joint_transform
self.image_transform = image_transform
self.target_transform = target_transform
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + 100 * epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
img, mask = self.dataset[idx]
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
img, mask = self.joint_transform(img, mask)
img = self.image_transform(img)
mask = self.target_transform(mask)
return img, mask
def build_cityscapes_train_valid_datasets(data_path, image_size):
args = get_args()
args.num_classes = Cityscapes.num_classes
args.ignore_index = Cityscapes.ignore_index
args.color_table = Cityscapes.color_table
args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_joint_transform = \
VitSegmentationJointTransform(train=True, resolution=image_size)
val_joint_transform = \
VitSegmentationJointTransform(train=False, resolution=image_size)
train_image_transform = \
VitSegmentationImageTransform(train=True, resolution=image_size)
val_image_transform = \
VitSegmentationImageTransform(train=False, resolution=image_size)
train_target_transform = \
VitSegmentationTargetTransform(train=True, resolution=image_size)
val_target_transform = \
VitSegmentationTargetTransform(train=False, resolution=image_size)
# training dataset
train_data = Cityscapes(
root=data_path[0],
split='train',
mode='fine',
resolution=image_size
)
train_data = RandomSeedSegmentationDataset(
train_data,
joint_transform=train_joint_transform,
image_transform=train_image_transform,
target_transform=train_target_transform)
# validation dataset
val_data = Cityscapes(
root=data_path[0],
split='val',
mode='fine',
resolution=image_size
)
val_data = RandomSeedSegmentationDataset(
val_data,
joint_transform=val_joint_transform,
image_transform=val_image_transform,
target_transform=val_target_transform)
return train_data, val_data
def build_train_valid_datasets(data_path, image_size):
return build_cityscapes_train_valid_datasets(data_path, image_size)