Spaces:
Build error
Build error
from torch.utils.data import DataLoader, Dataset, Sampler | |
from pathlib import Path | |
from collections import defaultdict | |
import json | |
import random | |
from multiprocessing import Pool | |
import h5py | |
import pickle | |
import math | |
from tqdm import tqdm | |
import torch | |
import numpy as np | |
from copy import deepcopy | |
from PIL import Image | |
from torch.utils.data.distributed import DistributedSampler | |
# from transformers import T5TokenizerFast, BartTokenizer | |
# from tokenization import VLT5TokenizerFast | |
# from vis_encoder import _transform | |
# from vqa_raw_data import augmentation_transform | |
from dataset.randaugment import RandomAugment | |
import torch | |
from torch import nn | |
from torchvision import transforms | |
import os | |
import re | |
# project_dir = Path(__file__).resolve().parent.parent # VLT5 | |
# workspace_dir = project_dir.parent | |
# dataset_dir = workspace_dir.joinpath('datasets/').resolve() | |
# coco_dir = dataset_dir.joinpath('COCO') | |
# vg_dir = dataset_dir.joinpath('VG') | |
# coco_img_dir = coco_dir.joinpath('images/') | |
# coco_feature_dir = coco_dir.joinpath('features') | |
class COCOCaptionFineTuneDataset(Dataset): | |
def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', | |
data_dir='/data/mshukor/data'): | |
super().__init__() | |
self.raw_dataset = raw_dataset | |
self.topk = topk | |
self.verbose = verbose | |
self.args = args | |
self.args.BUTD100 = False | |
self.mode = mode | |
dataset_dir = Path(data_dir) | |
coco_dir = dataset_dir.joinpath('COCO') | |
vg_dir = dataset_dir.joinpath('VG') | |
coco_img_dir = coco_dir.joinpath('images/') | |
coco_feature_dir = coco_dir.joinpath('features') | |
# Loading datasets to data | |
self.source = split | |
if self.verbose: | |
print('Data source: ', self.source) | |
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
self.train_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', | |
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
self.test_transform = transforms.Compose([ | |
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
# if self.args.tokenizer is None: | |
# self.args.tokenizer = self.args.backbone | |
# if 't5' in self.args.tokenizer: | |
# if self.args.use_vision: | |
# self.tokenizer = VLT5TokenizerFast.from_pretrained( | |
# args.backbone, | |
# # max_length=self.args.max_text_length, | |
# do_lower_case=self.args.do_lower_case) | |
# else: | |
# self.tokenizer = T5TokenizerFast.from_pretrained( | |
# args.backbone, | |
# # max_length=self.args.max_text_length, | |
# do_lower_case=self.args.do_lower_case) | |
# elif 'bart' in self.args.tokenizer: | |
# self.tokenizer = BartTokenizer.from_pretrained( | |
# args.backbone, | |
# # max_length=self.args.max_text_length, | |
# do_lower_case=self.args.do_lower_case) | |
# additional_special_tokens = [f'<extra_id_{i}>' for i in range(100-1, -1, -1)] + \ | |
# [f'<vis_extra_id_{i}>' for i in range(100-1, -1, -1)] | |
# special_tokens_dict = {'additional_special_tokens': additional_special_tokens} | |
# num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) | |
# if self.args.oscar_tags: | |
# # Load VG Classes | |
# vg_classes = [] | |
# with open(vg_dir.joinpath('objects_vocab.txt')) as f: | |
# for obj in f.readlines(): | |
# vg_classes.append(obj.split(',')[0].lower().strip()) | |
# self.vg_classes = vg_classes | |
data_info_path = dataset_dir.joinpath('COCO/dataset_coco.json') | |
with open(data_info_path) as f: | |
karpathy_data = json.load(f) | |
split_rename = { | |
'train': 'train', | |
'restval': 'train', | |
'val': 'val', | |
'test': 'test' | |
} | |
n_images = 0 | |
data = [] | |
for datum in karpathy_data['images']: | |
re_split = split_rename[datum['split']] | |
if re_split != self.source.split('_')[-1]: | |
continue | |
if re_split == 'train': | |
for d in datum['sentences']: | |
# if self.args.BUTD100: | |
# img_id = str(int(datum['filename'].split('.')[0].split('_')[-1])) | |
# else: | |
img_id = datum['filename'].split('.')[0] | |
new_datum = { | |
'img_id': img_id, | |
'sent': d['raw'].strip(), | |
'targets': [d['raw'].strip() for d in datum['sentences']], | |
'is_train': True, | |
} | |
data.append(new_datum) | |
else: | |
# if self.args.BUTD100: | |
# img_id = str( | |
# int(datum['filename'].split('.')[0].split('_')[-1])) | |
# else: | |
img_id = datum['filename'].split('.')[0] | |
new_datum = { | |
'img_id': img_id, | |
# 'sent': d['raw'], | |
'targets': [d['raw'].strip() for d in datum['sentences']], | |
'is_train': False, | |
} | |
data.append(new_datum) | |
n_images += 1 | |
if self.verbose: | |
print(f"{self.source} has {n_images} images") | |
print(f"Loaded {len(data)} data from", split) | |
# self.n_gpus = torch.cuda.device_count() | |
# self.rank = rank | |
if isinstance(self.topk, float) and (0 < self.topk <= 1): | |
used_samples = int(self.topk * len(data)) | |
data = random.sample(data, used_samples) | |
if self.verbose: | |
print(f"Use only {len(data)} data") | |
elif self.topk > 0: | |
data = data[:int(self.topk)] | |
if self.verbose: | |
print(f"Use only {len(data)} data") | |
self.data = data | |
if self.verbose: | |
print("# all sentences:", len(self.data)) | |
# self.n_boxes = args.n_boxes | |
self.image_size = self.args.image_size | |
if mode == "train" and self.args.use_data_augmentation: | |
self.transform = self.train_transform | |
else: | |
self.transform = self.test_transform | |
self.source_to_h5 = {} | |
# if self.args.max_n_boxes == 36: | |
self.source_to_h5.update({ | |
'train2014': coco_img_dir.joinpath(f'train2014'), | |
'val2014': coco_img_dir.joinpath(f'val2014'), | |
}) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
out_dict = {} | |
out_dict['args'] = self.args | |
datum = self.data[idx] | |
###### Image ###### | |
# if self.args.use_vision: | |
img_id = datum['img_id'] | |
out_dict['img_id'] = img_id | |
if self.args.BUTD100: | |
source = self.source | |
else: | |
if 'train' in img_id: | |
source = 'train2014' | |
elif 'val' in img_id: | |
source = 'val2014' | |
path = self.source_to_h5[source].joinpath(f"{img_id}.jpg") | |
image = Image.open(path).convert('RGB') | |
out_dict["image"] = self.transform(image) | |
# out_dict['n_boxes'] = self.args.n_boxes | |
###### Text ##### | |
# if self.args.no_prefix: | |
# input_text = '' | |
# input_ids = [] | |
# else: | |
# if self.args.prefix is None: | |
# prefix = f'{self.args.prompt}' | |
# elif self.args.prefix == 'span': | |
# prefix = "span prediction:" | |
# elif self.args.prefix == 'denoise': | |
# prefix = "denoise text: <mask>" | |
# elif self.args.prefix == 'mask': | |
# if 'bart' in self.args.tokenizer: | |
# prefix = "<mask>" | |
# input_tokens = [prefix] | |
# # if self.args.oscar_tags: | |
# # prefix = f'describe image with tags:' | |
# # input_tokens = [prefix] | |
# # obj_ids = f[f'{img_id}/obj_id'][()] | |
# # for obj_id in obj_ids: | |
# # obj = self.vg_classes[obj_id] | |
# # if obj not in input_tokens: | |
# # input_tokens.append(obj) | |
# input_text = ' '.join(input_tokens) | |
# # if 't5' in self.args.tokenizer: | |
# # input_ids = self.tokenizer.encode( | |
# # input_text, | |
# # max_length=self.args.max_text_length, truncation=True) | |
# # elif 'bart' in self.args.tokenizer: | |
# # input_ids = self.tokenizer.encode( | |
# # input_text, | |
# # max_length=self.args.max_text_length, truncation=True) | |
# # else: | |
# # input_ids = self.tokenizer.convert_tokens_to_ids( | |
# # self.tokenizer.tokenize(input_text)[:self.args.max_text_length - 1] + ['[SEP]']) | |
# out_dict['input_text'] = input_text | |
# out_dict['input_ids'] = torch.LongTensor(input_ids) | |
# out_dict['input_length'] = len(input_ids) | |
if datum['is_train']: | |
sent = datum['sent'].strip() | |
# if 't5' in self.args.tokenizer: | |
# target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True) | |
# elif 'bart' in self.args.tokenizer: | |
# target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True) | |
# assert len(target_ids) <= self.args.gen_max_length, len(target_ids) | |
out_dict['sent'] = sent | |
# out_dict['target_ids'] = torch.LongTensor(target_ids) | |
# out_dict['target_length'] = len(target_ids) | |
if 'targets' in datum: | |
out_dict['targets'] = datum['targets'] | |
return out_dict | |
def collate_fn(self, batch): | |
batch_entry = {} | |
B = len(batch) | |
# S_W_L = max(entry['input_length'] for entry in batch) | |
# input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id | |
# if self.args.no_prefix: | |
# assert input_ids.size() == (B, 0) | |
# if self.args.use_vision: | |
# pass | |
if 'target_ids' in batch[0]: | |
T_W_L = max(entry['target_length'] for entry in batch) | |
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id | |
# sentences = [] | |
targets = [] | |
img_ids = [] | |
img_paths = [] | |
input_text = [] | |
images = [] | |
sents = [] | |
for i, entry in enumerate(batch): | |
# input_ids[i, :entry['input_length']] = entry['input_ids'] | |
# if self.args.use_vision: | |
# n_boxes = entry['n_boxes'] | |
images.append(entry['image']) | |
img_ids.append(entry['img_id']) | |
# img_paths.append(entry['img_path']) | |
if 'target_ids' in entry: | |
target_ids[i, :entry['target_length']] = entry['target_ids'] | |
# if 'input_text' in entry: | |
# input_text.append(entry['input_text']) | |
# sentences.append(entry['sent']) | |
if 'targets' in entry: | |
targets.append(entry['targets']) | |
if 'sent' in entry: | |
sents.append(entry['sent']) | |
# batch_entry['input_ids'] = input_ids | |
# if 'target_ids' in batch[0]: | |
# word_mask = target_ids != self.tokenizer.pad_token_id | |
# target_ids[~word_mask] = -100 | |
# batch_entry['target_ids'] = target_ids | |
# if self.args.use_vision: | |
batch_entry['images'] = torch.stack(images) | |
batch_entry['img_id'] = img_ids | |
batch_entry['img_paths'] = img_paths | |
if 'sent' in entry: | |
batch_entry['sent'] = sents | |
# batch_entry['sent'] = sentences | |
# batch_entry['input_text'] = input_text | |
batch_entry['targets'] = targets | |
batch_entry['task'] = 'caption' | |
return batch_entry | |
def pre_caption(caption,max_words): | |
caption = re.sub( | |
r"([,.'!?\"()*#:;~])", | |
'', | |
caption.lower(), | |
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person') | |
caption = re.sub( | |
r"\s{2,}", | |
' ', | |
caption, | |
) | |
caption = caption.rstrip('\n') | |
caption = caption.strip(' ') | |
#truncate caption | |
caption_words = caption.split(' ') | |
if len(caption_words)>max_words: | |
caption = ' '.join(caption_words[:max_words]) | |
return caption | |
class CCDataset(Dataset): | |
def __init__(self, split='CC', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', | |
data_dir='/data/mshukor/data', config_dir='/data/mshukor/data/cc3m.json', max_words=30): | |
super().__init__() | |
self.raw_dataset = raw_dataset | |
self.topk = topk | |
self.verbose = verbose | |
self.args = args | |
self.mode = mode | |
data = [] | |
ann_files = [config_dir] | |
ann_file = [] | |
for p in ann_files: | |
ann_file.append(os.path.join(args.data_json_dir, p)) | |
for f in ann_file: | |
tmp = json.load(open(f,'r')) | |
data += tmp | |
print('size of', f, len(tmp)) | |
print(len(data)) | |
self.max_words = max_words | |
for e in data: | |
e['image'] = os.path.join(data_dir, ('/').join(e['image'].split('/')[4:])) | |
# Loading datasets to data | |
self.source = split | |
if self.verbose: | |
print('Data source: ', self.source) | |
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
self.train_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), | |
transforms.RandomHorizontalFlip(), | |
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', | |
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
self.test_transform = transforms.Compose([ | |
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
if isinstance(self.topk, float) and (0 < self.topk <= 1): | |
used_samples = int(self.topk * len(data)) | |
data = random.sample(data, used_samples) | |
if self.verbose: | |
print(f"Use only {len(data)} data") | |
elif self.topk > 0: | |
data = data[:int(self.topk)] | |
if self.verbose: | |
print(f"Use only {len(data)} data") | |
self.data = data | |
if self.verbose: | |
print("# all sentences:", len(self.data)) | |
self.image_size = self.args.image_size | |
if mode == "train" and self.args.use_data_augmentation: | |
self.transform = self.train_transform | |
else: | |
self.transform = self.test_transform | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
out_dict = {} | |
out_dict['args'] = self.args | |
datum = self.data[idx] | |
if type(datum['caption']) == list: | |
caption = pre_caption(random.choice(datum['caption']), self.max_words) | |
else: | |
caption = pre_caption(datum['caption'], self.max_words) | |
###### Image ###### | |
image = Image.open(datum['image']).convert('RGB') | |
img_id = datum['image'].split('/')[-1].split('.')[0] | |
out_dict['img_id'] = img_id | |
out_dict["image"] = self.transform(image) | |
out_dict['sent'] = caption | |
out_dict['targets'] = caption | |
return out_dict | |
def collate_fn(self, batch): | |
batch_entry = {} | |
B = len(batch) | |
if 'target_ids' in batch[0]: | |
T_W_L = max(entry['target_length'] for entry in batch) | |
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id | |
targets = [] | |
img_ids = [] | |
img_paths = [] | |
input_text = [] | |
images = [] | |
sents = [] | |
for i, entry in enumerate(batch): | |
images.append(entry['image']) | |
img_ids.append(entry['img_id']) | |
if 'target_ids' in entry: | |
target_ids[i, :entry['target_length']] = entry['target_ids'] | |
if 'targets' in entry: | |
targets.append(entry['targets']) | |
if 'sent' in entry: | |
sents.append(entry['sent']) | |
# if self.args.use_vision: | |
batch_entry['images'] = torch.stack(images) | |
batch_entry['img_id'] = img_ids | |
batch_entry['img_paths'] = img_paths | |
if 'sent' in entry: | |
batch_entry['sent'] = sents | |
batch_entry['targets'] = targets | |
batch_entry['task'] = 'caption' | |
return batch_entry | |
def get_loader(args, split='train', mode='train', | |
batch_size=32, workers=4, distributed=False, gpu=0, | |
topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False, config_dir=None): | |
# if 'mscoco' in split: | |
# verbose = (gpu == 0) | |
if 'CC' in split: | |
dataset = CCDataset(split, data_dir=data_dir, mode=mode, topk=topk, args=args, verbose=verbose, rank=gpu, config_dir=config_dir) | |
else: | |
dataset = COCOCaptionFineTuneDataset( | |
split, | |
# raw_dataset=_dset, | |
rank=gpu, | |
topk=topk, | |
verbose=verbose, | |
args=args, | |
mode=mode, data_dir=data_dir) | |
if distributed and mode == 'train': | |
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) | |
# train_sampler = DistributedSampler(dataset) | |
# train_sampler = RandomNonreplacmentSampler(dataset, dataset.n_iter) | |
else: | |
train_sampler = None | |
if mode == 'train': | |
loader = DataLoader( | |
dataset, batch_size=batch_size, shuffle=(train_sampler is None), | |
num_workers=workers, pin_memory=True, sampler=train_sampler, | |
collate_fn=dataset.collate_fn) | |
else: | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, shuffle=False, | |
num_workers=workers, pin_memory=True, | |
sampler=None, | |
collate_fn=dataset.collate_fn, | |
drop_last=False) | |
if verbose: | |
loader.evaluator = COCOCaptionEvaluator() | |
loader.task = 'caption' | |
return loader | |
class COCOCaptionEvaluator: | |
def __init__(self): | |
import language_evaluation | |
self.evaluator = language_evaluation.CocoEvaluator(verbose=False) | |
def evaluate(self, predicts, answers): | |
results = self.evaluator.run_evaluation(predicts, answers) | |
return results |