UnIVAL / tasks /pretrain_tasks /unify_task.py
mshukor
init
26fd00c
raw
history blame
13.7 kB
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from dataclasses import dataclass, field
import json
import logging
import os
import math
from typing import Optional
from fairseq.tasks import register_task
from fairseq.data import FairseqDataset, iterators
from tasks.ofa_task import OFATask, OFAConfig
from data.pretrain_data.unify_dataset import UnifyDataset
from data.file_dataset import FileDataset
logger = logging.getLogger(__name__)
@dataclass
class UnifyConfig(OFAConfig):
max_image_size: int = field(
default=512, metadata={"help": ""}
)
neg_sample_dir: Optional[str] = field(
default=None,
metadata={"help": "negative sample directory, which contains captions (taken from all image-text pairs), "
"answers (taken from VQA), "
"objects (taken form OpenImages) "},
)
code_image_size: int = field(
default=128, metadata={"help": "the resolution of the generated image in the image infilling task"}
)
pretrain_seed: int = field(
default=7,
metadata={"help": "pretrain seed"},
)
mask_ratio: float = field(
default=0.3,
metadata={"help": "fraction of words/subwords that will be masked"},
)
random_ratio: float = field(
default=0.0,
metadata={"help": "instead of using [MASK], use random token this often"},
)
keep_ratio: float = field(
default=0.0,
metadata={"help": "instead of using [MASK], keep original token this often"},
)
mask_length: str = field(
default="span-poisson",
metadata={"help": "mask length to choose ['subword', 'word', 'span-poisson']"},
)
poisson_lambda: float = field(
default=3.0,
metadata={"help": "randomly shuffle sentences for this proportion of inputs"},
)
replace_length: int = field(
default=1,
metadata={"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)"},
)
neg_captions_video: Optional[str] = field(
default=None,
metadata={"help": "negative sample file .txt, which contains captions (taken from all video-text pairs), "},
)
neg_captions_audio: Optional[str] = field(
default=None,
metadata={"help": "negative sample file .txt, which contains captions (taken from all video-text pairs), "},
)
audio_selected_cols: Optional[str] = field(
default=None,
metadata={"help": "detection data selected cols"},
)
audio_data: Optional[str] = field(
default=None,
metadata={"help": "detection data"},
)
audio_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of video examples in the batch"},
)
video_selected_cols: Optional[str] = field(
default=None,
metadata={"help": "video data selected cols"},
)
video_data: Optional[str] = field(
default=None,
metadata={"help": "video data"},
)
video_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of video examples in the batch"},
)
image_text_data: Optional[str] = field(
default=None,
metadata={"help": "cc12m data"},
)
image_text_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
other_data_cnt: int = field(
default=8,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
init_image_text_data: Optional[str] = field(
default=None,
metadata={"help": "init data"},
)
init_dataset_epoch: int = field(
default=3,
metadata={"help": "to witch from the init dataset"},
)
init_text_data: Optional[str] = field(
default=None,
metadata={"help": "init data"},
)
image_text_vqa_data: Optional[str] = field(
default=None,
metadata={"help": "vqa data"},
)
image_text_vqa_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
image_text_ground_data: Optional[str] = field(
default=None,
metadata={"help": "cc12m data"},
)
image_text_ground_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
only_video_data: Optional[str] = field(
default=None,
metadata={"help": "only video data"},
)
only_audio_data: Optional[str] = field(
default=None,
metadata={"help": "only video data"},
)
video_text_data: Optional[str] = field(
default=None,
metadata={"help": "cc12m data"},
)
video_text_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
audio_text_data: Optional[str] = field(
default=None,
metadata={"help": "cc12m data"},
)
audio_text_cnt: int = field(
default=2,
metadata={"help": "to control the ratio of image-text examples in the batch"},
)
audio_with_video: bool = field(
default=False,
metadata={"help": "audio_with_video"},
)
@register_task("unify_task", dataclass=UnifyConfig)
class UnifyTask(OFATask):
def __init__(self, cfg: UnifyConfig, src_dict, tgt_dict):
super().__init__(cfg, src_dict, tgt_dict)
self.type2ans_dict = json.load(open(os.path.join(self.cfg.neg_sample_dir, 'type2ans.json')))
self.ans2type_dict = {}
for type, answer_list in self.type2ans_dict.items():
if type == 'other':
continue
for answer in answer_list:
self.ans2type_dict[answer] = type
self.all_object_list = [
row.strip() for row in open(os.path.join(self.cfg.neg_sample_dir, 'object.txt')) if row.strip() != ''
]
neg_captions_path = os.path.join(self.cfg.neg_sample_dir, 'all_captions.txt')
self.all_caption_list = [
row.strip() for row in open(os.path.join(neg_captions_path)) if row.strip() != ''
]
if self.cfg.neg_captions_video is not None:
neg_captions_video = self.cfg.neg_captions_video
else:
neg_captions_video = neg_captions_path
if self.cfg.neg_captions_audio is not None:
neg_captions_audio = self.cfg.neg_captions_audio
else:
neg_captions_audio = neg_captions_path
print("Reading negative video captions:", neg_captions_video)
self.all_caption_video_list = [
row.strip() for row in open(os.path.join(neg_captions_video)) if row.strip() != ''
]
print("Reading negative audio captions:", neg_captions_audio)
self.all_caption_audio_list = [
row.strip() for row in open(os.path.join(neg_captions_audio)) if row.strip() != ''
]
# audio
self.audio_dataset = None
self.video_dataset = None
self.image_text_dataset = None
self.video_text_dataset = None
self.audio_text_dataset = None
self.init_image_text_dataset = None
self.init_text_dataset = None
self.image_text_ground_dataset = None
self.image_text_vqa_dataset = None
if self.cfg.audio_data is not None:
print("Load audio data")
self.audio_dataset = FileDataset(self.cfg.audio_data, self.cfg.audio_selected_cols)
if self.cfg.video_data is not None:
print("Load video data")
self.video_dataset = FileDataset(self.cfg.video_data, self.cfg.video_selected_cols)
if self.cfg.image_text_data is not None:
print("Load Image Text data")
self.image_text_dataset = FileDataset(self.cfg.image_text_data, self.cfg.selected_cols)
if self.cfg.init_image_text_data is not None:
print("Load Init Image Text data")
self.init_image_text_dataset = FileDataset(self.cfg.init_image_text_data, self.cfg.selected_cols)
if self.cfg.init_text_data is not None:
print("Load Init Text data")
self.init_text_dataset = FileDataset(self.cfg.init_text_data, self.cfg.text_selected_cols)
if self.cfg.image_text_vqa_data is not None:
print("Load Image Text data")
self.image_text_vqa_dataset = FileDataset(self.cfg.image_text_vqa_data, self.cfg.selected_cols)
if self.cfg.image_text_ground_data is not None:
print("Load Image Text data")
self.image_text_ground_dataset = FileDataset(self.cfg.image_text_ground_data, self.cfg.selected_cols)
if self.cfg.video_text_data is not None:
print("Load Video Text data")
self.video_text_dataset = FileDataset(self.cfg.video_text_data, self.cfg.selected_cols)
if self.cfg.audio_text_data is not None:
print("Load Video Text data")
self.audio_text_dataset = FileDataset(self.cfg.audio_text_data, self.cfg.selected_cols)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = self.cfg.data.split(',')
assert len(paths) > 0
file_path = paths[(epoch - 1) % (len(paths))]
dataset = FileDataset(file_path, self.cfg.selected_cols)
self.datasets[split] = UnifyDataset(
split,
dataset,
self.bpe,
self.src_dict,
self.tgt_dict,
max_src_length=self.cfg.max_src_length,
max_tgt_length=self.cfg.max_tgt_length,
seed=self.cfg.pretrain_seed,
code_dict_size=self.cfg.code_dict_size,
num_bins=self.cfg.num_bins,
patch_image_size=self.cfg.patch_image_size,
code_image_size=self.cfg.code_image_size,
all_object_list=self.all_object_list,
all_caption_list=self.all_caption_list,
type2ans_dict=self.type2ans_dict,
ans2type_dict=self.ans2type_dict,
max_image_size=self.cfg.max_image_size,
mask_ratio=self.cfg.mask_ratio,
random_ratio=self.cfg.random_ratio,
keep_ratio=self.cfg.keep_ratio,
mask_length=self.cfg.mask_length,
poisson_lambda=self.cfg.poisson_lambda,
replace_length=self.cfg.replace_length,
read_from_img_path=self.cfg.read_from_img_path,
image_dir=self.cfg.image_dir,
no_image_transform=self.cfg.no_image_transform,
patch_frame_size=self.cfg.patch_frame_size,
num_frames=self.cfg.num_frames,
all_caption_video_list=self.all_caption_video_list,
all_caption_audio_list=self.all_caption_audio_list,
max_audio_len=self.cfg.max_audio_len,
audio_dataset=self.audio_dataset,
video_dataset=self.video_dataset,
video_cnt=self.cfg.video_cnt,
audio_cnt=self.cfg.audio_cnt,
image_text_dataset=self.image_text_dataset,
image_text_cnt=self.cfg.image_text_cnt,
other_data_cnt=self.cfg.other_data_cnt,
init_image_text_dataset=self.init_image_text_dataset,
init_dataset_epoch=self.cfg.init_dataset_epoch,
init_text_dataset=self.init_text_dataset,
image_text_vqa_dataset=self.image_text_vqa_dataset,
image_text_vqa_cnt=self.cfg.image_text_vqa_cnt,
image_text_ground_dataset=self.image_text_ground_dataset,
image_text_ground_cnt=self.cfg.image_text_ground_cnt,
only_video_data=self.cfg.only_video_data,
only_audio_data=self.cfg.only_audio_data,
video_text_dataset=self.video_text_dataset,
video_text_cnt=self.cfg.video_text_cnt,
audio_text_dataset=self.audio_text_dataset,
audio_text_cnt=self.cfg.audio_text_cnt,
audio_with_video=self.cfg.audio_with_video,
sample_rate=self.cfg.sample_rate,
)
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# create mini-batches with given size constraints
batch_sampler = [
[j for j in range(i, min(i + max_sentences, len(dataset)))]
for i in range(0, len(dataset), max_sentences)
]
total_row_count = dataset.dataset.get_total_row_count()
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
if len(batch_sampler) < num_batches:
batch_sampler.append([1])
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=1,
shard_id=0,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size
)
return epoch_iter