|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
import json |
|
import logging |
|
import os |
|
import math |
|
from typing import Optional |
|
from fairseq.tasks import register_task |
|
from fairseq.data import FairseqDataset, iterators |
|
|
|
from tasks.ofa_task import OFATask, OFAConfig |
|
from data.pretrain_data.unify_dataset import UnifyDataset |
|
from data.file_dataset import FileDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class UnifyConfig(OFAConfig): |
|
max_image_size: int = field( |
|
default=512, metadata={"help": ""} |
|
) |
|
neg_sample_dir: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "negative sample directory, which contains captions (taken from all image-text pairs), " |
|
"answers (taken from VQA), " |
|
"objects (taken form OpenImages) "}, |
|
) |
|
code_image_size: int = field( |
|
default=128, metadata={"help": "the resolution of the generated image in the image infilling task"} |
|
) |
|
|
|
pretrain_seed: int = field( |
|
default=7, |
|
metadata={"help": "pretrain seed"}, |
|
) |
|
|
|
mask_ratio: float = field( |
|
default=0.3, |
|
metadata={"help": "fraction of words/subwords that will be masked"}, |
|
) |
|
random_ratio: float = field( |
|
default=0.0, |
|
metadata={"help": "instead of using [MASK], use random token this often"}, |
|
) |
|
keep_ratio: float = field( |
|
default=0.0, |
|
metadata={"help": "instead of using [MASK], keep original token this often"}, |
|
) |
|
mask_length: str = field( |
|
default="span-poisson", |
|
metadata={"help": "mask length to choose ['subword', 'word', 'span-poisson']"}, |
|
) |
|
poisson_lambda: float = field( |
|
default=3.0, |
|
metadata={"help": "randomly shuffle sentences for this proportion of inputs"}, |
|
) |
|
replace_length: int = field( |
|
default=1, |
|
metadata={"help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)"}, |
|
) |
|
|
|
|
|
neg_captions_video: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "negative sample file .txt, which contains captions (taken from all video-text pairs), "}, |
|
) |
|
|
|
neg_captions_audio: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "negative sample file .txt, which contains captions (taken from all video-text pairs), "}, |
|
) |
|
|
|
audio_selected_cols: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "detection data selected cols"}, |
|
) |
|
|
|
audio_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "detection data"}, |
|
) |
|
audio_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of video examples in the batch"}, |
|
) |
|
|
|
video_selected_cols: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "video data selected cols"}, |
|
) |
|
|
|
video_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "video data"}, |
|
) |
|
|
|
video_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of video examples in the batch"}, |
|
) |
|
|
|
image_text_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "cc12m data"}, |
|
) |
|
|
|
image_text_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
other_data_cnt: int = field( |
|
default=8, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
|
|
init_image_text_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "init data"}, |
|
) |
|
|
|
init_dataset_epoch: int = field( |
|
default=3, |
|
metadata={"help": "to witch from the init dataset"}, |
|
) |
|
|
|
init_text_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "init data"}, |
|
) |
|
|
|
|
|
image_text_vqa_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "vqa data"}, |
|
) |
|
|
|
image_text_vqa_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
|
|
image_text_ground_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "cc12m data"}, |
|
) |
|
|
|
image_text_ground_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
only_video_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "only video data"}, |
|
) |
|
only_audio_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "only video data"}, |
|
) |
|
|
|
video_text_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "cc12m data"}, |
|
) |
|
|
|
video_text_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
audio_text_data: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "cc12m data"}, |
|
) |
|
|
|
audio_text_cnt: int = field( |
|
default=2, |
|
metadata={"help": "to control the ratio of image-text examples in the batch"}, |
|
) |
|
|
|
audio_with_video: bool = field( |
|
default=False, |
|
metadata={"help": "audio_with_video"}, |
|
) |
|
|
|
|
|
@register_task("unify_task", dataclass=UnifyConfig) |
|
class UnifyTask(OFATask): |
|
def __init__(self, cfg: UnifyConfig, src_dict, tgt_dict): |
|
super().__init__(cfg, src_dict, tgt_dict) |
|
|
|
self.type2ans_dict = json.load(open(os.path.join(self.cfg.neg_sample_dir, 'type2ans.json'))) |
|
self.ans2type_dict = {} |
|
for type, answer_list in self.type2ans_dict.items(): |
|
if type == 'other': |
|
continue |
|
for answer in answer_list: |
|
self.ans2type_dict[answer] = type |
|
|
|
self.all_object_list = [ |
|
row.strip() for row in open(os.path.join(self.cfg.neg_sample_dir, 'object.txt')) if row.strip() != '' |
|
] |
|
neg_captions_path = os.path.join(self.cfg.neg_sample_dir, 'all_captions.txt') |
|
self.all_caption_list = [ |
|
row.strip() for row in open(os.path.join(neg_captions_path)) if row.strip() != '' |
|
] |
|
|
|
if self.cfg.neg_captions_video is not None: |
|
neg_captions_video = self.cfg.neg_captions_video |
|
else: |
|
neg_captions_video = neg_captions_path |
|
|
|
if self.cfg.neg_captions_audio is not None: |
|
neg_captions_audio = self.cfg.neg_captions_audio |
|
else: |
|
neg_captions_audio = neg_captions_path |
|
|
|
print("Reading negative video captions:", neg_captions_video) |
|
self.all_caption_video_list = [ |
|
row.strip() for row in open(os.path.join(neg_captions_video)) if row.strip() != '' |
|
] |
|
print("Reading negative audio captions:", neg_captions_audio) |
|
self.all_caption_audio_list = [ |
|
row.strip() for row in open(os.path.join(neg_captions_audio)) if row.strip() != '' |
|
] |
|
|
|
|
|
|
|
|
|
self.audio_dataset = None |
|
self.video_dataset = None |
|
self.image_text_dataset = None |
|
|
|
self.video_text_dataset = None |
|
|
|
self.audio_text_dataset = None |
|
|
|
self.init_image_text_dataset = None |
|
|
|
self.init_text_dataset = None |
|
|
|
self.image_text_ground_dataset = None |
|
self.image_text_vqa_dataset = None |
|
|
|
|
|
|
|
if self.cfg.audio_data is not None: |
|
print("Load audio data") |
|
self.audio_dataset = FileDataset(self.cfg.audio_data, self.cfg.audio_selected_cols) |
|
|
|
if self.cfg.video_data is not None: |
|
print("Load video data") |
|
self.video_dataset = FileDataset(self.cfg.video_data, self.cfg.video_selected_cols) |
|
|
|
if self.cfg.image_text_data is not None: |
|
print("Load Image Text data") |
|
self.image_text_dataset = FileDataset(self.cfg.image_text_data, self.cfg.selected_cols) |
|
|
|
if self.cfg.init_image_text_data is not None: |
|
print("Load Init Image Text data") |
|
self.init_image_text_dataset = FileDataset(self.cfg.init_image_text_data, self.cfg.selected_cols) |
|
|
|
if self.cfg.init_text_data is not None: |
|
print("Load Init Text data") |
|
self.init_text_dataset = FileDataset(self.cfg.init_text_data, self.cfg.text_selected_cols) |
|
|
|
if self.cfg.image_text_vqa_data is not None: |
|
print("Load Image Text data") |
|
self.image_text_vqa_dataset = FileDataset(self.cfg.image_text_vqa_data, self.cfg.selected_cols) |
|
|
|
if self.cfg.image_text_ground_data is not None: |
|
print("Load Image Text data") |
|
self.image_text_ground_dataset = FileDataset(self.cfg.image_text_ground_data, self.cfg.selected_cols) |
|
|
|
|
|
if self.cfg.video_text_data is not None: |
|
print("Load Video Text data") |
|
self.video_text_dataset = FileDataset(self.cfg.video_text_data, self.cfg.selected_cols) |
|
|
|
if self.cfg.audio_text_data is not None: |
|
print("Load Video Text data") |
|
self.audio_text_dataset = FileDataset(self.cfg.audio_text_data, self.cfg.selected_cols) |
|
|
|
|
|
def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
|
paths = self.cfg.data.split(',') |
|
assert len(paths) > 0 |
|
|
|
file_path = paths[(epoch - 1) % (len(paths))] |
|
dataset = FileDataset(file_path, self.cfg.selected_cols) |
|
|
|
|
|
|
|
self.datasets[split] = UnifyDataset( |
|
split, |
|
dataset, |
|
self.bpe, |
|
self.src_dict, |
|
self.tgt_dict, |
|
max_src_length=self.cfg.max_src_length, |
|
max_tgt_length=self.cfg.max_tgt_length, |
|
seed=self.cfg.pretrain_seed, |
|
code_dict_size=self.cfg.code_dict_size, |
|
num_bins=self.cfg.num_bins, |
|
patch_image_size=self.cfg.patch_image_size, |
|
code_image_size=self.cfg.code_image_size, |
|
all_object_list=self.all_object_list, |
|
all_caption_list=self.all_caption_list, |
|
type2ans_dict=self.type2ans_dict, |
|
ans2type_dict=self.ans2type_dict, |
|
max_image_size=self.cfg.max_image_size, |
|
mask_ratio=self.cfg.mask_ratio, |
|
random_ratio=self.cfg.random_ratio, |
|
keep_ratio=self.cfg.keep_ratio, |
|
mask_length=self.cfg.mask_length, |
|
poisson_lambda=self.cfg.poisson_lambda, |
|
replace_length=self.cfg.replace_length, |
|
read_from_img_path=self.cfg.read_from_img_path, |
|
image_dir=self.cfg.image_dir, |
|
no_image_transform=self.cfg.no_image_transform, |
|
patch_frame_size=self.cfg.patch_frame_size, |
|
num_frames=self.cfg.num_frames, |
|
all_caption_video_list=self.all_caption_video_list, |
|
all_caption_audio_list=self.all_caption_audio_list, |
|
max_audio_len=self.cfg.max_audio_len, |
|
audio_dataset=self.audio_dataset, |
|
video_dataset=self.video_dataset, |
|
video_cnt=self.cfg.video_cnt, |
|
audio_cnt=self.cfg.audio_cnt, |
|
image_text_dataset=self.image_text_dataset, |
|
image_text_cnt=self.cfg.image_text_cnt, |
|
other_data_cnt=self.cfg.other_data_cnt, |
|
init_image_text_dataset=self.init_image_text_dataset, |
|
init_dataset_epoch=self.cfg.init_dataset_epoch, |
|
init_text_dataset=self.init_text_dataset, |
|
image_text_vqa_dataset=self.image_text_vqa_dataset, |
|
image_text_vqa_cnt=self.cfg.image_text_vqa_cnt, |
|
image_text_ground_dataset=self.image_text_ground_dataset, |
|
image_text_ground_cnt=self.cfg.image_text_ground_cnt, |
|
only_video_data=self.cfg.only_video_data, |
|
only_audio_data=self.cfg.only_audio_data, |
|
video_text_dataset=self.video_text_dataset, |
|
video_text_cnt=self.cfg.video_text_cnt, |
|
audio_text_dataset=self.audio_text_dataset, |
|
audio_text_cnt=self.cfg.audio_text_cnt, |
|
audio_with_video=self.cfg.audio_with_video, |
|
sample_rate=self.cfg.sample_rate, |
|
) |
|
|
|
def get_batch_iterator( |
|
self, |
|
dataset, |
|
max_tokens=None, |
|
max_sentences=None, |
|
max_positions=None, |
|
ignore_invalid_inputs=False, |
|
required_batch_size_multiple=1, |
|
seed=1, |
|
num_shards=1, |
|
shard_id=0, |
|
num_workers=0, |
|
epoch=1, |
|
data_buffer_size=0, |
|
disable_iterator_cache=False, |
|
): |
|
assert isinstance(dataset, FairseqDataset) |
|
|
|
|
|
dataset.set_epoch(epoch) |
|
|
|
|
|
batch_sampler = [ |
|
[j for j in range(i, min(i + max_sentences, len(dataset)))] |
|
for i in range(0, len(dataset), max_sentences) |
|
] |
|
total_row_count = dataset.dataset.get_total_row_count() |
|
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences) |
|
if len(batch_sampler) < num_batches: |
|
batch_sampler.append([1]) |
|
|
|
|
|
epoch_iter = iterators.EpochBatchIterator( |
|
dataset=dataset, |
|
collate_fn=dataset.collater, |
|
batch_sampler=batch_sampler, |
|
seed=seed, |
|
num_shards=1, |
|
shard_id=0, |
|
num_workers=num_workers, |
|
epoch=epoch, |
|
buffer_size=data_buffer_size |
|
) |
|
|
|
return epoch_iter |
|
|