|
from torchvision import transforms |
|
from cwm.data.transforms import * |
|
from cwm.data.dataset import ContextAndTargetVideoDataset |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from cwm.data.masking_generator import RotatedTableMaskingGenerator |
|
|
|
class DataAugmentationForVideoMAE(object): |
|
def __init__(self, augmentation_type, input_size, augmentation_scales): |
|
|
|
transform_list = [] |
|
|
|
self.scale = GroupScale(input_size) |
|
transform_list.append(self.scale) |
|
|
|
if augmentation_type == 'multiscale': |
|
self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales)) |
|
elif augmentation_type == 'center': |
|
self.train_augmentation = GroupCenterCrop(input_size) |
|
|
|
transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)]) |
|
|
|
|
|
normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
transform_list.append(normalize) |
|
|
|
self.transform = transforms.Compose(transform_list) |
|
|
|
def __call__(self, images): |
|
process_data, _ = self.transform(images) |
|
return process_data |
|
|
|
def __repr__(self): |
|
repr = "(DataAugmentationForVideoMAE,\n" |
|
repr += " transform = %s,\n" % str(self.transform) |
|
repr += ")" |
|
return repr |
|
|
|
|
|
def build_pretraining_dataset(args): |
|
|
|
dataset_list = [] |
|
data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales) |
|
|
|
mask_generator = RotatedTableMaskingGenerator( |
|
input_size=args.mask_input_size, |
|
mask_ratio=args.mask_ratio, |
|
tube_length=args.tubelet_size, |
|
batch_size=args.batch_size, |
|
mask_type=args.mask_type |
|
) |
|
|
|
for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list: |
|
dataset = ContextAndTargetVideoDataset( |
|
root=None, |
|
setting=data_path, |
|
video_ext='mp4', |
|
is_color=True, |
|
modality='rgb', |
|
context_length=args.context_frames, |
|
target_length=args.target_frames, |
|
step_units=args.temporal_units, |
|
new_step=args.sampling_rate, |
|
context_target_gap=args.context_target_gap, |
|
transform=data_transform, |
|
randomize_interframes=False, |
|
channels_first=True, |
|
temporal_jitter=False, |
|
train=True, |
|
mask_generator=mask_generator, |
|
) |
|
dataset_list.append(dataset) |
|
dataset = torch.utils.data.ConcatDataset(dataset_list) |
|
return dataset |
|
|