|
|
|
|
|
|
|
|
|
""" |
|
make a general fairseq task for MM pretraining. |
|
""" |
|
|
|
import random |
|
|
|
from fairseq.tasks import LegacyFairseqTask, register_task |
|
|
|
from .task import Task |
|
from .retritask import RetriTask |
|
from ..datasets import FairseqMMDataset |
|
from .. import utils |
|
|
|
|
|
@register_task("mmtask") |
|
class FairseqMMTask(LegacyFairseqTask): |
|
@staticmethod |
|
def add_args(parser): |
|
|
|
|
|
parser.add_argument( |
|
"taskconfig", |
|
metavar="FILE", |
|
help=("taskconfig to load all configurations" "outside fairseq parser."), |
|
) |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
return FairseqMMTask(args) |
|
|
|
def __init__(self, args): |
|
super().__init__(args) |
|
config = utils.load_config(args) |
|
self.mmtask = Task.config_task(config) |
|
self.mmtask.build_dataset() |
|
self.mmtask.build_model() |
|
self.mmtask.build_loss() |
|
|
|
def load_dataset(self, split, **kwargs): |
|
split_map = { |
|
"train": self.mmtask.train_data, |
|
"valid": self.mmtask.val_data, |
|
"test": self.mmtask.test_data, |
|
} |
|
if split not in split_map: |
|
raise ValueError("unknown split type.") |
|
if split_map[split] is not None: |
|
self.datasets[split] = FairseqMMDataset(split_map[split]) |
|
|
|
def get_batch_iterator( |
|
self, |
|
dataset, |
|
max_tokens=None, |
|
max_sentences=None, |
|
max_positions=None, |
|
ignore_invalid_inputs=False, |
|
required_batch_size_multiple=1, |
|
seed=1, |
|
num_shards=1, |
|
shard_id=0, |
|
num_workers=0, |
|
epoch=1, |
|
data_buffer_size=0, |
|
disable_iterator_cache=False, |
|
skip_remainder_batch=False, |
|
grouped_shuffling=False, |
|
update_epoch_batch_itr=False, |
|
): |
|
random.seed(epoch) |
|
if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask): |
|
if epoch >= self.mmtask.config.retri_epoch: |
|
if not hasattr(self.mmtask, "retri_dataloader"): |
|
self.mmtask.build_dataloader() |
|
self.mmtask.retrive_candidates(epoch) |
|
|
|
return super().get_batch_iterator( |
|
dataset, |
|
max_tokens, |
|
max_sentences, |
|
max_positions, |
|
ignore_invalid_inputs, |
|
required_batch_size_multiple, |
|
seed, |
|
num_shards, |
|
shard_id, |
|
num_workers, |
|
epoch, |
|
data_buffer_size, |
|
disable_iterator_cache, |
|
grouped_shuffling, |
|
update_epoch_batch_itr, |
|
) |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
@property |
|
def target_dictionary(self): |
|
return None |
|
|