Spaces:
Build error
Build error
from dataset.randaugment import RandomAugment | |
from torch.utils.data import DataLoader | |
from .vqa import vqa_dataset | |
import torch | |
from torch import nn | |
from torchvision import transforms | |
from PIL import Image | |
def create_dataset(dataset, config, data_dir='/data/mshukor/data'): | |
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
pretrain_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', | |
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
train_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', | |
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
test_transform = transforms.Compose([ | |
transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
if dataset=='vqa': | |
train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') | |
vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list']) | |
return train_dataset, vqa_test_dataset | |
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): | |
loaders = [] | |
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): | |
if is_train: | |
shuffle = (sampler is None) | |
drop_last = True | |
else: | |
shuffle = False | |
drop_last = False | |
loader = DataLoader( | |
dataset, | |
batch_size=bs, | |
num_workers=n_worker, | |
pin_memory=True, | |
sampler=sampler, | |
shuffle=shuffle, | |
collate_fn=collate_fn, | |
drop_last=drop_last, | |
) | |
loaders.append(loader) | |
return loaders | |
def create_sampler(datasets, shuffles, num_tasks, global_rank): | |
samplers = [] | |
for dataset,shuffle in zip(datasets,shuffles): | |
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) | |
samplers.append(sampler) | |
return samplers |