diff --git a/fairseq/docs/models.rst b/fairseq/docs/models.rst new file mode 100644 index 0000000000000000000000000000000000000000..054622d587c3b7f01f17f442919140755acd8f9e --- /dev/null +++ b/fairseq/docs/models.rst @@ -0,0 +1,104 @@ +.. role:: hidden + :class: hidden-section + +.. module:: fairseq.models + +.. _Models: + +Models +====== + +A Model defines the neural network's ``forward()`` method and encapsulates all +of the learnable parameters in the network. Each model also provides a set of +named *architectures* that define the precise network configuration (e.g., +embedding dimension, number of layers, etc.). + +Both the model type and architecture are selected via the ``--arch`` +command-line argument. Once selected, a model may expose additional command-line +arguments for further configuration. + +.. note:: + + All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends + :class:`torch.nn.Module`. Thus any fairseq Model can be used as a + stand-alone Module in other PyTorch code. + + +Convolutional Neural Networks (CNN) +----------------------------------- + +.. module:: fairseq.models.fconv +.. autoclass:: fairseq.models.fconv.FConvModel + :members: +.. autoclass:: fairseq.models.fconv.FConvEncoder + :members: + :undoc-members: +.. autoclass:: fairseq.models.fconv.FConvDecoder + :members: + + +Long Short-Term Memory (LSTM) networks +-------------------------------------- + +.. module:: fairseq.models.lstm +.. autoclass:: fairseq.models.lstm.LSTMModel + :members: +.. autoclass:: fairseq.models.lstm.LSTMEncoder + :members: +.. autoclass:: fairseq.models.lstm.LSTMDecoder + :members: + + +Transformer (self-attention) networks +------------------------------------- + +.. module:: fairseq.models.transformer +.. autoclass:: fairseq.models.transformer.TransformerModel + :members: +.. autoclass:: fairseq.models.transformer.TransformerEncoder + :members: +.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer + :members: +.. autoclass:: fairseq.models.transformer.TransformerDecoder + :members: +.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer + :members: + + +Adding new models +----------------- + +.. currentmodule:: fairseq.models +.. autofunction:: fairseq.models.register_model +.. autofunction:: fairseq.models.register_model_architecture +.. autoclass:: fairseq.models.BaseFairseqModel + :members: + :undoc-members: +.. autoclass:: fairseq.models.FairseqEncoderDecoderModel + :members: + :undoc-members: +.. autoclass:: fairseq.models.FairseqEncoderModel + :members: + :undoc-members: +.. autoclass:: fairseq.models.FairseqLanguageModel + :members: + :undoc-members: +.. autoclass:: fairseq.models.FairseqMultiModel + :members: + :undoc-members: +.. autoclass:: fairseq.models.FairseqEncoder + :members: +.. autoclass:: fairseq.models.CompositeEncoder + :members: +.. autoclass:: fairseq.models.FairseqDecoder + :members: + + +.. _Incremental decoding: + +Incremental decoding +-------------------- + +.. autoclass:: fairseq.models.FairseqIncrementalDecoder + :members: + :undoc-members: diff --git a/fairseq/docs/tutorial_classifying_names.rst b/fairseq/docs/tutorial_classifying_names.rst new file mode 100644 index 0000000000000000000000000000000000000000..de099f08f548d4fb92829922337a4b7f4ceacc9b --- /dev/null +++ b/fairseq/docs/tutorial_classifying_names.rst @@ -0,0 +1,415 @@ +Tutorial: Classifying Names with a Character-Level RNN +====================================================== + +In this tutorial we will extend fairseq to support *classification* tasks. In +particular we will re-implement the PyTorch tutorial for `Classifying Names with +a Character-Level RNN `_ +in fairseq. It is recommended to quickly skim that tutorial before beginning +this one. + +This tutorial covers: + +1. **Preprocessing the data** to create dictionaries. +2. **Registering a new Model** that encodes an input sentence with a simple RNN + and predicts the output label. +3. **Registering a new Task** that loads our dictionaries and dataset. +4. **Training the Model** using the existing command-line tools. +5. **Writing an evaluation script** that imports fairseq and allows us to + interactively evaluate our model on new inputs. + + +1. Preprocessing the data +------------------------- + +The original tutorial provides raw data, but we'll work with a modified version +of the data that is already tokenized into characters and split into separate +train, valid and test sets. + +Download and extract the data from here: +`tutorial_names.tar.gz `_ + +Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess` +command-line tool to create the dictionaries. While this tool is primarily +intended for sequence-to-sequence problems, we're able to reuse it here by +treating the label as a "target" sequence of length 1. We'll also output the +preprocessed files in "raw" format using the ``--dataset-impl`` option to +enhance readability: + +.. code-block:: console + + > fairseq-preprocess \ + --trainpref names/train --validpref names/valid --testpref names/test \ + --source-lang input --target-lang label \ + --destdir names-bin --dataset-impl raw + +After running the above command you should see a new directory, +:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*. + + +2. Registering a new Model +-------------------------- + +Next we'll register a new model in fairseq that will encode an input sentence +with a simple RNN and predict the output label. Compared to the original PyTorch +tutorial, our version will also work with batches of data and GPU Tensors. + +First let's copy the simple RNN module implemented in the `PyTorch tutorial +`_. +Create a new file named :file:`fairseq/models/rnn_classifier.py` with the +following contents:: + + import torch + import torch.nn as nn + + class RNN(nn.Module): + + def __init__(self, input_size, hidden_size, output_size): + super(RNN, self).__init__() + + self.hidden_size = hidden_size + + self.i2h = nn.Linear(input_size + hidden_size, hidden_size) + self.i2o = nn.Linear(input_size + hidden_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, input, hidden): + combined = torch.cat((input, hidden), 1) + hidden = self.i2h(combined) + output = self.i2o(combined) + output = self.softmax(output) + return output, hidden + + def initHidden(self): + return torch.zeros(1, self.hidden_size) + +We must also *register* this model with fairseq using the +:func:`~fairseq.models.register_model` function decorator. Once the model is +registered we'll be able to use it with the existing :ref:`Command-line Tools`. + +All registered models must implement the :class:`~fairseq.models.BaseFairseqModel` +interface, so we'll create a small wrapper class in the same file and register +it in fairseq with the name ``'rnn_classifier'``:: + + from fairseq.models import BaseFairseqModel, register_model + + # Note: the register_model "decorator" should immediately precede the + # definition of the Model class. + + @register_model('rnn_classifier') + class FairseqRNNClassifier(BaseFairseqModel): + + @staticmethod + def add_args(parser): + # Models can override this method to add new command-line arguments. + # Here we'll add a new command-line argument to configure the + # dimensionality of the hidden state. + parser.add_argument( + '--hidden-dim', type=int, metavar='N', + help='dimensionality of the hidden state', + ) + + @classmethod + def build_model(cls, args, task): + # Fairseq initializes models by calling the ``build_model()`` + # function. This provides more flexibility, since the returned model + # instance can be of a different type than the one that was called. + # In this case we'll just return a FairseqRNNClassifier instance. + + # Initialize our RNN module + rnn = RNN( + # We'll define the Task in the next section, but for now just + # notice that the task holds the dictionaries for the "source" + # (i.e., the input sentence) and "target" (i.e., the label). + input_size=len(task.source_dictionary), + hidden_size=args.hidden_dim, + output_size=len(task.target_dictionary), + ) + + # Return the wrapped version of the module + return FairseqRNNClassifier( + rnn=rnn, + input_vocab=task.source_dictionary, + ) + + def __init__(self, rnn, input_vocab): + super(FairseqRNNClassifier, self).__init__() + + self.rnn = rnn + self.input_vocab = input_vocab + + # The RNN module in the tutorial expects one-hot inputs, so we can + # precompute the identity matrix to help convert from indices to + # one-hot vectors. We register it as a buffer so that it is moved to + # the GPU when ``cuda()`` is called. + self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab))) + + def forward(self, src_tokens, src_lengths): + # The inputs to the ``forward()`` function are determined by the + # Task, and in particular the ``'net_input'`` key in each + # mini-batch. We'll define the Task in the next section, but for + # now just know that *src_tokens* has shape `(batch, src_len)` and + # *src_lengths* has shape `(batch)`. + bsz, max_src_len = src_tokens.size() + + # Initialize the RNN hidden state. Compared to the original PyTorch + # tutorial we'll also handle batched inputs and work on the GPU. + hidden = self.rnn.initHidden() + hidden = hidden.repeat(bsz, 1) # expand for batched inputs + hidden = hidden.to(src_tokens.device) # move to GPU + + for i in range(max_src_len): + # WARNING: The inputs have padding, so we should mask those + # elements here so that padding doesn't affect the results. + # This is left as an exercise for the reader. The padding symbol + # is given by ``self.input_vocab.pad()`` and the unpadded length + # of each input is given by *src_lengths*. + + # One-hot encode a batch of input characters. + input = self.one_hot_inputs[src_tokens[:, i].long()] + + # Feed the input to our RNN. + output, hidden = self.rnn(input, hidden) + + # Return the final output state for making a prediction + return output + +Finally let's define a *named architecture* with the configuration for our +model. This is done with the :func:`~fairseq.models.register_model_architecture` +function decorator. Thereafter this named architecture can be used with the +``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``:: + + from fairseq.models import register_model_architecture + + # The first argument to ``register_model_architecture()`` should be the name + # of the model we registered above (i.e., 'rnn_classifier'). The function we + # register here should take a single argument *args* and modify it in-place + # to match the desired architecture. + + @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn') + def pytorch_tutorial_rnn(args): + # We use ``getattr()`` to prioritize arguments that are explicitly given + # on the command-line, so that the defaults defined below are only used + # when no other value has been specified. + args.hidden_dim = getattr(args, 'hidden_dim', 128) + + +3. Registering a new Task +------------------------- + +Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our +dictionaries and dataset. Tasks can also control how the data is batched into +mini-batches, but in this tutorial we'll reuse the batching provided by +:class:`fairseq.data.LanguagePairDataset`. + +Create a new file named :file:`fairseq/tasks/simple_classification.py` with the +following contents:: + + import os + import torch + + from fairseq.data import Dictionary, LanguagePairDataset + from fairseq.tasks import LegacyFairseqTask, register_task + + + @register_task('simple_classification') + class SimpleClassificationTask(LegacyFairseqTask): + + @staticmethod + def add_args(parser): + # Add some command-line arguments for specifying where the data is + # located and the maximum supported input length. + parser.add_argument('data', metavar='FILE', + help='file prefix for data') + parser.add_argument('--max-positions', default=1024, type=int, + help='max input length') + + @classmethod + def setup_task(cls, args, **kwargs): + # Here we can perform any setup required for the task. This may include + # loading Dictionaries, initializing shared Embedding layers, etc. + # In this case we'll just load the Dictionaries. + input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt')) + label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt')) + print('| [input] dictionary: {} types'.format(len(input_vocab))) + print('| [label] dictionary: {} types'.format(len(label_vocab))) + + return SimpleClassificationTask(args, input_vocab, label_vocab) + + def __init__(self, args, input_vocab, label_vocab): + super().__init__(args) + self.input_vocab = input_vocab + self.label_vocab = label_vocab + + def load_dataset(self, split, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + prefix = os.path.join(self.args.data, '{}.input-label'.format(split)) + + # Read input sentences. + sentences, lengths = [], [] + with open(prefix + '.input', encoding='utf-8') as file: + for line in file: + sentence = line.strip() + + # Tokenize the sentence, splitting on spaces + tokens = self.input_vocab.encode_line( + sentence, add_if_not_exist=False, + ) + + sentences.append(tokens) + lengths.append(tokens.numel()) + + # Read labels. + labels = [] + with open(prefix + '.label', encoding='utf-8') as file: + for line in file: + label = line.strip() + labels.append( + # Convert label to a numeric ID. + torch.LongTensor([self.label_vocab.add_symbol(label)]) + ) + + assert len(sentences) == len(labels) + print('| {} {} {} examples'.format(self.args.data, split, len(sentences))) + + # We reuse LanguagePairDataset since classification can be modeled as a + # sequence-to-sequence task where the target sequence has length 1. + self.datasets[split] = LanguagePairDataset( + src=sentences, + src_sizes=lengths, + src_dict=self.input_vocab, + tgt=labels, + tgt_sizes=torch.ones(len(labels)), # targets have length 1 + tgt_dict=self.label_vocab, + left_pad_source=False, + # Since our target is a single class label, there's no need for + # teacher forcing. If we set this to ``True`` then our Model's + # ``forward()`` method would receive an additional argument called + # *prev_output_tokens* that would contain a shifted version of the + # target sequence. + input_feeding=False, + ) + + def max_positions(self): + """Return the max input length allowed by the task.""" + # The source should be less than *args.max_positions* and the "target" + # has max length 1. + return (self.args.max_positions, 1) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.input_vocab + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.label_vocab + + # We could override this method if we wanted more control over how batches + # are constructed, but it's not necessary for this tutorial since we can + # reuse the batching provided by LanguagePairDataset. + # + # def get_batch_iterator( + # self, dataset, max_tokens=None, max_sentences=None, max_positions=None, + # ignore_invalid_inputs=False, required_batch_size_multiple=1, + # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, + # data_buffer_size=0, disable_iterator_cache=False, + # ): + # (...) + + +4. Training the Model +--------------------- + +Now we're ready to train the model. We can use the existing :ref:`fairseq-train` +command-line tool for this, making sure to specify our new Task (``--task +simple_classification``) and Model architecture (``--arch +pytorch_tutorial_rnn``): + +.. note:: + + You can also configure the dimensionality of the hidden state by passing the + ``--hidden-dim`` argument to :ref:`fairseq-train`. + +.. code-block:: console + + > fairseq-train names-bin \ + --task simple_classification \ + --arch pytorch_tutorial_rnn \ + --optimizer adam --lr 0.001 --lr-shrink 0.5 \ + --max-tokens 1000 + (...) + | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21 + | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208 + | done training in 31.6 seconds + +The model files should appear in the :file:`checkpoints/` directory. + + +5. Writing an evaluation script +------------------------------- + +Finally we can write a short script to evaluate our model on new inputs. Create +a new file named :file:`eval_classifier.py` with the following contents:: + + from fairseq import checkpoint_utils, data, options, tasks + + # Parse command-line arguments for generation + parser = options.get_generation_parser(default_task='simple_classification') + args = options.parse_args_and_arch(parser) + + # Setup task + task = tasks.setup_task(args) + + # Load model + print('| loading model from {}'.format(args.path)) + models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task) + model = models[0] + + while True: + sentence = input('\nInput: ') + + # Tokenize into characters + chars = ' '.join(list(sentence.strip())) + tokens = task.source_dictionary.encode_line( + chars, add_if_not_exist=False, + ) + + # Build mini-batch to feed to the model + batch = data.language_pair_dataset.collate( + samples=[{'id': -1, 'source': tokens}], # bsz = 1 + pad_idx=task.source_dictionary.pad(), + eos_idx=task.source_dictionary.eos(), + left_pad_source=False, + input_feeding=False, + ) + + # Feed batch to the model and get predictions + preds = model(**batch['net_input']) + + # Print top 3 predictions and their log-probabilities + top_scores, top_labels = preds[0].topk(k=3) + for score, label_idx in zip(top_scores, top_labels): + label_name = task.target_dictionary.string([label_idx]) + print('({:.2f})\t{}'.format(score, label_name)) + +Now we can evaluate our model interactively. Note that we have included the +original data path (:file:`names-bin/`) so that the dictionaries can be loaded: + +.. code-block:: console + + > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt + | [input] dictionary: 64 types + | [label] dictionary: 24 types + | loading model from checkpoints/checkpoint_best.pt + + Input: Satoshi + (-0.61) Japanese + (-1.20) Arabic + (-2.86) Italian + + Input: Sinbad + (-0.30) Arabic + (-1.76) English + (-4.08) Russian diff --git a/fairseq/examples/MMPT/.gitignore b/fairseq/examples/MMPT/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..70a255dc918bf4a4242dd7ea5ded63e6848264db --- /dev/null +++ b/fairseq/examples/MMPT/.gitignore @@ -0,0 +1,139 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +runs +data +pretrained_models +projects/mmfusion_* +log_test +third-party +python_log +slurm_snapshot_code +lightning_logs +demos diff --git a/fairseq/examples/MMPT/README.md b/fairseq/examples/MMPT/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4a84819d9ddc992170296958a9431ff0538a85be --- /dev/null +++ b/fairseq/examples/MMPT/README.md @@ -0,0 +1,166 @@ +# VideoCLIP and VLM + +You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq). + +VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks. + + + +VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks. + + + +### News +[Oct. 2021] Initial release of implementation for the following papers: +[VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021) +[VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021) + + +### Installation +We aim to minimize the dependency of this repo on other packages. +We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future): +``` +git clone https://github.com/pytorch/fairseq +cd fairseq +pip install -e . # also optionally follow fairseq README for apex installation for fp16 training. +export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy. +``` + +Then install this toolkit: +``` +cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples. +pip install -e . +``` + +The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release. +Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`. +In addition, some downstream tasks may need `conda install pandas`. + + +### Usage +#### Download Checkpoints +We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`. + +Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`. + +#### Demo of Inference +run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP. + +```python +import torch + +from mmpt.models import MMPTModel + + +model, tokenizer, aligner = MMPTModel.from_pretrained( + "projects/retri/videoclip/how2.yaml") + +model.eval() + + +# B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d) +video_frames = torch.randn(1, 2, 30, 224, 224, 3) +caps, cmasks = aligner._build_text_seq( + tokenizer("some text", add_special_tokens=False)["input_ids"] +) + +caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1 + +with torch.no_grad(): + output = model(video_frames, caps, cmasks, return_score=True) +print(output["score"]) # dot-product +``` + +#### Data Preparation +See [dataset](DATASET.md) for each dataset. + +#### Global Config for Training Pipeline +We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`. + +We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run. + +First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`. + +Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook, +``` +python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation. +python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper). +python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model. +``` + +Pretraining can be run as: +``` +python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus. +``` +You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training. + +The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md). + + +### Development +Several components of this toolkit can be re-used for future research (and also our ongoing research). + +#### Framework Wrapper +We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`. + +#### Processors +**Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study). +Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing. + +We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`). +`MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset. +`VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video. +`TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`). +`Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc. + +#### Performance-tuned Components +To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`. + + +### Citation +If this codebase is useful for your work, please cite the following papers: + +```BibTeX +@inproceedings{xu-etal-2021-videoclip, + title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding", + author = "Xu, Hu and + Ghosh, Gargi and + Huang, Po-Yao and + Okhonko, Dmytro and + Aghajanyan, Armen and + Metze, Florian and + Zettlemoyer, Luke and + Feichtenhofer, Christoph", + booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)", + month = nov, + year = "2021", + address = "Online", + publisher = "Association for Computational Linguistics", +} + +@inproceedings{xu-etal-2021-vlm, + title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding", + author = "Xu, Hu and + Ghosh, Gargi and + Huang, Po-Yao and + Arora, Prahal and + Aminzadeh, Masoumeh and + Feichtenhofer, Christoph and + Metze, Florian and + Zettlemoyer, Luke", + booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021", + month = aug, + year = "2021", + address = "Online", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2021.findings-acl.370", + doi = "10.18653/v1/2021.findings-acl.370", + pages = "4227--4239", +} +``` + +### Bug Reports +This repo is in its initial stage, welcome bug reports to huxu@fb.com + +### Copyright +The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license. diff --git a/fairseq/examples/MMPT/endtask.md b/fairseq/examples/MMPT/endtask.md new file mode 100644 index 0000000000000000000000000000000000000000..7690955327283ded4b37857a4a7b78463e0eb0f8 --- /dev/null +++ b/fairseq/examples/MMPT/endtask.md @@ -0,0 +1,41 @@ +# Zero-shot Transfer and Finetuning + +(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.) +All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`. +Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`. + +### Tasks + +Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks: +text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`; +video captioning: `Youcook`; +Video Question and Answering: `MSRVTT-QA`. + +To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`. + +### Zero-shot Transfer (no Training) +Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer. + +### Fine-tuning + +The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`. + +We typically do finetuning on 2 gpus (`local_small`). + +### Testing +For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`. + +We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config. + +Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config. + +Launching a testing is as simple as training by specifying the path of a testing config: +```python locallaunch.py projects/mfmmlm/test_vtt.yaml``` +Testing will be launched locally by default since prediction is computationally less expensive. + +### Third-party Libraries +We list the following finetuning tasks that require third-party libraries. + +Youcook captioning: `https://github.com/Maluuba/nlg-eval` + +CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`) diff --git a/fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py b/fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..02c49141db69c44663bd438b947c268d06f8aa2b --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset. +""" + +from collections import OrderedDict + +from torch.utils.data import Dataset +from torch.utils.data.dataloader import default_collate +from fairseq.data import FairseqDataset, data_utils + + +class FairseqMMDataset(FairseqDataset): + """ + A wrapper class for MMDataset for fairseq. + """ + + def __init__(self, mmdataset): + if not isinstance(mmdataset, Dataset): + raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.") + self.mmdataset = mmdataset + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, idx): + with data_utils.numpy_seed(43211, self.epoch, idx): + return self.mmdataset[idx] + + def __len__(self): + return len(self.mmdataset) + + def collater(self, samples): + if hasattr(self.mmdataset, "collator"): + return self.mmdataset.collator(samples) + if len(samples) == 0: + return {} + if isinstance(samples[0], dict): + batch = OrderedDict() + for key in samples[0]: + if samples[0][key] is not None: + batch[key] = default_collate([sample[key] for sample in samples]) + return batch + else: + return default_collate(samples) + + def size(self, index): + """dummy implementation: we don't use --max-tokens""" + return 1 + + def num_tokens(self, index): + """dummy implementation: we don't use --max-tokens""" + return 1 diff --git a/fairseq/examples/MMPT/mmpt/datasets/mmdataset.py b/fairseq/examples/MMPT/mmpt/datasets/mmdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d07283f917a430a8d9b1226c8fa6ab71450e8a9 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/datasets/mmdataset.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from collections import OrderedDict + +from torch.utils.data import Dataset +from torch.utils.data.dataloader import default_collate + +from ..utils import set_seed + + +class MMDataset(Dataset): + """ + A generic multi-modal dataset. + Args: + `meta_processor`: a meta processor, + handling loading meta data and return video_id and text_id. + `video_processor`: a video processor, + handling e.g., decoding, loading .np files. + `text_processor`: a text processor, + handling e.g., tokenization. + `aligner`: combine the video and text feature + as one training example. + """ + + def __init__( + self, + meta_processor, + video_processor, + text_processor, + align_processor, + ): + self.split = meta_processor.split + self.meta_processor = meta_processor + self.video_processor = video_processor + self.text_processor = text_processor + self.align_processor = align_processor + + def __len__(self): + return len(self.meta_processor) + + def __getitem__(self, idx): + if self.split == "test": + set_seed(idx) + video_id, text_id = self.meta_processor[idx] + video_feature = self.video_processor(video_id) + text_feature = self.text_processor(text_id) + output = self.align_processor(video_id, video_feature, text_feature) + # TODO (huxu): the following is for debug purpose. + output.update({"idx": idx}) + return output + + def collater(self, samples): + """This collator is deprecated. + set self.collator = MMDataset.collater. + see collator in FairseqMMDataset. + """ + + if len(samples) == 0: + return {} + if isinstance(samples[0], dict): + batch = OrderedDict() + for key in samples[0]: + if samples[0][key] is not None: + batch[key] = default_collate( + [sample[key] for sample in samples]) + # if torch.is_tensor(batch[key]): + # print(key, batch[key].size()) + # else: + # print(key, len(batch[key])) + return batch + else: + return default_collate(samples) + + def print_example(self, output): + print("[one example]", output["video_id"]) + if ( + hasattr(self.align_processor, "subsampling") + and self.align_processor.subsampling is not None + and self.align_processor.subsampling > 1 + ): + for key in output: + if torch.is_tensor(output[key]): + output[key] = output[key][0] + + # search tokenizer to translate ids back. + tokenizer = None + if hasattr(self.text_processor, "tokenizer"): + tokenizer = self.text_processor.tokenizer + elif hasattr(self.align_processor, "tokenizer"): + tokenizer = self.align_processor.tokenizer + if tokenizer is not None: + caps = output["caps"].tolist() + if isinstance(caps[0], list): + caps = caps[0] + print("caps", tokenizer.decode(caps)) + print("caps", tokenizer.convert_ids_to_tokens(caps)) + + for key, value in output.items(): + if torch.is_tensor(value): + if len(value.size()) >= 3: # attention_mask. + print(key, value.size()) + print(key, "first", value[0, :, :]) + print(key, "last", value[-1, :, :]) + else: + print(key, value) + print("[end of one example]") diff --git a/fairseq/examples/MMPT/mmpt/evaluators/__init__.py b/fairseq/examples/MMPT/mmpt/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d06b9d7974db251143025124468f48cd230e89a --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/evaluators/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .metric import * +from .evaluator import * + + +# experimental. +try: + from .expmetric import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/evaluators/evaluator.py b/fairseq/examples/MMPT/mmpt/evaluators/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..94d9c5ec9a6e84434dbced8d5647754c5a571570 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/evaluators/evaluator.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import glob +import numpy as np + +from . import metric as metric_path +from . import predictor as predictor_path + + +class Evaluator(object): + """ + perform evaluation on a single (downstream) task. + make this both offline and online. + TODO(huxu) saving evaluation results. + """ + + def __init__(self, config, eval_dataloader=None): + if config.metric is None: + raise ValueError("config.metric is", config.metric) + metric_cls = getattr(metric_path, config.metric) + self.metric = metric_cls(config) + if config.predictor is None: + raise ValueError("config.predictor is", config.predictor) + predictor_cls = getattr(predictor_path, config.predictor) + self.predictor = predictor_cls(config) + self.eval_dataloader = eval_dataloader + + def __call__(self): + try: + print(self.predictor.pred_dir) + for pred_file in glob.glob( + self.predictor.pred_dir + "/*_merged.npy"): + outputs = np.load(pred_file) + results = self.metric.compute_metrics(outputs) + self.metric.print_computed_metrics(results) + + outputs = np.load(os.path.join( + self.predictor.pred_dir, "merged.npy")) + results = self.metric.compute_metrics(outputs) + return {"results": results, "metric": self.metric} + except FileNotFoundError: + print("\n[missing]", self.predictor.pred_dir) + return {} + + def evaluate(self, model, eval_dataloader=None, output_file="merged"): + if eval_dataloader is None: + eval_dataloader = self.eval_dataloader + outputs = self.predictor.predict_loop( + model, eval_dataloader, output_file) + results = self.metric.compute_metrics(**outputs) + return results diff --git a/fairseq/examples/MMPT/mmpt/models/transformermodel.py b/fairseq/examples/MMPT/mmpt/models/transformermodel.py new file mode 100644 index 0000000000000000000000000000000000000000..6acc419f09edbd5c9007c4d33d517d59b2f79b77 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/models/transformermodel.py @@ -0,0 +1,734 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch + +from torch import nn + +try: + from transformers.modeling_bert import ( + BertPreTrainedModel, + BertModel, + BertEncoder, + BertPredictionHeadTransform, + ) +except ImportError: + pass + +from ..modules import VideoTokenMLP, MMBertEmbeddings + + +# --------------- fine-tuning models --------------- +class MMBertForJoint(BertPreTrainedModel): + """A BertModel with isolated attention mask to separate modality.""" + + def __init__(self, config): + super().__init__(config) + self.videomlp = VideoTokenMLP(config) + self.bert = MMBertModel(config) + self.init_weights() + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + separate_forward_split=None, + ): + return_dict = ( + return_dict if return_dict is not None + else self.config.use_return_dict + ) + video_tokens = self.videomlp(input_video_embeds) + + outputs = self.bert( + input_ids, + video_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + separate_forward_split=separate_forward_split, + ) + + return outputs + + +class MMBertForTokenClassification(BertPreTrainedModel): + """A BertModel similar to MMJointUni, with extra wrapper layer + to be fine-tuned from other pretrained MMFusion model.""" + + def __init__(self, config): + super().__init__(config) + self.videomlp = VideoTokenMLP(config) + self.bert = MMBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # TODO(huxu): 779 is the number of classes for COIN: move to config? + self.classifier = nn.Linear(config.hidden_size, 779) + self.init_weights() + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + next_sentence_label=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + separate_forward_split=None, + ): + return_dict = ( + return_dict if return_dict is not None + else self.config.use_return_dict + ) + + video_tokens = self.videomlp(input_video_embeds) + outputs = self.bert( + input_ids, + video_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + separate_forward_split=separate_forward_split, + ) + + return (self.classifier(outputs[0]),) + + +# ------------ pre-training models ---------------- + +class MMBertForEncoder(BertPreTrainedModel): + """A BertModel for Contrastive Learning.""" + def __init__(self, config): + super().__init__(config) + self.videomlp = VideoTokenMLP(config) + self.bert = MMBertModel(config) + self.init_weights() + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = ( + return_dict if return_dict is not None + else self.config.use_return_dict + ) + if input_video_embeds is not None: + video_tokens = self.videomlp(input_video_embeds) + else: + video_tokens = None + + outputs = self.bert( + input_ids, + video_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs + + +class MMBertForMFMMLM(BertPreTrainedModel): + """A BertModel with shared prediction head on MFM-MLM.""" + def __init__(self, config): + super().__init__(config) + self.videomlp = VideoTokenMLP(config) + self.bert = MMBertModel(config) + self.cls = MFMMLMHead(config) + self.hidden_size = config.hidden_size + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + masked_frame_labels=None, + target_video_hidden_states=None, + non_masked_frame_mask=None, + masked_lm_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = ( + return_dict if return_dict is not None + else self.config.use_return_dict + ) + if input_video_embeds is not None: + video_tokens = self.videomlp(input_video_embeds) + else: + video_tokens = None + + if target_video_hidden_states is not None: + target_video_hidden_states = self.videomlp( + target_video_hidden_states) + + non_masked_frame_hidden_states = video_tokens.masked_select( + non_masked_frame_mask.unsqueeze(-1) + ).view(-1, self.hidden_size) + + outputs = self.bert( + input_ids, + video_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + mfm_scores, prediction_scores = None, None + if masked_frame_labels is not None and masked_lm_labels is not None: + # split the sequence. + text_offset = masked_frame_labels.size(1) + 1 # [CLS] + video_sequence_output = sequence_output[ + :, 1:text_offset + ] # remove [SEP] as not in video_label. + text_sequence_output = torch.cat( + [sequence_output[:, :1], sequence_output[:, text_offset:]], + dim=1 + ) + + hidden_size = video_sequence_output.size(-1) + selected_video_output = video_sequence_output.masked_select( + masked_frame_labels.unsqueeze(-1) + ).view(-1, hidden_size) + + # only compute select tokens to training to speed up. + hidden_size = text_sequence_output.size(-1) + # masked_lm_labels = masked_lm_labels.reshape(-1) + labels_mask = masked_lm_labels != -100 + + selected_text_output = text_sequence_output.masked_select( + labels_mask.unsqueeze(-1) + ).view(-1, hidden_size) + mfm_scores, prediction_scores = self.cls( + selected_video_output, + target_video_hidden_states, + non_masked_frame_hidden_states, + selected_text_output, + ) + + output = ( + mfm_scores, + prediction_scores, + ) + outputs + return output + + +class BertMFMMLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly + # resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward( + self, + video_hidden_states=None, + target_video_hidden_states=None, + non_masked_frame_hidden_states=None, + text_hidden_states=None, + ): + video_logits, text_logits = None, None + if video_hidden_states is not None: + video_hidden_states = self.transform(video_hidden_states) + non_masked_frame_logits = torch.mm( + video_hidden_states, + non_masked_frame_hidden_states.transpose(1, 0) + ) + masked_frame_logits = torch.bmm( + video_hidden_states.unsqueeze(1), + target_video_hidden_states.unsqueeze(-1), + ).squeeze(-1) + video_logits = torch.cat( + [masked_frame_logits, non_masked_frame_logits], dim=1 + ) + + if text_hidden_states is not None: + text_hidden_states = self.transform(text_hidden_states) + text_logits = self.decoder(text_hidden_states) + return video_logits, text_logits + + +class MFMMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertMFMMLMPredictionHead(config) + + def forward( + self, + video_hidden_states=None, + target_video_hidden_states=None, + non_masked_frame_hidden_states=None, + text_hidden_states=None, + ): + video_logits, text_logits = self.predictions( + video_hidden_states, + target_video_hidden_states, + non_masked_frame_hidden_states, + text_hidden_states, + ) + return video_logits, text_logits + + +class MMBertForMTM(MMBertForMFMMLM): + def __init__(self, config): + BertPreTrainedModel.__init__(self, config) + self.videomlp = VideoTokenMLP(config) + self.bert = MMBertModel(config) + self.cls = MTMHead(config) + self.hidden_size = config.hidden_size + self.init_weights() + + +class BertMTMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + self.decoder = nn.Linear( + config.hidden_size, config.vocab_size, bias=False) + + def forward( + self, + video_hidden_states=None, + target_video_hidden_states=None, + non_masked_frame_hidden_states=None, + text_hidden_states=None, + ): + non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0) + video_logits, text_logits = None, None + if video_hidden_states is not None: + video_hidden_states = self.transform(video_hidden_states) + + masked_frame_logits = torch.bmm( + video_hidden_states.unsqueeze(1), + target_video_hidden_states.unsqueeze(-1), + ).squeeze(-1) + + non_masked_frame_logits = torch.mm( + video_hidden_states, + non_masked_frame_hidden_states + ) + video_on_vocab_logits = self.decoder(video_hidden_states) + video_logits = torch.cat([ + masked_frame_logits, + non_masked_frame_logits, + video_on_vocab_logits], dim=1) + + if text_hidden_states is not None: + text_hidden_states = self.transform(text_hidden_states) + # text first so label does not need to be shifted. + text_on_vocab_logits = self.decoder(text_hidden_states) + text_on_video_logits = torch.mm( + text_hidden_states, + non_masked_frame_hidden_states + ) + text_logits = torch.cat([ + text_on_vocab_logits, + text_on_video_logits + ], dim=1) + + return video_logits, text_logits + + +class MTMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertMTMPredictionHead(config) + + def forward( + self, + video_hidden_states=None, + target_video_hidden_states=None, + non_masked_frame_hidden_states=None, + text_hidden_states=None, + ): + video_logits, text_logits = self.predictions( + video_hidden_states, + target_video_hidden_states, + non_masked_frame_hidden_states, + text_hidden_states, + ) + return video_logits, text_logits + + +class MMBertModel(BertModel): + """MMBertModel has MMBertEmbedding to support video tokens.""" + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + # overwrite embedding + self.embeddings = MMBertEmbeddings(config) + self.encoder = MultiLayerAttentionMaskBertEncoder(config) + self.init_weights() + + def forward( + self, + input_ids=None, + input_video_embeds=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + separate_forward_split=None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None + else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids " + "and inputs_embeds at the same time" + ) + elif input_ids is not None: + if input_video_embeds is not None: + input_shape = ( + input_ids.size(0), + input_ids.size(1) + input_video_embeds.size(1), + ) + else: + input_shape = ( + input_ids.size(0), + input_ids.size(1), + ) + elif inputs_embeds is not None: + if input_video_embeds is not None: + input_shape = ( + inputs_embeds.size(0), + inputs_embeds.size(1) + input_video_embeds.size(1), + ) + else: + input_shape = ( + input_ids.size(0), + input_ids.size(1), + ) + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None \ + else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case + # we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = \ + self.get_extended_attention_mask( + attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to + # [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = ( + encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or + # [num_hidden_layers x num_heads] + # and head_mask is converted to shape + # [num_hidden_layers x batch x num_heads x seq_length x seq_length] + + head_mask = self.get_head_mask( + head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids, + input_video_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + + if separate_forward_split is not None: + split_embedding_output = \ + embedding_output[:, :separate_forward_split] + split_extended_attention_mask = extended_attention_mask[ + :, :, :, :separate_forward_split, :separate_forward_split + ] + split_encoder_outputs = self.encoder( + split_embedding_output, + attention_mask=split_extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + assert ( + len(split_encoder_outputs) <= 2 + ), "we do not support merge on attention for now." + encoder_outputs = [] + encoder_outputs.append([split_encoder_outputs[0]]) + if len(split_encoder_outputs) == 2: + encoder_outputs.append([]) + for _all_hidden_states in split_encoder_outputs[1]: + encoder_outputs[-1].append([_all_hidden_states]) + + split_embedding_output = \ + embedding_output[:, separate_forward_split:] + split_extended_attention_mask = extended_attention_mask[ + :, :, :, separate_forward_split:, separate_forward_split: + ] + + split_encoder_outputs = self.encoder( + split_embedding_output, + attention_mask=split_extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + assert ( + len(split_encoder_outputs) <= 2 + ), "we do not support merge on attention for now." + encoder_outputs[0].append(split_encoder_outputs[0]) + encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1) + if len(split_encoder_outputs) == 2: + for layer_idx, _all_hidden_states in enumerate( + split_encoder_outputs[1] + ): + encoder_outputs[1][layer_idx].append(_all_hidden_states) + encoder_outputs[1][layer_idx] = torch.cat( + encoder_outputs[1][layer_idx], dim=1 + ) + encoder_outputs = tuple(encoder_outputs) + else: + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + return (sequence_output, pooled_output) + encoder_outputs[1:] + + def get_extended_attention_mask(self, attention_mask, input_shape, device): + """This is borrowed from `modeling_utils.py` with the support of + multi-layer attention masks. + The second dim is expected to be number of layers. + See `MMAttentionMaskProcessor`. + Makes broadcastable attention and causal masks so that future + and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, + zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, \ + with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions + # [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable + # to all heads. + if attention_mask.dim() == 4: + extended_attention_mask = attention_mask[:, :, None, :, :] + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) \ + * -10000.0 + return extended_attention_mask + else: + return super().get_extended_attention_mask( + attention_mask, input_shape, device + ) + + +class MultiLayerAttentionMaskBertEncoder(BertEncoder): + """extend BertEncoder with the capability of + multiple layers of attention mask.""" + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_attention_mask = ( + attention_mask[:, i, :, :, :] + if attention_mask.dim() == 5 + else attention_mask + ) + + if getattr(self.config, "gradient_checkpointing", False): + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + layer_attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return tuple( + v + for v in [hidden_states, all_hidden_states, all_attentions] + if v is not None + ) diff --git a/fairseq/examples/MMPT/mmpt/modules/__init__.py b/fairseq/examples/MMPT/mmpt/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c78594c21d7ccdcb2ab11c918dd211adce27c14 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/modules/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .mm import * + +try: + from .expmm import * +except ImportError: + pass diff --git a/fairseq/examples/MMPT/mmpt/modules/mm.py b/fairseq/examples/MMPT/mmpt/modules/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9777371a5e7afc235a8b4e6c164b2a35258eb6 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/modules/mm.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Facebook, Inc. All Rights Reserved + + +import torch + +from torch import nn + +try: + from transformers.modeling_bert import ( + BertEmbeddings, + ACT2FN, + ) +except ImportError: + pass + + +class VideoTokenMLP(nn.Module): + def __init__(self, config): + super().__init__() + input_dim = config.input_dim if hasattr(config, "input_dim") else 512 + self.linear1 = nn.Linear(input_dim, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.linear2 = nn.Linear(config.hidden_size, config.hidden_size) + + def forward(self, hidden_states): + hidden_states = self.linear1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.linear2(hidden_states) + return hidden_states + + +class MMBertEmbeddings(BertEmbeddings): + def __init__(self, config): + super().__init__(config) + self.max_video_len = config.max_video_len + if hasattr(config, "use_seg_emb") and config.use_seg_emb: + """the original VLM paper uses seg_embeddings for temporal space. + although not used it changed the randomness of initialization. + we keep it for reproducibility. + """ + self.seg_embeddings = nn.Embedding(256, config.hidden_size) + + def forward( + self, + input_ids, + input_video_embeds, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + ): + input_tensor = input_ids if input_ids is not None else inputs_embeds + if input_video_embeds is not None: + input_shape = ( + input_tensor.size(0), + input_tensor.size(1) + input_video_embeds.size(1), + ) + else: + input_shape = (input_tensor.size(0), input_tensor.size(1)) + + if position_ids is None: + """ + Auto skip position embeddings for text only case. + use cases: + (1) action localization and segmentation: + feed in len-1 dummy video token needs text part to + skip input_video_embeds.size(1) for the right + position_ids for video [SEP] and rest text tokens. + (2) MMFusionShare for two forward passings: + in `forward_text`: input_video_embeds is None. + need to skip video [SEP] token. + + # video_len + 1: [CLS] + video_embed + # self.max_video_len + 1: [SEP] for video. + # self.max_video_len + 2: [SEP] for video. + # self.max_video_len + input_ids.size(1): rest for text. + """ + if input_video_embeds is not None: + video_len = input_video_embeds.size(1) + starting_offset = self.max_video_len + 1 # video [SEP] + ending_offset = self.max_video_len + input_ids.size(1) + else: + video_len = 0 + starting_offset = self.max_video_len + 2 # first text token. + ending_offset = self.max_video_len + input_ids.size(1) + 1 + position_ids = torch.cat([ + self.position_ids[:, :video_len + 1], + self.position_ids[:, starting_offset:ending_offset] + ], dim=1) + + if token_type_ids is None: + token_type_ids = torch.zeros( + input_shape, dtype=torch.long, device=self.position_ids.device + ) + + """ + the format of input_ids is [CLS] [SEP] caption [SEP] padding. + the goal is to build [CLS] video tokens [SEP] caption [SEP] . + """ + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if input_video_embeds is not None: + inputs_mm_embeds = torch.cat([ + inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:] + ], dim=1) + else: + # text only for `MMFusionShare`. + inputs_mm_embeds = inputs_embeds + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_mm_embeds + position_embeddings + embeddings += token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class AlignHead(nn.Module): + """this will load pre-trained weights for NSP, which is desirable.""" + + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, dropout_pooled_output): + logits = self.seq_relationship(dropout_pooled_output) + return logits diff --git a/fairseq/examples/MMPT/mmpt/modules/retri.py b/fairseq/examples/MMPT/mmpt/modules/retri.py new file mode 100644 index 0000000000000000000000000000000000000000..d1b288f8e52308eeaefcb4411cb3d23e543f37d0 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/modules/retri.py @@ -0,0 +1,429 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import numpy as np +import pickle +import time + +try: + import faiss +except ImportError: + pass + +from collections import defaultdict + +from ..utils import get_local_rank, print_on_rank0 + + +class VectorRetriever(object): + """ + How2 Video Retriver. + Reference usage of FAISS: + https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py + """ + + def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): + if db_type == "flatl2": + quantizer = faiss.IndexFlatL2(hidden_size) # the other index + self.db = faiss.IndexIVFFlat( + quantizer, hidden_size, cent, faiss.METRIC_L2) + elif db_type == "pq": + self.db = faiss.index_factory( + hidden_size, f"IVF{cent}_HNSW32,PQ32" + ) + else: + raise ValueError("unknown type of db", db_type) + self.train_thres = cent * examples_per_cent_to_train + self.train_cache = [] + self.train_len = 0 + self.videoid_to_vectoridx = {} + self.vectoridx_to_videoid = None + self.make_direct_maps_done = False + + def make_direct_maps(self): + faiss.downcast_index(self.db).make_direct_map() + + def __len__(self): + return self.db.ntotal + + def save(self, out_dir): + faiss.write_index( + self.db, + os.path.join(out_dir, "faiss_idx") + ) + with open( + os.path.join( + out_dir, "videoid_to_vectoridx.pkl"), + "wb") as fw: + pickle.dump( + self.videoid_to_vectoridx, fw, + protocol=pickle.HIGHEST_PROTOCOL + ) + + def load(self, out_dir): + fn = os.path.join(out_dir, "faiss_idx") + self.db = faiss.read_index(fn) + with open( + os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: + self.videoid_to_vectoridx = pickle.load(fr) + + def add(self, hidden_states, video_ids, last=False): + assert len(hidden_states) == len(video_ids), "{}, {}".format( + str(len(hidden_states)), str(len(video_ids))) + assert len(hidden_states.shape) == 2 + assert hidden_states.dtype == np.float32 + + valid_idx = [] + for idx, video_id in enumerate(video_ids): + if video_id not in self.videoid_to_vectoridx: + valid_idx.append(idx) + self.videoid_to_vectoridx[video_id] = \ + len(self.videoid_to_vectoridx) + + hidden_states = hidden_states[valid_idx] + if not self.db.is_trained: + self.train_cache.append(hidden_states) + self.train_len += hidden_states.shape[0] + if self.train_len < self.train_thres: + return + self.finalize_training() + else: + self.db.add(hidden_states) + + def finalize_training(self): + hidden_states = np.concatenate(self.train_cache, axis=0) + del self.train_cache + local_rank = get_local_rank() + if local_rank == 0: + start = time.time() + print("training db on", self.train_thres, "/", self.train_len) + self.db.train(hidden_states[:self.train_thres]) + if local_rank == 0: + print("training db for", time.time() - start) + self.db.add(hidden_states) + + def search( + self, + query_hidden_states, + orig_dist, + ): + if len(self.videoid_to_vectoridx) != self.db.ntotal: + raise ValueError( + "cannot search: size mismatch in-between index and db", + len(self.videoid_to_vectoridx), + self.db.ntotal + ) + + if self.vectoridx_to_videoid is None: + self.vectoridx_to_videoid = { + self.videoid_to_vectoridx[videoid]: videoid + for videoid in self.videoid_to_vectoridx + } + assert len(self.vectoridx_to_videoid) \ + == len(self.videoid_to_vectoridx) + + # MultilingualFaissDataset uses the following; not sure the purpose. + # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) + queried_dist, index = self.db.search(query_hidden_states, 1) + queried_dist, index = queried_dist[:, 0], index[:, 0] + + outputs = np.array( + [self.vectoridx_to_videoid[_index] + if _index != -1 else (-1, -1, -1) for _index in index], + dtype=np.int32) + outputs[queried_dist <= orig_dist] = -1 + return outputs + + def search_by_video_ids( + self, + video_ids, + retri_factor + ): + if len(self.videoid_to_vectoridx) != self.db.ntotal: + raise ValueError( + len(self.videoid_to_vectoridx), + self.db.ntotal + ) + + if not self.make_direct_maps_done: + self.make_direct_maps() + + if self.vectoridx_to_videoid is None: + self.vectoridx_to_videoid = { + self.videoid_to_vectoridx[videoid]: videoid + for videoid in self.videoid_to_vectoridx + } + assert len(self.vectoridx_to_videoid) \ + == len(self.videoid_to_vectoridx) + + query_hidden_states = [] + vector_ids = [] + for video_id in video_ids: + vector_id = self.videoid_to_vectoridx[video_id] + vector_ids.append(vector_id) + query_hidden_state = self.db.reconstruct(vector_id) + query_hidden_states.append(query_hidden_state) + query_hidden_states = np.stack(query_hidden_states) + + # MultilingualFaissDataset uses the following; not sure the reason. + # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) + _, index = self.db.search(query_hidden_states, retri_factor) + outputs = [] + for sample_idx, sample in enumerate(index): + # the first video_id is always the video itself. + cands = [video_ids[sample_idx]] + for vector_idx in sample: + if vector_idx >= 0 \ + and vector_ids[sample_idx] != vector_idx: + cands.append( + self.vectoridx_to_videoid[vector_idx] + ) + outputs.append(cands) + return outputs + + +class VectorRetrieverDM(VectorRetriever): + """ + with direct map. + How2 Video Retriver. + Reference usage of FAISS: + https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py + """ + + def __init__( + self, + hidden_size, + cent, + db_type, + examples_per_cent_to_train + ): + super().__init__( + hidden_size, cent, db_type, examples_per_cent_to_train) + self.make_direct_maps_done = False + + def make_direct_maps(self): + faiss.downcast_index(self.db).make_direct_map() + self.make_direct_maps_done = True + + def search( + self, + query_hidden_states, + orig_dist, + ): + if len(self.videoid_to_vectoridx) != self.db.ntotal: + raise ValueError( + len(self.videoid_to_vectoridx), + self.db.ntotal + ) + + if not self.make_direct_maps_done: + self.make_direct_maps() + if self.vectoridx_to_videoid is None: + self.vectoridx_to_videoid = { + self.videoid_to_vectoridx[videoid]: videoid + for videoid in self.videoid_to_vectoridx + } + assert len(self.vectoridx_to_videoid) \ + == len(self.videoid_to_vectoridx) + + # MultilingualFaissDataset uses the following; not sure the reason. + # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) + queried_dist, index = self.db.search(query_hidden_states, 1) + outputs = [] + for sample_idx, sample in enumerate(index): + # and queried_dist[sample_idx] < thres \ + if sample >= 0 \ + and queried_dist[sample_idx] < orig_dist[sample_idx]: + outputs.append(self.vectoridx_to_videoid[sample]) + else: + outputs.append(None) + return outputs + + def search_by_video_ids( + self, + video_ids, + retri_factor=8 + ): + if len(self.videoid_to_vectoridx) != self.db.ntotal: + raise ValueError( + len(self.videoid_to_vectoridx), + self.db.ntotal + ) + + if not self.make_direct_maps_done: + self.make_direct_maps() + if self.vectoridx_to_videoid is None: + self.vectoridx_to_videoid = { + self.videoid_to_vectoridx[videoid]: videoid + for videoid in self.videoid_to_vectoridx + } + assert len(self.vectoridx_to_videoid) \ + == len(self.videoid_to_vectoridx) + + query_hidden_states = [] + vector_ids = [] + for video_id in video_ids: + vector_id = self.videoid_to_vectoridx[video_id] + vector_ids.append(vector_id) + query_hidden_state = self.db.reconstruct(vector_id) + query_hidden_states.append(query_hidden_state) + query_hidden_states = np.stack(query_hidden_states) + + # MultilingualFaissDataset uses the following; not sure the reason. + # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) + _, index = self.db.search(query_hidden_states, retri_factor) + outputs = [] + for sample_idx, sample in enumerate(index): + # the first video_id is always the video itself. + cands = [video_ids[sample_idx]] + for vector_idx in sample: + if vector_idx >= 0 \ + and vector_ids[sample_idx] != vector_idx: + cands.append( + self.vectoridx_to_videoid[vector_idx] + ) + outputs.append(cands) + return outputs + + +class MMVectorRetriever(VectorRetrieverDM): + """ + multimodal vector retriver: + text retrieve video or video retrieve text. + """ + + def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): + super().__init__( + hidden_size, cent, db_type, examples_per_cent_to_train) + video_db = self.db + super().__init__( + hidden_size, cent, db_type, examples_per_cent_to_train) + text_db = self.db + self.db = {"video": video_db, "text": text_db} + self.video_to_videoid = defaultdict(list) + + def __len__(self): + assert self.db["video"].ntotal == self.db["text"].ntotal + return self.db["video"].ntotal + + def make_direct_maps(self): + faiss.downcast_index(self.db["video"]).make_direct_map() + faiss.downcast_index(self.db["text"]).make_direct_map() + + def save(self, out_dir): + faiss.write_index( + self.db["video"], + os.path.join(out_dir, "video_faiss_idx") + ) + faiss.write_index( + self.db["text"], + os.path.join(out_dir, "text_faiss_idx") + ) + + with open( + os.path.join( + out_dir, "videoid_to_vectoridx.pkl"), + "wb") as fw: + pickle.dump( + self.videoid_to_vectoridx, fw, + protocol=pickle.HIGHEST_PROTOCOL + ) + + def load(self, out_dir): + fn = os.path.join(out_dir, "video_faiss_idx") + video_db = faiss.read_index(fn) + fn = os.path.join(out_dir, "text_faiss_idx") + text_db = faiss.read_index(fn) + self.db = {"video": video_db, "text": text_db} + with open( + os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: + self.videoid_to_vectoridx = pickle.load(fr) + self.video_to_videoid = defaultdict(list) + + def add(self, hidden_states, video_ids): + """hidden_states is a pair `(video, text)`""" + assert len(hidden_states) == len(video_ids), "{}, {}".format( + str(len(hidden_states)), str(len(video_ids))) + assert len(hidden_states.shape) == 3 + assert len(self.video_to_videoid) == 0 + + valid_idx = [] + for idx, video_id in enumerate(video_ids): + if video_id not in self.videoid_to_vectoridx: + valid_idx.append(idx) + self.videoid_to_vectoridx[video_id] = \ + len(self.videoid_to_vectoridx) + + batch_size = hidden_states.shape[0] + hidden_states = hidden_states[valid_idx] + + hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy() + if not self.db["video"].is_trained: + self.train_cache.append(hidden_states) + train_len = batch_size * len(self.train_cache) + if train_len < self.train_thres: + return + + hidden_states = np.concatenate(self.train_cache, axis=1) + del self.train_cache + self.db["video"].train(hidden_states[0, :self.train_thres]) + self.db["text"].train(hidden_states[1, :self.train_thres]) + self.db["video"].add(hidden_states[0]) + self.db["text"].add(hidden_states[1]) + + def get_clips_by_video_id(self, video_id): + if not self.video_to_videoid: + for video_id, video_clip, text_clip in self.videoid_to_vectoridx: + self.video_to_videoid[video_id].append( + (video_id, video_clip, text_clip)) + return self.video_to_videoid[video_id] + + def search( + self, + video_ids, + target_modality, + retri_factor=8 + ): + if len(self.videoid_to_vectoridx) != len(self): + raise ValueError( + len(self.videoid_to_vectoridx), + len(self) + ) + + if not self.make_direct_maps_done: + self.make_direct_maps() + if self.vectoridx_to_videoid is None: + self.vectoridx_to_videoid = { + self.videoid_to_vectoridx[videoid]: videoid + for videoid in self.videoid_to_vectoridx + } + assert len(self.vectoridx_to_videoid) \ + == len(self.videoid_to_vectoridx) + + src_modality = "text" if target_modality == "video" else "video" + + query_hidden_states = [] + vector_ids = [] + for video_id in video_ids: + vector_id = self.videoid_to_vectoridx[video_id] + vector_ids.append(vector_id) + query_hidden_state = self.db[src_modality].reconstruct(vector_id) + query_hidden_states.append(query_hidden_state) + query_hidden_states = np.stack(query_hidden_states) + + # MultilingualFaissDataset uses the following; not sure the reason. + # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) + _, index = self.db[target_modality].search( + query_hidden_states, retri_factor) + outputs = [] + for sample_idx, sample in enumerate(index): + cands = [] + for vector_idx in sample: + if vector_idx >= 0: + cands.append( + self.vectoridx_to_videoid[vector_idx] + ) + outputs.append(cands) + return outputs diff --git a/fairseq/examples/MMPT/mmpt/modules/vectorpool.py b/fairseq/examples/MMPT/mmpt/modules/vectorpool.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b23d2da888c714d7605d508c558fbad34e709d --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/modules/vectorpool.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. All Rights Reserved + +import torch +import os +import numpy as np +import pickle + +from . import retri +from ..utils import get_local_rank + + +class VectorPool(object): + """ + Base class of retrieval space. + """ + + def __init__(self, config): + from transformers import AutoConfig + self.hidden_size = AutoConfig.from_pretrained( + config.dataset.bert_name).hidden_size + self.retriever_cls = getattr(retri, config.retriever_cls) + + def __call__(self, sample, **kwargs): + raise NotImplementedError + + def build_retriver( + self, + retriever_cls=None, + hidden_size=None, + centroids=512, + db_type="flatl2", + examples_per_cent_to_train=48 + ): + + """merge results from multiple gpus and return a retriver..""" + self.retriver = retriever_cls( + hidden_size, centroids, db_type, examples_per_cent_to_train) + return self.retriver + + def __repr__(self): + if hasattr(self, "retriver"): + retriver_name = str(len(self.retriver)) + else: + retriver_name = "no retriver field yet" + return self.__class__.__name__ \ + + "(" + retriver_name + ")" + + +class VideoVectorPool(VectorPool): + """ + average clips of a video as video representation. + """ + def __init__(self, config): + super().__init__(config) + self.build_retriver(self.retriever_cls, self.hidden_size) + + def __call__(self, sample, subsampling, **kwargs): + hidden_states = ( + sample["pooled_video"] + sample["pooled_text"]) / 2. + hidden_states = hidden_states.view( + -1, subsampling, + hidden_states.size(-1)) + hidden_states = torch.mean(hidden_states, dim=1) + hidden_states = hidden_states.cpu().detach().numpy() + video_ids = [] + for offset_idx, video_id in enumerate(sample["video_id"]): + if isinstance(video_id, tuple) and len(video_id) == 3: + # a sharded video_id. + video_id = video_id[0] + video_ids.append(video_id) + assert len(video_ids) == len(hidden_states) + self.retriver.add( + hidden_states.astype("float32"), + video_ids + ) + + +class DistributedVectorPool(VectorPool): + """ + support sync of multiple gpus/nodes. + """ + def __init__(self, config): + super().__init__(config) + self.out_dir = os.path.join( + config.fairseq.checkpoint.save_dir, + "retri") + os.makedirs(self.out_dir, exist_ok=True) + self.hidden_states = [] + self.video_ids = [] + + def build_retriver( + self, + retriever_cls=None, + hidden_size=None, + centroids=4096, + db_type="flatl2", + examples_per_cent_to_train=48 + ): + if retriever_cls is None: + retriever_cls = self.retriever_cls + if hidden_size is None: + hidden_size = self.hidden_size + """merge results from multiple gpus and return a retriver..""" + if torch.distributed.is_initialized(): + self.save() + # sync saving. + torch.distributed.barrier() + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + self.retriver = retriever_cls( + hidden_size, centroids, db_type, examples_per_cent_to_train) + # each gpu process has its own retriever. + for local_rank in range(world_size): + if get_local_rank() == 0: + print("load local_rank", local_rank) + hidden_states, video_ids = self.load(local_rank) + hidden_states = hidden_states.astype("float32") + self.retriver.add(hidden_states, video_ids) + return self.retriver + + def load(self, local_rank): + hidden_states = np.load( + os.path.join( + self.out_dir, + "hidden_state" + str(local_rank) + ".npy" + ) + ) + + with open( + os.path.join( + self.out_dir, "video_id" + str(local_rank) + ".pkl"), + "rb") as fr: + video_ids = pickle.load(fr) + return hidden_states, video_ids + + def save(self): + hidden_states = np.vstack(self.hidden_states) + assert len(hidden_states) == len(self.video_ids), "{}, {}".format( + len(hidden_states), + len(self.video_ids) + ) + local_rank = torch.distributed.get_rank() \ + if torch.distributed.is_initialized() else 0 + + np.save( + os.path.join( + self.out_dir, + "hidden_state" + str(local_rank) + ".npy"), + hidden_states) + + with open( + os.path.join( + self.out_dir, + "video_id" + str(local_rank) + ".pkl"), + "wb") as fw: + pickle.dump( + self.video_ids, + fw, + protocol=pickle.HIGHEST_PROTOCOL + ) + + +class DistributedVideoVectorPool(DistributedVectorPool): + """ + average clips of a video as video representation. + """ + def __call__(self, sample, subsampling, **kwargs): + hidden_states = ( + sample["pooled_video"] + sample["pooled_text"]) / 2. + hidden_states = hidden_states.view( + -1, subsampling, + hidden_states.size(-1)) + hidden_states = torch.mean(hidden_states, dim=1) + hidden_states = hidden_states.cpu().detach().numpy() + video_ids = [] + for offset_idx, video_id in enumerate(sample["video_id"]): + if isinstance(video_id, tuple) and len(video_id) == 3: + # a sharded video_id. + video_id = video_id[0] + video_ids.append(video_id) + assert len(video_ids) == len(hidden_states) + self.hidden_states.append(hidden_states) + self.video_ids.extend(video_ids) + + +# ------------ the following are deprecated -------------- + +class TextClipVectorPool(VectorPool): + def __init__(self, config): + from transformers import AutoConfig + hidden_size = AutoConfig.from_pretrained( + config.dataset.bert_name).hidden_size + retriever_cls = getattr(retri, config.retriever_cls) + self.build_retriver(retriever_cls, hidden_size) + + def __call__(self, sample, **kwargs): + clip_meta = sample["clip_meta"].cpu() + assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5])) + text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]] + + if hasattr(self, "retriver"): + # build_retriver is called. + self.retriver.add( + sample["pooled_text"].cpu().numpy().astype("float32"), + text_meta + ) + else: + raise NotImplementedError + + +class MMClipVectorPool(VectorPool): + """ + Multimodal Clip-level vector pool. + """ + def __init__(self, out_dir): + """use hidden_states to store `(video, text)`.""" + """use video_ids to store `(video_id, start, end)`.""" + super().__init__(out_dir) + + def __call__(self, sample, **kwargs): + pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy() + pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy() + + self.hidden_states.append( + np.concatenate([pooled_video, pooled_text], axis=1) + ) + + video_starts = sample["video_start"].cpu() + video_ends = sample["video_end"].cpu() + assert torch.all(torch.le(video_starts, video_ends)) + + text_starts = sample["text_start"].cpu() + text_ends = sample["text_end"].cpu() + assert torch.all(torch.le(text_starts, text_ends)) + subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"]) + video_ids = [video_id for video_id in sample["video_id"] + for _ in range(subsample_size) + ] + for video_id, video_start, video_end, text_start, text_end in zip( + video_ids, video_starts, video_ends, text_starts, text_ends): + self.video_ids.append(( + video_id, + (int(video_start), int(video_end)), + (int(text_start), int(text_end)) + )) diff --git a/fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py b/fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..8a1ad402cd890516457c188910ec265a47fc6e8e --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py @@ -0,0 +1,242 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import json +import pickle +from tqdm import tqdm +import os +import numpy as np + + +class CaptionDedupProcessor(object): + """remove overlapping of caption sentences(clip). + Some statistics: + caption: + {'t_clip_len': 246.6448431320854, + 'video_len': 281.09174795676245, + 'clip_tps': 0.8841283727427481, + 'video_tps': 0.7821156477732097, + 'min_clip_len': 0.0, + 'max_clip_len': 398.3, + 'mean_clip_len': 3.196580003006861, + 'num_clip': 77.15897706301081} + + raw_caption: + {'t_clip_len': 238.95908778424115, + 'video_len': 267.5914859862507, + 'clip_tps': 2.4941363624267963, + 'video_tps': 2.258989769647173, + 'min_clip_len': 0.0, + 'max_clip_len': 398.3, + 'mean_clip_len': 3.0537954186814265, + 'num_clip': 78.24986779481756} + """ + + def __init__(self, pkl_file): + with open(pkl_file, "rb") as fd: + self.data = pickle.load(fd) + self.stat = { + "t_clip_len": [], + "video_len": [], + "clip_tps": [], + "video_tps": [], + "clip_len": [], + } + + def __call__(self): + for idx, video_id in enumerate(tqdm(self.data)): + caption = json.loads(self.data[video_id]) + caption = self._dedup(caption) + if idx < 4096: # for the first 4096 examples, compute the statistics. + self.save_stat(video_id, caption) + self.data[video_id] = json.dumps(caption) + self.print_stat() + + def single(self, video_id): + caption = json.loads(self.data[video_id]) + for clip_idx, (start, end, text) in enumerate( + zip(caption["start"], caption["end"], caption["text"]) + ): + print(start, end, text) + print("@" * 100) + caption = self._dedup(caption) + for clip_idx, (start, end, text) in enumerate( + zip(caption["start"], caption["end"], caption["text"]) + ): + print(start, end, text) + print("#" * 100) + self.save_stat(video_id, caption) + self.print_stat() + + def finalize(self, tgt_fn): + with open(tgt_fn, "wb") as fw: + pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL) + + def save_stat(self, video_id, caption): + video_fn = os.path.join( + "data/feat/feat_how2_s3d", video_id + ".npy" + ) + if os.path.isfile(video_fn): + with open(video_fn, "rb", 1) as fr: # 24 is the buffer size. buffered + version = np.lib.format.read_magic(fr) + shape, fortran, dtype = np.lib.format._read_array_header(fr, version) + video_len = shape[0] + + t_clip_len = 0.0 + t_tokens = 0 + for idx, (start, end, text) in enumerate( + zip(caption["start"], caption["end"], caption["text"]) + ): + clip_len = ( + (end - max(caption["end"][idx - 1], start)) + if idx > 0 + else end - start + ) + t_clip_len += clip_len + t_tokens += len(text.split(" ")) + self.stat["clip_len"].append(clip_len) + self.stat["t_clip_len"].append(t_clip_len) + self.stat["video_len"].append(video_len) + self.stat["clip_tps"].append(t_tokens / t_clip_len) + self.stat["video_tps"].append(t_tokens / video_len) + + def print_stat(self): + result = { + "t_clip_len": np.mean(self.stat["t_clip_len"]), + "video_len": np.mean(self.stat["video_len"]), + "clip_tps": np.mean(self.stat["clip_tps"]), + "video_tps": np.mean(self.stat["video_tps"]), + "min_clip_len": min(self.stat["clip_len"]), + "max_clip_len": max(self.stat["clip_len"]), + "mean_clip_len": np.mean(self.stat["clip_len"]), + "num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]), + } + print(result) + + def _dedup(self, caption): + def random_merge(end_idx, start, end, text, starts, ends, texts): + if random.random() > 0.5: + # print(clip_idx, "[PARTIAL INTO PREV]", end_idx) + # overlapped part goes to the end of previous. + ends[-1] = max(ends[-1], start) # ? + rest_text = text[end_idx:].strip() + if rest_text: + starts.append(max(ends[-1], start)) + ends.append(max(end, starts[-1])) + texts.append(rest_text) + else: # goes to the beginning of the current. + # strip the previous. + left_text = texts[-1][:-end_idx].strip() + if left_text: + # print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx) + ends[-1] = min(ends[-1], start) + texts[-1] = left_text + else: + # print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx) + starts.pop(-1) + ends.pop(-1) + texts.pop(-1) + starts.append(start) + ends.append(end) + texts.append(text) + + starts, ends, texts = [], [], [] + for clip_idx, (start, end, text) in enumerate( + zip(caption["start"], caption["end"], caption["text"]) + ): + if not isinstance(text, str): + continue + text = text.replace("\n", " ").strip() + if len(text) == 0: + continue + starts.append(start) + ends.append(end) + texts.append(text) + break + + for clip_idx, (start, end, text) in enumerate( + zip( + caption["start"][clip_idx + 1:], + caption["end"][clip_idx + 1:], + caption["text"][clip_idx + 1:], + ) + ): + if not isinstance(text, str): + continue + text = text.replace("\n", " ").strip() + if len(text) == 0: + continue + + # print(clip_idx, texts[-5:]) + # print(clip_idx, start, end, text) + if texts[-1].endswith(text): # subset of prev caption -> merge + # print(clip_idx, "[MERGE INTO PREV]") + ends[-1] = max(ends[-1], end) + elif text.startswith(texts[-1]): # superset of prev caption -> merge + # print(clip_idx, "[PREV MERGE INTO CUR]") + texts[-1] = text + starts[-1] = min(starts[-1], start) + ends[-1] = max(ends[-1], end) + else: # overlapping or non-overlapping. + for end_idx in range(1, len(text) + 1): + if texts[-1].endswith(text[:end_idx]): + random_merge(end_idx, start, end, text, starts, ends, texts) + break + else: + starts.append(start) + ends.append(end) + texts.append(text) + + assert (ends[-1] + 0.001) >= starts[-1] and len( + texts[-1] + ) > 0, "{} {} {} <- {} {} {}, {} {} {}".format( + str(starts[-1]), + str(ends[-1]), + texts[-1], + caption["start"][clip_idx - 1], + caption["end"][clip_idx - 1], + caption["text"][clip_idx - 1], + str(start), + str(end), + text, + ) + + return {"start": starts, "end": ends, "text": texts} + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="dedup how2 caption") + parser.add_argument('--how2dir', default="data/how2") + args = parser.parse_args() + + raw_caption_json = os.path.join(args.how2dir, "raw_caption.json") + raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl") + raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl") + + def convert_to_pickle(src_fn, tgt_fn): + with open(src_fn) as fd: + captions = json.load(fd) + + for video_id in captions: + captions[video_id] = json.dumps(captions[video_id]) + + with open(tgt_fn, "wb") as fw: + pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL) + + if not os.path.isfile(raw_caption_pickle): + convert_to_pickle(raw_caption_json, raw_caption_pickle) + + deduper = CaptionDedupProcessor(raw_caption_pickle) + deduper() + deduper.finalize(raw_caption_dedup_pickle) + + """ + # demo + deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl") + deduper.single("HfIeQ9pzL5U") + """ diff --git a/fairseq/examples/MMPT/mmpt/processors/how2processor.py b/fairseq/examples/MMPT/mmpt/processors/how2processor.py new file mode 100644 index 0000000000000000000000000000000000000000..bed2168b1df28babc7a12e81b6bc31d36d73bc99 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/how2processor.py @@ -0,0 +1,887 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright (c) Facebook, Inc. All Rights Reserved + + +import torch +import math +import pickle +import random +import os +import numpy as np + +from collections import deque +from typing import Optional, Tuple, List +from .processor import ( + Processor, + MetaProcessor, + TextProcessor, + Aligner, + MMAttentionMask2DProcessor +) + +from ..utils import ShardedTensor + + +class How2MetaProcessor(MetaProcessor): + def __init__(self, config): + super().__init__(config) + path = self._get_split_path(config) + with open(path) as fd: + self.data = [line.strip() for line in fd] + + def __getitem__(self, idx): + video_id = self.data[idx] + return video_id, video_id + + +class ShardedHow2MetaProcessor(How2MetaProcessor): + def __init__(self, config): + super().__init__(config) + self.split = str(config.split) + self.vfeat_dir = config.vfeat_dir + self._init_shard() + + def _init_shard(self): + if self.split == "train": + meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl") + with open(meta_fn, "rb") as fr: + meta = pickle.load(fr) + elif self.split == "valid": + meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl") + with open(meta_fn, "rb") as fr: + meta = pickle.load(fr) + elif self.split == "test": + print("use how2 val as test.") + meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl") + with open(meta_fn, "rb") as fr: + meta = pickle.load(fr) + else: + raise ValueError("unsupported for MetaProcessor:", self.split) + video_id_to_shard = {} + for shard_id in meta: + for video_idx, video_id in enumerate(meta[shard_id]): + video_id_to_shard[video_id] = (shard_id, video_idx) + self.video_id_to_shard = video_id_to_shard + + def __getitem__(self, idx): + video_id, video_id = super().__getitem__(idx) + shard_id, shard_idx = self.video_id_to_shard[video_id] + meta = (video_id, idx, shard_id, shard_idx) + return meta, meta + + +class ShardedVideoProcessor(Processor): + """ + mmaped shards of numpy video features. + """ + + def __init__(self, config): + self.split = str(config.split) + self.vfeat_dir = config.vfeat_dir + + def __call__(self, video_id): + _, _, shard_id, video_idx = video_id + if self.split == "train": + shard = ShardedTensor.load( + os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)), + "r" + ) + elif self.split == "valid": + shard = ShardedTensor.load( + os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)), + "r" + ) + elif self.split == "test": + shard = ShardedTensor.load( + os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)), + "r" + ) + else: + raise ValueError("unknown split", self.split) + feat = shard[video_idx] + return feat + + +class ShardedTextProcessor(Processor): + def __init__(self, config): + self.tfeat_dir = str(config.tfeat_dir) + self.split = str(config.split) + + def __call__(self, video_id): + _, _, shard_id, shard_idx = video_id + if self.split == "train": + target_path = self.tfeat_dir + "train" + "_" + str(shard_id) + elif self.split == "valid": + target_path = self.tfeat_dir + "val" + "_" + str(shard_id) + elif self.split == "test": + target_path = self.tfeat_dir + "val" + "_" + str(shard_id) + else: + raise ValueError("unknown split", self.split) + + startend = ShardedTensor.load( + target_path + ".startends", "r")[shard_idx] + cap_ids = ShardedTensor.load( + target_path + ".caps_ids", "r")[shard_idx] + cap = [] + for clip_idx in range(len(cap_ids)): + clip = cap_ids[clip_idx] + cap.append(clip[clip != -1].tolist()) + start, end = startend[:, 0].tolist(), startend[:, 1].tolist() + return {"start": start, "end": end, "cap": cap} + + +class FixedLenAligner(Aligner): + """ + In the model we assume text is on the left (closer to BERT formulation) + and video is on the right. + We fix the total length of text + video. + max_video_len is in number of secs. + max_text_len is in number of tokens. + + special tokens formats: + we use the format [CLS] [SEP] text tokens [SEP] [PAD] ... + [CLS] will be splitted out into: + [CLS] video tokens [SEP] text tokens [SEP] [PAD] ... + token_type_ids will be generated by the model (for now). + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + so each sequence owns a [SEP] token for no-ops. + """ + + def __init__(self, config): + super().__init__(config) + self.text_clip_sampler = TextClipSamplingProcessor( + self.max_len - self.max_video_len - 3 + ) + """ + decide subsampling: + `config.subsampling` will change batch_size in trainer. + `config.clip_per_video` (used by RetriTask) doesn't + change batch_size in trainer. + """ + subsampling = config.subsampling \ + if config.subsampling is not None else None + if config.clip_per_video is not None: + subsampling = config.clip_per_video + self.subsampling = subsampling + + def _get_text_maxlen(self): + # use max text len + return self.text_clip_sampler.max_text_len + + def __call__(self, video_id, video_feature, text_feature): + from transformers import default_data_collator + video_idx = video_id[1] + if self.subsampling is not None and self.subsampling >= 1: + batch = [] + for _ in range(self.subsampling): + centerclip_idx = random.randint( + 0, len(text_feature["start"]) - 1) + batch.append( + self.sampling( + video_idx, + video_feature, + text_feature, + centerclip_idx, + self._get_text_maxlen() + )) + batch = self.batch_post_processing(batch, video_feature) + batch = default_data_collator(batch) + else: + raise ValueError( + "dataset.subsampling must be >= 1 for efficient video loading.") + batch = self.sampling(video_idx, video_feature, text_feature) + batch = self.batch_post_processing(batch, video_feature) + + batch["video_id"] = video_id if isinstance(video_id, str) \ + else video_id[0] + # e2e: make sure frame ids is into tensor. + assert torch.is_tensor(batch["vfeats"]) + return batch + + def sampling( + self, + video_idx, + video_feature, + text_feature, + centerclip_idx=None, + sampled_max_text_len=None, + ): + text_clip_indexs = self.text_clip_sampler( + text_feature, centerclip_idx, + sampled_max_text_len + ) + if isinstance(video_feature, np.ndarray): + video_len = len(video_feature) + else: + video_len = math.ceil(text_feature["end"][-1]) + + video_end = min( + math.ceil(text_feature["end"][text_clip_indexs[-1]]), + video_len + ) + video_start = max( + min( + math.floor(text_feature["start"][text_clip_indexs[0]]), + video_end), + 0 + ) + + video_clips = {"start": [video_start], "end": [video_end]} + + # tensorize. + vfeats, vmasks = self._build_video_seq( + video_feature, video_clips + ) + caps, cmasks = self._build_text_seq( + text_feature, text_clip_indexs + ) + + text_start = text_clip_indexs[0] + text_end = text_clip_indexs[-1] + 1 + + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, + "vmasks": vmasks, + "video_start": video_start, + "video_end": video_end, + "text_start": text_start, + "text_end": text_end, + } + + +class VariedLenAligner(FixedLenAligner): + def __init__(self, config): + super().__init__(config) + self.sampled_min_len = config.sampled_min_len + self.sampled_max_len = config.sampled_max_len + + def _get_text_maxlen(self): + return random.randint(self.sampled_min_len, self.sampled_max_len) + + +class StartClipAligner(VariedLenAligner): + def sampling( + self, + video_idx, + video_feature, + text_feature, + centerclip_idx=None, + sampled_max_text_len=None, + ): + return super().sampling( + video_idx, video_feature, text_feature, 0) + + +class OverlappedAligner(VariedLenAligner): + """video clip and text clip has overlappings + but may not be the same start/end.""" + def __init__(self, config): + super().__init__(config) + self.sampled_video_min_len = config.sampled_video_min_len + self.sampled_video_max_len = config.sampled_video_max_len + + self.video_clip_sampler = VideoClipSamplingProcessor() + + def _get_video_maxlen(self): + return random.randint( + self.sampled_video_min_len, self.sampled_video_max_len) + + def sampling( + self, + video_idx, + video_feature, + text_feature, + centerclip_idx=None, + sampled_max_text_len=None, + ): + text_clip_indexs = self.text_clip_sampler( + text_feature, centerclip_idx, + sampled_max_text_len + ) + if isinstance(video_feature, np.ndarray): + video_len = len(video_feature) + else: + video_len = math.ceil(text_feature["end"][-1]) + low = math.floor(text_feature["start"][text_clip_indexs[0]]) + high = math.ceil(text_feature["end"][text_clip_indexs[-1]]) + if low < high: + center = random.randint(low, high) + else: + center = int((low + high) // 2) + center = max(0, min(video_feature.shape[0] - 1, center)) + + assert 0 <= center < video_feature.shape[0] + + video_clips = self.video_clip_sampler( + video_len, self._get_video_maxlen(), center + ) + video_start = video_clips["start"][0] + video_end = video_clips["end"][0] + + # tensorize. + vfeats, vmasks = self._build_video_seq( + video_feature, video_clips + ) + caps, cmasks = self._build_text_seq( + text_feature, text_clip_indexs + ) + + text_start = text_clip_indexs[0] + text_end = text_clip_indexs[-1] + 1 + + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, + "vmasks": vmasks, + "video_start": video_start, + "video_end": video_end, + "text_start": text_start, + "text_end": text_end, + } + + +class MFMMLMAligner(FixedLenAligner): + """ + `FixedLenAligner` with Masked Language Model and Masked Frame Model. + """ + + def __init__(self, config): + super().__init__(config) + keep_prob = config.keep_prob if config.keep_prob is not None else 1.0 + self.text_clip_sampler = TextClipSamplingProcessor( + self.max_len - self.max_video_len - 3, keep_prob + ) + self.sampled_min_len = config.sampled_min_len + self.sampled_max_len = config.sampled_max_len + self.masked_token_sampler = TextMaskingProcessor(config) + self.mm_type = config.mm_type \ + if config.mm_type is not None else "full" + self.attnmasker = MMAttentionMask2DProcessor() \ + if self.mm_type == "textgen" else None + self.masked_frame_sampler = FrameMaskingProcessor(config) + self.lazy_vfeat_mask = ( + False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask + ) + self.mm_prob = config.mm_prob if config.mm_prob is not None else 0. + + def __call__(self, video_id, video_feature, text_feature): + from transformers import default_data_collator + if self.subsampling is not None and self.subsampling > 1: + batch = [] + for _ in range(self.subsampling): + centerclip_idx = random.randint( + 0, len(text_feature["start"]) - 1) + sampled_max_text_len = random.randint( + self.sampled_min_len, self.sampled_max_len + ) + batch.append( + self.sampling( + video_id, + video_feature, + text_feature, + centerclip_idx, + sampled_max_text_len, + ) + ) + batch = self.batch_post_processing(batch, video_feature) + batch = default_data_collator(batch) + else: + batch = self.sampling(video_id, video_feature, text_feature) + batch = self.batch_post_processing(batch, video_feature) + batch["video_id"] = video_id if isinstance(video_id, str) \ + else video_id[0] + return batch + + def sampling( + self, + video_id, + video_feature, + text_feature, + centerclip_idx=None, + sampled_max_text_len=None, + ): + output = FixedLenAligner.sampling(self, + video_id, video_feature, text_feature, + centerclip_idx, sampled_max_text_len) + + masking_text, masking_video = None, None + if random.random() < self.mm_prob: + if random.random() > 0.5: + masking_text, masking_video = self.mm_type, "no" + else: + masking_text, masking_video = "no", "full" + video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None + video_label = self.masked_frame_sampler( + output["vmasks"], masking_video, vfeats=video_feats) + caps, text_label = self.masked_token_sampler( + output["caps"], masking_text) + + output.update({ + "caps": caps, + "video_label": video_label, + "text_label": text_label, + }) + + if self.attnmasker is not None: + attention_mask = self.attnmasker( + output["vmasks"], output["cmasks"], masking_text) + output.update({ + "attention_mask": attention_mask + }) + return output + + +class FrameMaskingProcessor(Processor): + def __init__(self, config): + self.mfm_probability = 0.15 + if config.mfm_probability is not None: + self.mfm_probability = config.mfm_probability + + def __call__(self, vmasks, modality_masking=None, vfeats=None): + """ + We perform lazy masking to save data transfer time. + It only generates video_labels by default and MFM model + will do actualy masking. + Return: `video_label` is a binary mask. + """ + video_label = vmasks.clone() + if modality_masking is not None: + if modality_masking == "full": + probability_matrix = torch.full(video_label.shape, 1.) + elif modality_masking == "no": + probability_matrix = torch.full(video_label.shape, 0.) + elif modality_masking == "inverse": + probability_matrix = torch.full( + video_label.shape, 1. - self.mfm_probability) + else: + raise ValueError("unknown modality masking.", modality_masking) + else: + probability_matrix = torch.full( + video_label.shape, self.mfm_probability) + masked_indices = torch.bernoulli(probability_matrix).bool() + # We only compute loss on masked tokens + video_label[~masked_indices] = 0 + if vfeats is not None: + vfeats[video_label, :] = 0.0 + return video_label + + +class TextGenerationProcessor(Processor): + def __init__(self, tokenizer): + self.bos_token_id = tokenizer.bos_token_id + self.pad_token_id = tokenizer.pad_token_id + + def __call__(self, inputs): + labels = inputs.clone() + # [CLS] [SEP] for video + labels[:2] = -100 + # keep [SEP] for text. + pad_mask = labels == self.pad_token_id + labels[pad_mask] = -100 + inputs[2:] = torch.cat([ + torch.LongTensor([self.bos_token_id]), + inputs[2:-1]]) + inputs[pad_mask] = self.pad_token_id + assert len(inputs) == len(labels) + return inputs, labels + + +class TextMaskingProcessor(Processor): + def __init__(self, config): + """this function is borrowed from + `transformers/data/data_collator.DataCollatorForLanguageModeling`""" + self.mlm_probability = 0.15 + if config.mlm_probability is not None: + self.mlm_probability = config.mlm_probability + self.bert_name = config.bert_name + # [CLS] is used as bos_token and [SEP] is used as eos_token. + # https://huggingface.co/transformers/master/model_doc/bertgeneration.html + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.bert_name, bos_token="[CLS]", eos_token="[SEP]") + self.textgen = TextGenerationProcessor(self.tokenizer) + + def __call__( + self, inputs: torch.Tensor, + modality_masking=None, + special_tokens_mask: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + expand modality_masking into + None: traditional bert masking. + "no": no masking. + "full": all [MASK] token for generation. + "gen": autoregressive generation. + """ + """ + Prepare masked tokens inputs/labels for masked language modeling: + 80% MASK, 10% random, 10% original. + """ + labels = inputs.clone() + # We sample a few tokens in each sequence for MLM training + # (with probability `self.mlm_probability`) + if modality_masking is not None: + if modality_masking == "full": + probability_matrix = torch.full(labels.shape, 1.) + elif modality_masking == "no": + probability_matrix = torch.full(labels.shape, 0.) + elif modality_masking.startswith("textgen"): + # [CLS] [SEP] ... + inputs, labels = self.textgen(inputs) + if "mask" not in modality_masking: + return inputs, labels + inputs = self.mask_input(inputs, special_tokens_mask) + return inputs, labels + elif modality_masking == "mask": + inputs = self.mask_input(inputs, special_tokens_mask) + labels = torch.full(inputs.shape, -100) + return inputs, labels + elif modality_masking == "inverse": + probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability) + else: + raise ValueError("unknown modality masking.", modality_masking) + else: + probability_matrix = torch.full(labels.shape, self.mlm_probability) + + if special_tokens_mask is None: + special_tokens_mask = self.get_special_tokens_mask( + labels.tolist(), already_has_special_tokens=True + ) + special_tokens_mask = torch.tensor( + special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = special_tokens_mask.bool() + + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, + # we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = ( + torch.bernoulli( + torch.full(labels.shape, 0.8)).bool() & masked_indices + ) + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.mask_token + ) + + # 10% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(labels.shape, 0.5)).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint( + len(self.tokenizer), labels.shape, dtype=torch.long + ) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input + # tokens unchanged + return inputs, labels + + def mask_input(self, inputs, special_tokens_mask=None): + # the following is new with masked autoregressive. + probability_matrix = torch.full( + inputs.shape, self.mlm_probability) + if special_tokens_mask is None: + special_tokens_mask = self.get_special_tokens_mask( + inputs.tolist(), already_has_special_tokens=True + ) + special_tokens_mask = torch.tensor( + special_tokens_mask, dtype=torch.bool) + else: + special_tokens_mask = special_tokens_mask.bool() + probability_matrix.masked_fill_(special_tokens_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + indices_replaced = ( + torch.bernoulli( + torch.full(inputs.shape, 0.8)).bool() & masked_indices + ) + inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.mask_token + ) + + # 10% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(inputs.shape, 0.5)).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint( + len(self.tokenizer), inputs.shape, dtype=torch.long + ) + inputs[indices_random] = random_words[indices_random] + return inputs + + def get_special_tokens_mask( + self, token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False + ) -> List[int]: + """ + Note: the version from transformers do not consider pad + as special tokens. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if" + "the provided sequence of " + "ids is already formated with special tokens " + "for the model." + ) + return list(map(lambda x: 1 if x in [ + self.tokenizer.sep_token_id, + self.tokenizer.cls_token_id, + self.tokenizer.pad_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + +class TextClipSamplingProcessor(Processor): + def __init__(self, max_text_len, keep_prob=1.0): + self.max_text_len = max_text_len + self.max_video_len = 256 # always hold. + self.keep_prob = keep_prob + + def __call__( + self, + text_feature, + centerclip_idx=None, + sampled_max_text_len=None, + sampled_max_video_len=None, + ): + # Let's use all caps for now and see if 256 can cover all of them. + if sampled_max_text_len is not None: + max_text_len = sampled_max_text_len + else: + max_text_len = self.max_text_len + if sampled_max_video_len is not None: + max_video_len = sampled_max_video_len + else: + max_video_len = self.max_video_len + + t_num_clips = len(text_feature["start"]) + + if centerclip_idx is None: + centerclip_idx = random.randint(0, t_num_clips - 1) + + start_idx, end_idx = centerclip_idx, centerclip_idx + 1 + text_clip_indexs = deque() + text_clip_indexs.append(start_idx) + text_len = len(text_feature["cap"][start_idx]) + + video_len = max( + 0, + text_feature["end"][start_idx] + - text_feature["start"][start_idx], + ) + + while ( + (start_idx > 0 or end_idx < t_num_clips) + and text_len < max_text_len + and video_len < max_video_len + ): + if random.random() > 0.5 and end_idx < t_num_clips: + # skip the next one? + if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips: + end_idx = end_idx + 1 + text_clip_indexs.append(end_idx) + text_len += len(text_feature["cap"][end_idx]) + end_idx += 1 + elif start_idx > 0: + if random.random() > self.keep_prob and (start_idx - 1) > 0: + start_idx = start_idx - 1 + start_idx -= 1 + text_clip_indexs.insert(0, start_idx) + text_len += len(text_feature["cap"][start_idx]) + else: + if end_idx < t_num_clips: + if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips: + end_idx = end_idx + 1 + text_clip_indexs.append(end_idx) + text_len += len(text_feature["cap"][end_idx]) + end_idx += 1 + else: + return text_clip_indexs + video_len = max( + 0, + text_feature["end"][text_clip_indexs[-1]] + - text_feature["start"][text_clip_indexs[0]], + ) + return text_clip_indexs + + +class VideoClipSamplingProcessor(Processor): + def __call__(self, video_len, max_video_len, center): + """ + `video_len`: length of the video. + `max_video_len`: maximum video tokens allowd in a sequence. + `center`: initial starting index. + """ + assert center >= 0 and center < video_len + t_clip_len = 0 + start, end = center, center + while (start > 0 or end < video_len) and t_clip_len < max_video_len: + # decide the direction to grow. + if start <= 0: + end += 1 + elif end >= video_len: + start -= 1 + elif random.random() > 0.5: + end += 1 + else: + start -= 1 + t_clip_len += 1 + return {"start": [start], "end": [end]} + + +class How2MILNCEAligner(FixedLenAligner): + """reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`""" + + def __init__(self, config): + super().__init__(config) + self.num_candidates = 4 + self.min_time = 5.0 + self.num_sec = 3.2 + # self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5 + # self.num_frames = 16 + + def sampling( + self, + video_id, + video_feature, + text_feature, + centerclip_idx=None, # will be ignored. + sampled_max_text_len=None # will be ignored. + ): + text, start, end = self._get_text(text_feature) + video = self._get_video(video_feature, start, end) + + vfeats = torch.zeros((self.max_video_len, video_feature.shape[1])) + vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool) + vfeats[: video.shape[0]] = torch.from_numpy(np.array(video)) + vmasks[: video.shape[0]] = 1 + + caps, cmasks = [], [] + for words in text: + cap, cmask = self._build_text_seq(text_feature, words) + caps.append(cap) + cmasks.append(cmask) + caps = torch.stack(caps) + cmasks = torch.stack(cmasks) + # video of shape: (video_len) + # text of shape (num_candidates, max_text_len) + + return { + "caps": caps, + "cmasks": cmasks, + "vfeats": vfeats, + "vmasks": vmasks, + # "video_id": video_id, + } + + def _get_video(self, video_feature, start, end): + start_seek = random.randint(start, int(max(start, end - self.num_sec))) + # duration = self.num_sec + 0.1 + return video_feature[start_seek : int(start_seek + self.num_sec)] + + def _get_text(self, cap): + ind = random.randint(0, len(cap["start"]) - 1) + if self.num_candidates == 1: + words = [ind] + else: + words = [] + cap_start = self._find_nearest_candidates(cap, ind) + for i in range(self.num_candidates): + words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))]) + + start, end = cap["start"][ind], cap["end"][ind] + # TODO: May need to be improved for edge cases. + # expand the min time. + if end - start < self.min_time: + diff = self.min_time - end + start + start = max(0, start - diff / 2) + end = start + self.min_time + return words, int(start), int(end) + + def _find_nearest_candidates(self, caption, ind): + """find the range of the clips.""" + start, end = ind, ind + #diff = caption["end"][end] - caption["start"][start] + n_candidate = 1 + while n_candidate < self.num_candidates: + # the first clip + if start == 0: + return 0 + # we add () in the following condition to fix the bug. + elif end == (len(caption["start"]) - 1): + return start - (self.num_candidates - n_candidate) + elif (caption["end"][end] - caption["start"][start - 1]) < ( + caption["end"][end + 1] - caption["start"][start] + ): + start -= 1 + else: + end += 1 + n_candidate += 1 + return start + + +class PKLJSONStrTextProcessor(TextProcessor): + """`caption.json` from howto100m are preprocessed as a + dict `[video_id, json_str]`. + Json parsing tokenization are conducted on-the-fly and cached into dict. + """ + + def __init__(self, config, max_clip_text_len=96): + print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.") + self.caption_pkl_path = str(config.caption_pkl_path) + with open(self.caption_pkl_path, "rb") as fd: + self.data = pickle.load(fd) + self.max_clip_text_len = max_clip_text_len + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + str(config.bert_name), use_fast=config.use_fast + ) + + def __call__(self, video_id): + caption = self.data[video_id] + if isinstance(caption, str): + import json + caption = json.loads(caption) + cap = [] + for clip_idx, text_clip in enumerate(caption["text"]): + clip_ids = [] + if isinstance(text_clip, str): + clip_ids = self.tokenizer( + text_clip[: self.max_clip_text_len], + add_special_tokens=False + )["input_ids"] + cap.append(clip_ids) + caption["cap"] = cap + caption.pop("text") # save space. + self.data[video_id] = caption + return caption diff --git a/fairseq/examples/MMPT/mmpt/processors/models/s3dg.py b/fairseq/examples/MMPT/mmpt/processors/models/s3dg.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7a691e33551807ef258a54282884a409408691 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/models/s3dg.py @@ -0,0 +1,336 @@ +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G) +with a text module for computing joint text-video embedding from raw text +and video input. The following code will enable you to load the HowTo100M +pretrained S3D Text-Video model from: + A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman, + End-to-End Learning of Visual Representations from Uncurated Instructional Videos. + https://arxiv.org/abs/1912.06430. + +S3D-G was proposed by: + S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy, + Rethinking Spatiotemporal Feature Learning For Video Understanding. + https://arxiv.org/abs/1712.04851. + Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py + +The S3D architecture was slightly modified with a space to depth trick for TPU +optimization. +""" + +import torch as th +import torch.nn.functional as F +import torch.nn as nn +import os +import numpy as np +import re + + +class InceptionBlock(nn.Module): + def __init__( + self, + input_dim, + num_outputs_0_0a, + num_outputs_1_0a, + num_outputs_1_0b, + num_outputs_2_0a, + num_outputs_2_0b, + num_outputs_3_0b, + gating=True, + ): + super(InceptionBlock, self).__init__() + self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1]) + self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1]) + self.conv_b1_b = STConv3D( + num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True + ) + self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1]) + self.conv_b2_b = STConv3D( + num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True + ) + self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1) + self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1]) + self.gating = gating + self.output_dim = ( + num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b + ) + if gating: + self.gating_b0 = SelfGating(num_outputs_0_0a) + self.gating_b1 = SelfGating(num_outputs_1_0b) + self.gating_b2 = SelfGating(num_outputs_2_0b) + self.gating_b3 = SelfGating(num_outputs_3_0b) + + def forward(self, input): + """Inception block + """ + b0 = self.conv_b0(input) + b1 = self.conv_b1_a(input) + b1 = self.conv_b1_b(b1) + b2 = self.conv_b2_a(input) + b2 = self.conv_b2_b(b2) + b3 = self.maxpool_b3(input) + b3 = self.conv_b3_b(b3) + if self.gating: + b0 = self.gating_b0(b0) + b1 = self.gating_b1(b1) + b2 = self.gating_b2(b2) + b3 = self.gating_b3(b3) + return th.cat((b0, b1, b2, b3), dim=1) + + +class SelfGating(nn.Module): + def __init__(self, input_dim): + super(SelfGating, self).__init__() + self.fc = nn.Linear(input_dim, input_dim) + + def forward(self, input_tensor): + """Feature gating as used in S3D-G. + """ + spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4]) + weights = self.fc(spatiotemporal_average) + weights = th.sigmoid(weights) + return weights[:, :, None, None, None] * input_tensor + + +class STConv3D(nn.Module): + def __init__( + self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False + ): + super(STConv3D, self).__init__() + self.separable = separable + self.relu = nn.ReLU(inplace=True) + assert len(kernel_size) == 3 + if separable and kernel_size[0] != 1: + spatial_kernel_size = [1, kernel_size[1], kernel_size[2]] + temporal_kernel_size = [kernel_size[0], 1, 1] + if isinstance(stride, list) and len(stride) == 3: + spatial_stride = [1, stride[1], stride[2]] + temporal_stride = [stride[0], 1, 1] + else: + spatial_stride = [1, stride, stride] + temporal_stride = [stride, 1, 1] + if isinstance(padding, list) and len(padding) == 3: + spatial_padding = [0, padding[1], padding[2]] + temporal_padding = [padding[0], 0, 0] + else: + spatial_padding = [0, padding, padding] + temporal_padding = [padding, 0, 0] + if separable: + self.conv1 = nn.Conv3d( + input_dim, + output_dim, + kernel_size=spatial_kernel_size, + stride=spatial_stride, + padding=spatial_padding, + bias=False, + ) + self.bn1 = nn.BatchNorm3d(output_dim) + self.conv2 = nn.Conv3d( + output_dim, + output_dim, + kernel_size=temporal_kernel_size, + stride=temporal_stride, + padding=temporal_padding, + bias=False, + ) + self.bn2 = nn.BatchNorm3d(output_dim) + else: + self.conv1 = nn.Conv3d( + input_dim, + output_dim, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn1 = nn.BatchNorm3d(output_dim) + + def forward(self, input): + out = self.relu(self.bn1(self.conv1(input))) + if self.separable: + out = self.relu(self.bn2(self.conv2(out))) + return out + + +class MaxPool3dTFPadding(th.nn.Module): + def __init__(self, kernel_size, stride=None, padding="SAME"): + super(MaxPool3dTFPadding, self).__init__() + if padding == "SAME": + padding_shape = self._get_padding_shape(kernel_size, stride) + self.padding_shape = padding_shape + self.pad = th.nn.ConstantPad3d(padding_shape, 0) + self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True) + + def _get_padding_shape(self, filter_shape, stride): + def _pad_top_bottom(filter_dim, stride_val): + pad_along = max(filter_dim - stride_val, 0) + pad_top = pad_along // 2 + pad_bottom = pad_along - pad_top + return pad_top, pad_bottom + + padding_shape = [] + for filter_dim, stride_val in zip(filter_shape, stride): + pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val) + padding_shape.append(pad_top) + padding_shape.append(pad_bottom) + depth_top = padding_shape.pop(0) + depth_bottom = padding_shape.pop(0) + padding_shape.append(depth_top) + padding_shape.append(depth_bottom) + return tuple(padding_shape) + + def forward(self, inp): + inp = self.pad(inp) + out = self.pool(inp) + return out + + +class Sentence_Embedding(nn.Module): + def __init__( + self, + embd_dim, + num_embeddings=66250, + word_embedding_dim=300, + token_to_word_path="dict.npy", + max_words=16, + output_dim=2048, + ): + super(Sentence_Embedding, self).__init__() + self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim) + self.fc1 = nn.Linear(word_embedding_dim, output_dim) + self.fc2 = nn.Linear(output_dim, embd_dim) + self.word_to_token = {} + self.max_words = max_words + token_to_word = np.load(token_to_word_path) + for i, t in enumerate(token_to_word): + self.word_to_token[t] = i + 1 + + def _zero_pad_tensor_token(self, tensor, size): + if len(tensor) >= size: + return tensor[:size] + else: + zero = th.zeros(size - len(tensor)).long() + return th.cat((tensor, zero), dim=0) + + def _split_text(self, sentence): + w = re.findall(r"[\w']+", str(sentence)) + return w + + def _words_to_token(self, words): + words = [ + self.word_to_token[word] for word in words if word in self.word_to_token + ] + if words: + we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words) + return we + else: + return th.zeros(self.max_words).long() + + def _words_to_ids(self, x): + split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x] + return th.stack(split_x, dim=0) + + def forward(self, x): + x = self._words_to_ids(x) + x = self.word_embd(x) + x = F.relu(self.fc1(x)) + x = th.max(x, dim=1)[0] + x = self.fc2(x) + return {'text_embedding': x} + + +class S3D(nn.Module): + def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True): + super(S3D, self).__init__() + self.num_classes = num_classes + self.gating = gating + self.space_to_depth = space_to_depth + if space_to_depth: + self.conv1 = STConv3D( + 24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False + ) + else: + self.conv1 = STConv3D( + 3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False + ) + self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False) + self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True) + self.gating = SelfGating(192) + self.maxpool_2a = MaxPool3dTFPadding( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" + ) + self.maxpool_3a = MaxPool3dTFPadding( + kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME" + ) + self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32) + self.mixed_3c = InceptionBlock( + self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64 + ) + self.maxpool_4a = MaxPool3dTFPadding( + kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME" + ) + self.mixed_4b = InceptionBlock( + self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64 + ) + self.mixed_4c = InceptionBlock( + self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64 + ) + self.mixed_4d = InceptionBlock( + self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64 + ) + self.mixed_4e = InceptionBlock( + self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64 + ) + self.mixed_4f = InceptionBlock( + self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128 + ) + self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding( + kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME" + ) + self.mixed_5b = InceptionBlock( + self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128 + ) + self.mixed_5c = InceptionBlock( + self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128 + ) + self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes) + self.text_module = Sentence_Embedding(num_classes, + token_to_word_path=dict_path) + + def _space_to_depth(self, input): + """3D space to depth trick for TPU optimization. + """ + B, C, T, H, W = input.shape + input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2) + input = input.permute(0, 3, 5, 7, 1, 2, 4, 6) + input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2) + return input + + def forward(self, inputs): + """Defines the S3DG base architecture.""" + if self.space_to_depth: + inputs = self._space_to_depth(inputs) + net = self.conv1(inputs) + if self.space_to_depth: + # we need to replicate 'SAME' tensorflow padding + net = net[:, :, 1:, 1:, 1:] + net = self.maxpool_2a(net) + net = self.conv_2b(net) + net = self.conv_2c(net) + if self.gating: + net = self.gating(net) + net = self.maxpool_3a(net) + net = self.mixed_3b(net) + net = self.mixed_3c(net) + net = self.maxpool_4a(net) + net = self.mixed_4b(net) + net = self.mixed_4c(net) + net = self.mixed_4d(net) + net = self.mixed_4e(net) + net = self.mixed_4f(net) + net = self.maxpool_5a(net) + net = self.mixed_5b(net) + net = self.mixed_5c(net) + net = th.mean(net, dim=[2, 3, 4]) + return {'video_embedding': self.fc(net), 'mixed_5c': net} diff --git a/fairseq/examples/MMPT/mmpt/processors/processor.py b/fairseq/examples/MMPT/mmpt/processors/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..98edb051f16efef81fba98b0b2f6befbad09f2d4 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/processors/processor.py @@ -0,0 +1,274 @@ +# Copyright (c) Facebook, Inc. All Rights Reserved + +import numpy as np +import os +import torch + + +class Processor(object): + """ + A generic processor for video (codec, feature etc.) and text. + """ + + def __call__(self, **kwargs): + raise NotImplementedError + + +class MetaProcessor(Processor): + """ + A meta processor is expected to load the metadata of a dataset: + (e.g., video_ids, or captions). + You must implement the `__getitem__` (meta datasets are rather diverse.). + """ + + def __init__(self, config): + self.split = config.split + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + raise NotImplementedError + + def _get_split_path(self, config): + splits = { + "train": config.train_path, + "valid": config.val_path, + "test": config.test_path, + } + if config.split is not None: + return splits[config.split] + return config.train_path + + +class TextProcessor(Processor): + """ + A generic Text processor: rename this as `withTokenizer`. + tokenize a string of text on-the-fly. + Warning: mostly used for end tasks. + (on-the-fly tokenization is slow for how2.) + TODO(huxu): move this class as a subclass. + """ + + def __init__(self, config): + self.bert_name = str(config.bert_name) + self.use_fast = config.use_fast + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + self.bert_name, use_fast=self.use_fast + ) + + def __call__(self, text_id): + caption = self.tokenizer(text_id, add_special_tokens=False) + return caption["input_ids"] + + +class VideoProcessor(Processor): + """ + A generic video processor: load a numpy video tokens by default. + """ + + def __init__(self, config): + self.vfeat_dir = config.vfeat_dir + + def __call__(self, video_fn): + if isinstance(video_fn, tuple): + video_fn = video_fn[0] + assert isinstance(video_fn, str) + video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy") + feat = np.load(video_fn) + return feat + + +class Aligner(object): + """ + An alignprocessor align video and text and output a dict of tensors (for a model). + """ + def __init__(self, config): + """__init__ needs to be light weight for more workers/threads.""" + self.split = config.split + self.max_video_len = config.max_video_len + self.max_len = config.max_len + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained( + str(config.bert_name), use_fast=config.use_fast + ) + self.cls_token_id = tokenizer.cls_token_id + self.sep_token_id = tokenizer.sep_token_id + self.pad_token_id = tokenizer.pad_token_id + self.mask_token_id = tokenizer.mask_token_id + + def __call__(self, video_id, video_feature, text_feature): + raise NotImplementedError + + def _build_video_seq(self, video_feature, video_clips=None): + """ + `video_feature`: available video tokens. + `video_clips`: video clip sequence to build. + """ + if not isinstance(video_feature, np.ndarray): + raise ValueError( + "unsupported type of video_feature", type(video_feature) + ) + + if video_clips is None: + # this is borrowed from DSAligner + video_start = 0 + video_end = min(len(video_feature), self.max_video_len) + # the whole sequence is a single clip. + video_clips = {"start": [video_start], "end": [video_end]} + + vfeats = np.zeros( + (self.max_video_len, video_feature.shape[1]), dtype=np.float32 + ) + vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool) + video_len = 0 + for start, end in zip(video_clips["start"], video_clips["end"]): + clip_len = min(self.max_video_len - video_len, (end - start)) + if clip_len > 0: + vfeats[video_len: video_len + clip_len] = video_feature[ + start: start + clip_len + ] + vmasks[video_len: video_len + clip_len] = 1 + video_len += clip_len + vfeats = torch.from_numpy(vfeats) + + return vfeats, vmasks + + def _build_text_seq(self, text_feature, text_clip_indexs=None): + """ + `text_feature`: all available clips. + `text_clip_indexes`: clip sequence to build. + """ + if text_clip_indexs is None: + text_clip_indexs = [0] + + full_caps = [] + if isinstance(text_feature, dict): + for clip_idx in text_clip_indexs: + full_caps.extend(text_feature["cap"][clip_idx]) + else: + full_caps = text_feature + max_text_len = self.max_len - self.max_video_len - 3 + full_caps = full_caps[:max_text_len] + full_caps = ( + [self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id] + ) + text_pad_len = self.max_len - len(full_caps) - self.max_video_len + padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len + caps = torch.LongTensor(padded_full_caps) + cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool) + cmasks[: len(full_caps)] = 1 + + return caps, cmasks + + def batch_post_processing(self, batch, video_feature): + return batch + + +class MMAttentionMask2DProcessor(Processor): + """text generation requires 2d mask + that is harder to generate by GPU at this stage.""" + + def __call__(self, vmask, cmask, mtype): + if mtype == "textgen": + return self._build_textgeneration_mask(vmask, cmask) + elif mtype == "videogen": + return self._build_videogeneration_mask(vmask, cmask) + else: + return self._build_mm_mask(vmask, cmask) + + def _build_mm_mask(self, vmask, cmask): + mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0) + return mask_1d[None, :].repeat(mask_1d.size(0), 1) + + def _build_videogeneration_mask(self, vmask, cmask): + # cls_mask is only about text otherwise it will leak generation. + cls_text_mask = torch.cat([ + # [CLS] + torch.ones( + (1,), dtype=torch.bool, device=cmask.device), + # video tokens and [SEP] for video. + torch.zeros( + (vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device), + cmask[2:] + ], dim=0) + + # concat horizontially. + video_len = int(vmask.sum()) + video_masks = torch.cat([ + # [CLS] + torch.ones( + (video_len, 1), dtype=torch.bool, device=cmask.device + ), + torch.tril( + torch.ones( + (video_len, video_len), + dtype=torch.bool, device=cmask.device)), + # video_padding + torch.zeros( + (video_len, vmask.size(0) - video_len), + dtype=torch.bool, device=cmask.device + ), + # [SEP] for video (unused). + torch.zeros( + (video_len, 1), dtype=torch.bool, device=cmask.device + ), + cmask[2:].unsqueeze(0).repeat(video_len, 1) + ], dim=1) + + text_masks = cls_text_mask[None, :].repeat( + cmask.size(0) - 2, 1) + video_padding_masks = cls_text_mask[None, :].repeat( + vmask.size(0) - video_len, 1) + + return torch.cat([ + cls_text_mask[None, :], + video_masks, + video_padding_masks, + torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:], + text_masks + ], dim=0) + + def _build_textgeneration_mask(self, vmask, cmask): + # cls_mask is only about video otherwise it will leak generation. + cls_video_mask = torch.cat([ + # [CLS] + torch.ones( + (1,), dtype=torch.bool, device=cmask.device), + vmask, + # [SEP] + torch.ones((1,), dtype=torch.bool, device=cmask.device), + torch.zeros( + (cmask.size(0)-2,), dtype=torch.bool, device=cmask.device) + ], dim=0) + + # concat horizontially. + text_len = int(cmask[2:].sum()) + text_masks = torch.cat([ + # [CLS] + torch.ones( + (text_len, 1), dtype=torch.bool, device=cmask.device + ), + vmask.unsqueeze(0).repeat(text_len, 1), + # [SEP] for video. + torch.ones( + (text_len, 1), dtype=torch.bool, device=cmask.device + ), + torch.tril( + torch.ones( + (text_len, text_len), + dtype=torch.bool, device=cmask.device)), + # padding. + torch.zeros( + (text_len, cmask.size(0) - text_len - 2), + dtype=torch.bool, device=cmask.device + ) + ], dim=1) + + cls_video_masks = cls_video_mask[None, :].repeat( + vmask.size(0) + 2, 1) + text_padding_masks = cls_video_mask[None, :].repeat( + cmask.size(0) - text_len - 2, 1) + return torch.cat([ + cls_video_masks, text_masks, text_padding_masks], dim=0) diff --git a/fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py b/fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b6115a39b2342be5513edd53016187ab91eb01 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +make a general fairseq task for MM pretraining. +""" + +import random + +from fairseq.tasks import LegacyFairseqTask, register_task + +from .task import Task +from .retritask import RetriTask +from ..datasets import FairseqMMDataset +from .. import utils + + +@register_task("mmtask") +class FairseqMMTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + # Add some command-line arguments for specifying where the data is + # located and the maximum supported input length. + parser.add_argument( + "taskconfig", + metavar="FILE", + help=("taskconfig to load all configurations" "outside fairseq parser."), + ) + + @classmethod + def setup_task(cls, args, **kwargs): + return FairseqMMTask(args) + + def __init__(self, args): + super().__init__(args) + config = utils.load_config(args) + self.mmtask = Task.config_task(config) + self.mmtask.build_dataset() + self.mmtask.build_model() + self.mmtask.build_loss() + + def load_dataset(self, split, **kwargs): + split_map = { + "train": self.mmtask.train_data, + "valid": self.mmtask.val_data, + "test": self.mmtask.test_data, + } + if split not in split_map: + raise ValueError("unknown split type.") + if split_map[split] is not None: + self.datasets[split] = FairseqMMDataset(split_map[split]) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + random.seed(epoch) + if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask): + if epoch >= self.mmtask.config.retri_epoch: + if not hasattr(self.mmtask, "retri_dataloader"): + self.mmtask.build_dataloader() + self.mmtask.retrive_candidates(epoch) + + return super().get_batch_iterator( + dataset, + max_tokens, + max_sentences, + max_positions, + ignore_invalid_inputs, + required_batch_size_multiple, + seed, + num_shards, + shard_id, + num_workers, + epoch, + data_buffer_size, + disable_iterator_cache, + grouped_shuffling, + update_epoch_batch_itr, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None diff --git a/fairseq/examples/MMPT/mmpt/tasks/milncetask.py b/fairseq/examples/MMPT/mmpt/tasks/milncetask.py new file mode 100644 index 0000000000000000000000000000000000000000..61b6ab0597f9f3a78bbcf1474613630b20c5a874 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/milncetask.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .task import Task + + +class MILNCETask(Task): + def reshape_subsample(self, sample): + if ( + hasattr(self.config.dataset, "subsampling") + and self.config.dataset.subsampling is not None + and self.config.dataset.subsampling > 1 + ): + for key in sample: + if torch.is_tensor(sample[key]): + tensor = self.flat_subsample(sample[key]) + if key in ["caps", "cmasks"]: + size = tensor.size() + batch_size = size[0] * size[1] + expanded_size = (batch_size,) + size[2:] + tensor = tensor.view(expanded_size) + sample[key] = tensor + return sample diff --git a/fairseq/examples/MMPT/mmpt/tasks/retritask.py b/fairseq/examples/MMPT/mmpt/tasks/retritask.py new file mode 100644 index 0000000000000000000000000000000000000000..b43f20fddb31f29b210b1c726e5f6ccaad04bcf0 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/retritask.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import torch +import pickle +import random + +from tqdm import tqdm +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from ..processors import ( + ShardedHow2MetaProcessor, + ShardedVideoProcessor, + ShardedTextProcessor, + VariedLenAligner, +) + +from ..datasets import MMDataset +from .task import Task +from ..modules import vectorpool +from ..evaluators.predictor import Predictor +from ..utils import set_seed, get_local_rank, get_world_size + + +class RetriTask(Task): + """abstract class for task with retrival.""" + + def reshape_subsample(self, sample): + for key in sample: + if torch.is_tensor(sample[key]): + sample[key] = self.flat_subsample(sample[key]) + return sample + + def flat_subsample(self, tensor): + if tensor.size(0) == 1: + tensor = tensor.squeeze(0) + return tensor + + def build_dataloader(self): + """called by `get_batch_iterator` in fairseqmmtask. """ + # TODO: hard-code dataloader for retri for now and configurable in .yaml. + # reuse the `train.lst`. + self.config.dataset.split = "train" + meta_processor = ShardedHow2MetaProcessor(self.config.dataset) + video_processor = ShardedVideoProcessor(self.config.dataset) + text_processor = ShardedTextProcessor(self.config.dataset) + + aligner = VariedLenAligner(self.config.dataset) + aligner.subsampling = self.config.dataset.clip_per_video + + self.retri_data = MMDataset( + meta_processor, video_processor, text_processor, aligner + ) + + retri_sampler = DistributedSampler(self.retri_data) + infer_scale = 16 + batch_size = self.config.dataset.num_video_per_batch \ + * infer_scale + + self.retri_dataloader = DataLoader( + self.retri_data, + collate_fn=self.retri_data.collater, + batch_size=batch_size, + shuffle=False, + sampler=retri_sampler, + num_workers=self.config.fairseq.dataset.num_workers + ) + return self.retri_dataloader + + def retrive_candidates(self, epoch, dataloader=None): + if get_local_rank() == 0: + print("running retrieval model.") + out_dir = os.path.join( + self.config.fairseq.checkpoint.save_dir, "retri") + os.makedirs(out_dir, exist_ok=True) + + if not os.path.isfile( + os.path.join( + out_dir, "batched_e" + str(epoch) + "_videos0.pkl") + ): + if dataloader is None: + dataloader = self.retri_dataloader + + self.model.eval() + self.model.is_train = False + + assert self.retri_data.meta_processor.data == \ + self.train_data.meta_processor.data # video_ids not mutated. + + self._retri_predict(epoch, dataloader) + + self.model.train() + self.model.is_train = True + + torch.distributed.barrier() + output = self._retri_sync(epoch, out_dir) + torch.distributed.barrier() + self.train_data.meta_processor.set_candidates(output) + return output + + +class VideoRetriTask(RetriTask): + """RetriTask on video level.""" + + def reshape_subsample(self, sample): + if ( + hasattr(self.config.dataset, "clip_per_video") + and self.config.dataset.clip_per_video is not None + and self.config.dataset.clip_per_video > 1 + ): + for key in sample: + if torch.is_tensor(sample[key]): + sample[key] = self.flat_subsample(sample[key]) + return sample + + def flat_subsample(self, tensor): + if tensor.size(0) == 1: + tensor = tensor.squeeze(0) + return Task.flat_subsample(self, tensor) + + def _retri_predict(self, epoch, dataloader): + set_seed(epoch) + # save for retrival. + predictor = VideoPredictor(self.config) + predictor.predict_loop( + self.model, dataloader) + set_seed(epoch) # get the same text clips. + # retrival. + retri_predictor = VideoRetriPredictor( + self.config) + retri_predictor.predict_loop( + self.model, predictor.vecpool.retriver, epoch) + del predictor + del retri_predictor + + def _retri_sync(self, epoch, out_dir): + # gpu do the same merge. + batched_videos = [] + for local_rank in range(get_world_size()): + fn = os.path.join( + out_dir, + "batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl") + with open(fn, "rb") as fr: + batched_videos.extend(pickle.load(fr)) + print( + "[INFO] batched_videos", + len(batched_videos), len(batched_videos[0])) + return batched_videos + + +class VideoPredictor(Predictor): + def __init__(self, config): + vectorpool_cls = getattr(vectorpool, config.vectorpool_cls) + self.vecpool = vectorpool_cls(config) + + def predict_loop( + self, + model, + dataloader, + early_stop=-1, + ): + with torch.no_grad(): + if get_local_rank() == 0: + dataloader = tqdm(dataloader) + for batch_idx, batch in enumerate(dataloader): + if batch_idx == early_stop: + break + self(batch, model) + return self.finalize() + + def __call__(self, sample, model, **kwargs): + param = next(model.parameters()) + dtype = param.dtype + device = param.device + subsample = sample["vfeats"].size(1) + sample = self.to_ctx(sample, device, dtype) + for key in sample: + if torch.is_tensor(sample[key]): + size = sample[key].size() + if len(size) >= 2: + batch_size = size[0] * size[1] + expanded_size = ( + (batch_size,) + size[2:] if len(size) > 2 + else (batch_size,) + ) + sample[key] = sample[key].view(expanded_size) + + outputs = model(**sample) + sample.update(outputs) + self.vecpool(sample, subsample) + + def finalize(self): + print("[INFO]", self.vecpool) + if not self.vecpool.retriver.db.is_trained: + self.vecpool.retriver.finalize_training() + return self.vecpool.retriver + + +class VideoRetriPredictor(Predictor): + """ + Online Retrieval Predictor for Clips (used by RetriTask). + TODO: merge this with VisPredictor? + """ + + def __init__(self, config): + self.pred_dir = os.path.join( + config.fairseq.checkpoint.save_dir, + "retri") + self.num_cands = config.num_cands + self.num_video_per_batch = config.dataset.num_video_per_batch + + def predict_loop( + self, + model, + retriver, + epoch, + early_stop=-1 + ): + # a fake loop that only try to recover video vector + # from video_id. + batched_videos = [] + # obtain available video_ids. + video_ids = list(retriver.videoid_to_vectoridx.keys()) + + dataloader = random.sample( + video_ids, + len(video_ids) // self.num_video_per_batch + ) + + if get_local_rank() == 0: + dataloader = tqdm(dataloader) + for batch_idx, batch in enumerate(dataloader): + # batch is one video id. + if batch_idx == early_stop: + break + video_ids = retriver.search_by_video_ids( + [batch], self.num_cands)[0] + if len(video_ids) > self.num_video_per_batch: + # we moved the center to make cluster robust. + video_ids = random.sample(video_ids, self.num_video_per_batch) + batched_videos.append(video_ids) + return self.finalize(batched_videos, epoch) + + def finalize(self, batched_videos, epoch): + fn = os.path.join( + self.pred_dir, + "batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl") + with open(fn, "wb") as fw: + pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL) + return batched_videos diff --git a/fairseq/examples/MMPT/mmpt/tasks/task.py b/fairseq/examples/MMPT/mmpt/tasks/task.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb50f24df49878d5b430f9df5ba1ffb1cf30e32 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/task.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from .. import tasks +from .. import models +from .. import losses +from ..datasets import MMDataset +from .. import processors + + +class Task(object): + """ + A task refers to one generic training task (e.g., training one model). + """ + + @classmethod + def config_task(cls, config): + """ + determine whether to load a hard-coded task or config from a generic one. + via if a task string is available in config. + """ + if config.task is not None: + # TODO (huxu): expand the search scope. + task_cls = getattr(tasks, config.task) + return task_cls(config) + else: + return Task(config) + + def __init__(self, config): + self.config = config + self.train_data = None + self.val_data = None + self.test_data = None + + self.model = None + self.loss_fn = None + self.eval_fn = None + + def build_dataset(self): + """TODO (huxu): move processor breakdown to MMDataset.""" + """fill-in `self.train_data`, `self.val_data` and `self.test_data`.""" + + meta_processor_cls = getattr( + processors, self.config.dataset.meta_processor) + video_processor_cls = getattr( + processors, self.config.dataset.video_processor) + text_processor_cls = getattr( + processors, self.config.dataset.text_processor) + aligner_cls = getattr( + processors, self.config.dataset.aligner) + + if self.config.dataset.train_path is not None: + self.config.dataset.split = "train" + # may be used by meta processor. + # meta_processor controls different dataset. + meta_processor = meta_processor_cls(self.config.dataset) + video_processor = video_processor_cls(self.config.dataset) + text_processor = text_processor_cls(self.config.dataset) + aligner = aligner_cls(self.config.dataset) + self.train_data = MMDataset( + meta_processor, video_processor, text_processor, aligner + ) + print("train_len", len(self.train_data)) + output = self.train_data[0] + self.train_data.print_example(output) + if self.config.dataset.val_path is not None: + self.config.dataset.split = "valid" + # may be used by meta processor. + meta_processor = meta_processor_cls(self.config.dataset) + video_processor = video_processor_cls(self.config.dataset) + text_processor = text_processor_cls(self.config.dataset) + aligner = aligner_cls(self.config.dataset) + self.val_data = MMDataset( + meta_processor, video_processor, text_processor, aligner + ) + print("val_len", len(self.val_data)) + output = self.val_data[0] + self.val_data.print_example(output) + + if self.config.dataset.split == "test": + # the following is run via lauching fairseq-validate. + meta_processor = meta_processor_cls(self.config.dataset) + video_processor = video_processor_cls(self.config.dataset) + text_processor = text_processor_cls(self.config.dataset) + + self.test_data = MMDataset( + meta_processor, video_processor, text_processor, aligner + ) + print("test_len", len(self.test_data)) + output = self.test_data[0] + self.test_data.print_example(output) + + def build_model(self, checkpoint=None): + if self.model is None: + model_cls = getattr(models, self.config.model.model_cls) + self.model = model_cls(self.config) + if checkpoint is not None: + self.load_checkpoint(checkpoint) + return self.model + + def load_checkpoint(self, checkpoint): + if self.model is None: + raise ValueError("model is not initialized.") + state_dict = torch.load(checkpoint) + state_dict = self._trim_state_dict(state_dict) + self.model.load_state_dict(state_dict, strict=False) + # if it's a fp16 model, turn it back. + if next(self.model.parameters()).dtype == torch.float16: + self.model = self.model.float() + return self.model + + def _trim_state_dict(self, state_dict): + from collections import OrderedDict + + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + if "model" in state_dict: # fairseq checkpoint format. + state_dict = state_dict["model"] + ret_state_dict = OrderedDict() + for ( + key, + value, + ) in state_dict.items(): + # remove fairseq wrapper since this is a task. + if key.startswith("mmmodel"): + key = key[len("mmmodel."):] + ret_state_dict[key] = value + return ret_state_dict + + def build_loss(self): + if self.loss_fn is None and self.config.loss is not None: + loss_cls = getattr(losses, self.config.loss.loss_cls) + self.loss_fn = loss_cls() + return self.loss_fn + + def flat_subsample(self, tensor): + size = tensor.size() + if len(size) >= 2: + batch_size = size[0] * size[1] + expanded_size = ( + (batch_size,) + size[2:] if len(size) > 2 + else (batch_size,) + ) + tensor = tensor.view(expanded_size) + return tensor + + def reshape_subsample(self, sample): + if ( + hasattr(self.config.dataset, "subsampling") + and self.config.dataset.subsampling is not None + and self.config.dataset.subsampling > 1 + ): + for key in sample: + if torch.is_tensor(sample[key]): + sample[key] = self.flat_subsample(sample[key]) + return sample + + def __call__(self, model, sample): + loss = None + loss_scalar = float("inf") + + sample = self.reshape_subsample(sample) + outputs = self.model(**sample) + sample.update(outputs) + if self.loss_fn is not None: + loss = self.loss_fn(**sample) + loss_scalar = loss.item() + + batch_size = sample["caps"].size(0) + sample_size = 1 + return { + "loss": loss, + "loss_scalar": loss_scalar, + "max_len": self.config.dataset.max_len, + "batch_size": batch_size, + "sample_size": sample_size, + } + + def build_dataloader(self): + """only used for trainer that lacks building loaders.""" + raise NotImplementedError diff --git a/fairseq/examples/MMPT/mmpt/tasks/vlmtask.py b/fairseq/examples/MMPT/mmpt/tasks/vlmtask.py new file mode 100644 index 0000000000000000000000000000000000000000..57dc4c91705fdb1292f2f2accbb42acb993eb6aa --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/tasks/vlmtask.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from .task import Task + + +class VLMTask(Task): + """A VLM task for reproducibility. + the collator split subsamples into two sub-batches. + This has should have no logic changes. + but changed the randomness in frame masking. + """ + + def flat_subsample(self, tensor): + size = tensor.size() + if len(size) >= 2: + batch_size = size[0] * (size[1] // 2) + expanded_size = ( + (batch_size, 2) + size[2:] if len(size) > 2 + else (batch_size, 2) + ) + tensor = tensor.view(expanded_size) + tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0) + return tensor diff --git a/fairseq/examples/MMPT/mmpt/utils/__init__.py b/fairseq/examples/MMPT/mmpt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2429ee3757353e768f71b27d129eb3ca3bcbec73 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/utils/__init__.py @@ -0,0 +1,68 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import random +import numpy as np +import torch + +from .shardedtensor import * +from .load_config import * + + +def set_seed(seed=43211): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if torch.backends.cudnn.enabled: + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def get_world_size(): + if torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +def get_local_rank(): + return torch.distributed.get_rank() \ + if torch.distributed.is_initialized() else 0 + + +def print_on_rank0(func): + local_rank = get_local_rank() + if local_rank == 0: + print("[INFO]", func) + + +class RetriMeter(object): + """ + Statistics on whether retrieval yields a better pair. + """ + def __init__(self, freq=1024): + self.freq = freq + self.total = 0 + self.replace = 0 + self.updates = 0 + + def __call__(self, data): + if isinstance(data, np.ndarray): + self.replace += data.shape[0] - int((data[:, 0] == -1).sum()) + self.total += data.shape[0] + elif torch.is_tensor(data): + self.replace += int(data.sum()) + self.total += data.size(0) + else: + raise ValueError("unsupported RetriMeter data type.", type(data)) + + self.updates += 1 + if get_local_rank() == 0 and self.updates % self.freq == 0: + print("[INFO]", self) + + def __repr__(self): + return "RetriMeter (" + str(self.replace / self.total) \ + + "/" + str(self.replace) + "/" + str(self.total) + ")" diff --git a/fairseq/examples/MMPT/mmpt/utils/load_config.py b/fairseq/examples/MMPT/mmpt/utils/load_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ede4f94117b197e5bdb551c5140604c6e47c91fb --- /dev/null +++ b/fairseq/examples/MMPT/mmpt/utils/load_config.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import omegaconf +from omegaconf import OmegaConf + + +def load_config(args=None, config_file=None, overwrite_fairseq=False): + """TODO (huxu): move fairseq overwrite to another function.""" + if args is not None: + config_file = args.taskconfig + config = recursive_config(config_file) + + if config.dataset.subsampling is not None: + batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling + print( + "adjusting batch_size to {} due to subsampling {}.".format( + batch_size, config.dataset.subsampling + ) + ) + config.fairseq.dataset.batch_size = batch_size + + is_test = config.dataset.split is not None and config.dataset.split == "test" + if not is_test: + if ( + config.fairseq.checkpoint is None + or config.fairseq.checkpoint.save_dir is None + ): + raise ValueError("fairseq save_dir or save_path must be specified.") + + save_dir = config.fairseq.checkpoint.save_dir + os.makedirs(save_dir, exist_ok=True) + if config.fairseq.common.tensorboard_logdir is not None: + tb_run_dir = suffix_rundir( + save_dir, config.fairseq.common.tensorboard_logdir + ) + config.fairseq.common.tensorboard_logdir = tb_run_dir + print( + "update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir + ) + os.makedirs(save_dir, exist_ok=True) + OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml")) + + if overwrite_fairseq and config.fairseq is not None and args is not None: + # flatten fields. + for group in config.fairseq: + for field in config.fairseq[group]: + print("overwrite args." + field, "as", config.fairseq[group][field]) + setattr(args, field, config.fairseq[group][field]) + return config + + +def recursive_config(config_path): + """allows for stacking of configs in any depth.""" + config = OmegaConf.load(config_path) + if config.includes is not None: + includes = config.includes + config.pop("includes") + base_config = recursive_config(includes) + config = OmegaConf.merge(base_config, config) + return config + + +def suffix_rundir(save_dir, run_dir): + max_id = -1 + for search_dir in os.listdir(save_dir): + if search_dir.startswith(run_dir): + splits = search_dir.split("_") + cur_id = int(splits[1]) if len(splits) > 1 else 0 + max_id = max(max_id, cur_id) + return os.path.join(save_dir, run_dir + "_" + str(max_id + 1)) + + +def overwrite_dir(config, replace, basedir): + for key in config: + if isinstance(config[key], str) and config[key].startswith(basedir): + config[key] = config[key].replace(basedir, replace) + if isinstance(config[key], omegaconf.dictconfig.DictConfig): + overwrite_dir(config[key], replace, basedir) diff --git a/fairseq/examples/MMPT/mmpt_cli/localjob.py b/fairseq/examples/MMPT/mmpt_cli/localjob.py new file mode 100644 index 0000000000000000000000000000000000000000..2675d3511a9ca700185d2d7b853dc56ad70c638c --- /dev/null +++ b/fairseq/examples/MMPT/mmpt_cli/localjob.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os + +from mmpt.utils import recursive_config + + +class BaseJob(object): + def __init__(self, yaml_file, dryrun=False): + self.yaml_file = yaml_file + self.config = recursive_config(yaml_file) + self.dryrun = dryrun + + def submit(self, **kwargs): + raise NotImplementedError + + def _normalize_cmd(self, cmd_list): + cmd_list = list(cmd_list) + yaml_index = cmd_list.index("[yaml]") + cmd_list[yaml_index] = self.yaml_file + return cmd_list + + +class LocalJob(BaseJob): + + CMD_CONFIG = { + "local_single": [ + "fairseq-train", "[yaml]", "--user-dir", "mmpt", + "--task", "mmtask", "--arch", "mmarch", + "--criterion", "mmloss", + ], + "local_small": [ + "fairseq-train", "[yaml]", "--user-dir", "mmpt", + "--task", "mmtask", "--arch", "mmarch", + "--criterion", "mmloss", + "--distributed-world-size", "2" + ], + "local_big": [ + "fairseq-train", "[yaml]", "--user-dir", "mmpt", + "--task", "mmtask", "--arch", "mmarch", + "--criterion", "mmloss", + "--distributed-world-size", "8" + ], + "local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"], + } + + def __init__(self, yaml_file, job_type=None, dryrun=False): + super().__init__(yaml_file, dryrun) + if job_type is None: + self.job_type = "local_single" + if self.config.task_type is not None: + self.job_type = self.config.task_type + else: + self.job_type = job_type + if self.job_type in ["local_single", "local_small"]: + if self.config.fairseq.dataset.batch_size > 32: + print("decreasing batch_size to 32 for local testing?") + + def submit(self): + cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type]) + if "predict" not in self.job_type: + # append fairseq args. + from mmpt.utils import load_config + + config = load_config(config_file=self.yaml_file) + for field in config.fairseq: + for key in config.fairseq[field]: + if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag. + param = ["--" + key.replace("_", "-")] + else: + if key == "lr": + value = str(config.fairseq[field][key][0]) + elif key == "adam_betas": + value = "'"+str(config.fairseq[field][key])+"'" + else: + value = str(config.fairseq[field][key]) + param = [ + "--" + key.replace("_", "-"), + value + ] + cmd_list.extend(param) + + print("launching", " ".join(cmd_list)) + if not self.dryrun: + os.system(" ".join(cmd_list)) + return JobStatus("12345678") + + +class JobStatus(object): + def __init__(self, job_id): + self.job_id = job_id + + def __repr__(self): + return self.job_id + + def __str__(self): + return self.job_id + + def done(self): + return False + + def running(self): + return False + + def result(self): + if self.done(): + return "{} is done.".format(self.job_id) + else: + return "{} is running.".format(self.job_id) + + def stderr(self): + return self.result() + + def stdout(self): + return self.result() diff --git a/fairseq/examples/MMPT/mmpt_cli/predict.py b/fairseq/examples/MMPT/mmpt_cli/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..4071e196d211f7b11170db2e7e35b716d3deeb69 --- /dev/null +++ b/fairseq/examples/MMPT/mmpt_cli/predict.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import glob +import argparse +import pprint +import omegaconf + +from omegaconf import OmegaConf +from torch.utils.data import DataLoader + +from mmpt.utils import load_config, set_seed +from mmpt.evaluators import Evaluator +from mmpt.evaluators import predictor as predictor_path +from mmpt.tasks import Task +from mmpt import processors +from mmpt.datasets import MMDataset + + +def get_dataloader(config): + meta_processor_cls = getattr(processors, config.dataset.meta_processor) + video_processor_cls = getattr(processors, config.dataset.video_processor) + text_processor_cls = getattr(processors, config.dataset.text_processor) + aligner_cls = getattr(processors, config.dataset.aligner) + + meta_processor = meta_processor_cls(config.dataset) + video_processor = video_processor_cls(config.dataset) + text_processor = text_processor_cls(config.dataset) + aligner = aligner_cls(config.dataset) + + test_data = MMDataset( + meta_processor, + video_processor, + text_processor, + aligner, + ) + print("test_len", len(test_data)) + output = test_data[0] + test_data.print_example(output) + + test_dataloader = DataLoader( + test_data, + batch_size=config.fairseq.dataset.batch_size, + shuffle=False, + num_workers=6, + collate_fn=test_data.collater, + ) + return test_dataloader + + +def main(args): + config = load_config(args) + + if isinstance(config, omegaconf.dictconfig.DictConfig): + print(OmegaConf.to_yaml(config)) + else: + pp = pprint.PrettyPrinter(indent=4) + pp.print(config) + + mmtask = Task.config_task(config) + mmtask.build_model() + + test_dataloader = get_dataloader(config) + checkpoint_search_path = os.path.dirname(config.eval.save_path) + results = [] + + prefix = os.path.basename(args.taskconfig) + if prefix.startswith("test"): + # loop all checkpoint for datasets without validation set. + if "best" not in config.fairseq.common_eval.path: + print("eval each epoch.") + for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"): + model = mmtask.load_checkpoint(checkpoint) + ckpt = os.path.basename(checkpoint) + evaluator = Evaluator(config) + output = evaluator.evaluate( + model, test_dataloader, ckpt + "_merged") + results.append((checkpoint, output)) + # use the one specified by the config lastly. + model = mmtask.load_checkpoint(config.fairseq.common_eval.path) + evaluator = Evaluator(config) + output = evaluator.evaluate(model, test_dataloader) + results.append((config.fairseq.common_eval.path, output)) + + best_result = None + best_metric = 0. + for checkpoint, result in results: + print(checkpoint) + evaluator.metric.print_computed_metrics(result) + best_score = evaluator.metric.best_metric(result) + if best_score > best_metric: + best_result = (checkpoint, result) + best_metric = best_score + print("best results:") + print(best_result[0]) + evaluator.metric.print_computed_metrics(best_result[1]) + + elif prefix.startswith("vis"): + model = mmtask.load_checkpoint(config.fairseq.common_eval.path) + predictor_cls = getattr(predictor_path, config.predictor) + predictor = predictor_cls(config) + predictor.predict_loop(model, test_dataloader, mmtask, None) + else: + raise ValueError("unknown prefix of the config file", args.taskconfig) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("taskconfig", type=str) + args = parser.parse_args() + main(args) diff --git a/fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml b/fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..337d66a2aadcff86da033d6656bdc570a20c32e6 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml @@ -0,0 +1,19 @@ +includes: projects/mfmmlm.yaml +project_dir: mtm/mmfusionmtm +task_group: + pretrain: + task: VLMTask # reproducible + dataset: + aligner: MFMMLMAligner + model: + use_seg_emb: True # reproducible + model_cls: MMFusionMTM + mm_encoder_cls: MMBertForMFMMLM + loss: + loss_cls: MTM + finetune: + model: + use_seg_emb: True # reproducible + test: + model: + use_seg_emb: True # reproducible diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..48fd64a5f4996a03b412c418e83c4672b7252e8d --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml @@ -0,0 +1,47 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: COINActionSegmentationMetaProcessor + train_path: data/coin/COIN.json + val_path: data/coin/COIN.json + vfeat_dir: data/feat/feat_coin_s3d + text_processor: COINActionSegmentationTextProcessor + aligner: COINActionSegmentationAligner + num_iso_layer: 12 + sliding_window: 8 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 1 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 8 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/coin +task_type: sweep_big +model: + model_cls: MMFusionActionSegmentation + mm_encoder_cls: MMBertForTokenClassification + use_seg_emb: true +loss: + loss_cls: CrossEntropy diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ca40ad815d9393bc9dc4d248cf7ec8810f1011f --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml @@ -0,0 +1,55 @@ +dataset: + video_processor: ShardedVideoProcessor + bert_name: bert-base-uncased + meta_processor: ShardedHow2MetaProcessor + train_path: data/how2/how2_s3d_train.lst + val_path: data/how2/how2_s3d_val.lst + vfeat_dir: data/feat/feat_how2_s3d_shard_small + text_processor: ShardedTextProcessor + tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased. + aligner: MFMMLMAligner + subsampling: 32 + sampled_min_len: 8 + sampled_max_len: 64 + max_video_len: 32 + max_len: 96 + lazy_vfeat_mask: true + mfm_probability: 0.15 + mlm_probability: 0.15 + mm_prob: 0.5 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 256 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 1000 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 15 + checkpoint: + save_dir: runs/mtm/vlm + save_interval_updates: 1024 + keep_interval_updates: 2 + keep_last_epochs: 30 +task_type: sweep_big +slurm_config: big +eval: + save_path: runs/mtm/vlm +model: + model_cls: MMFusionMTM + mm_encoder_cls: MMBertForMFMMLM + use_seg_emb: true +loss: + loss_cls: MTM +task: VLMTask diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d159847875f74050fbae5f0c223fa6d56bbd3d38 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml @@ -0,0 +1,38 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: CrossTaskVideoProcessor + aligner: CrossTaskAligner + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/crosstask/checkpoint_best.pt +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/crosstask/eval +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a41557df6af496c755b0b32345a736b1cdc01ce4 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml @@ -0,0 +1,29 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/vtt/checkpoint_last.pt +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/vtt/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..abf3309f7072c88a6b7c72e2540d4950d0e55575 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml @@ -0,0 +1,29 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: MSRVTTQAAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTQAMetaProcessor + test_path: data/msrvtt-qa/MSR_MC_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTQATextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/vttqa/checkpoint_last.pt +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/vttqa/eval +metric: QAMetric +predictor: QAPredictor diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b2595d7c3c633d0faac5e4891457067659f7fafc --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml @@ -0,0 +1,32 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: YoucookVideoProcessor + aligner: DSNLGAligner + bert_name: bert-base-uncased + meta_processor: YoucookNLGMetaProcessor + test_path: data/youcook/val_list.txt + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: NLGTextProcessor + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/mtm/vlm/youcookcap/checkpoint_best.pt +model: + model_cls: MMFusionNLG + mm_encoder_cls: MMBertForNLG + max_decode_length: 24 + use_seg_emb: true +eval: + save_path: runs/mtm/vlm/youcookcap/eval +metric: NLGMetric +predictor: NLGPredictor +gen_param: + num_beams: 5 diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6c5b1ab40cbd2256919b42ea6b6a5fa52fe41a1 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml @@ -0,0 +1,49 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + full_test_path: data/msrvtt/MSRVTT_FULL_test.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 256 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 10 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/vtt +task_type: sweep_small +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +loss: + loss_cls: T2VContraLoss diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a440c7dd2f5b324d57367456d22ca5aab8f397f --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml @@ -0,0 +1,47 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 128 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 5 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/vttqa +task_type: sweep_small +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +loss: + loss_cls: V2TContraLoss diff --git a/fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml b/fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ee82b81b8b14a9f689e73bac22f21b780640984 --- /dev/null +++ b/fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml @@ -0,0 +1,47 @@ +dataset: + video_processor: YoucookVideoProcessor + bert_name: bert-base-uncased + meta_processor: YoucookMetaProcessor + train_path: data/youcook/youcook_train.pkl + val_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: true + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: TextProcessor + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 128 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 10 + checkpoint: + restore_file: runs/mtm/vlm/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/mtm/vlm/youcook +task_type: sweep_small +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint + use_seg_emb: true +loss: + loss_cls: T2VContraLoss diff --git a/fairseq/examples/MMPT/projects/retri/videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afd040ab050554aa0ca9c65ab4cf36cf6f3155dd --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip.yaml @@ -0,0 +1,10 @@ +includes: projects/retri/videoretri.yaml +project_dir: retri/videoclip +task_group: + pretrain: + model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaed5e47f62d421e5c434a2f539437831ae387db --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml @@ -0,0 +1,49 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: COINActionSegmentationMetaProcessor + train_path: data/coin/COIN.json + val_path: data/coin/COIN.json + vfeat_dir: data/feat/feat_coin_s3d + text_processor: COINActionSegmentationTextProcessor + aligner: COINActionSegmentationAligner + num_iso_layer: 12 + sliding_window: 8 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 1 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 8 + checkpoint: + restore_file: runs/retri/videoclip/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/retri/videoclip/coin +task_type: sweep_big +model: + model_cls: MMFusionSeparateActionSegmentation + mm_encoder_cls: null + video_encoder_cls: MMBertForTokenClassification + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: CrossEntropy diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..758601e3593aa3960fee996dcfdfd71972f9f068 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml @@ -0,0 +1,55 @@ +dataset: + video_processor: CrossTaskVideoProcessor + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + train_path: data/crosstask/crosstask_release/videos.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + aligner: CrossTaskAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 1 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 5 + checkpoint: + restore_file: runs/retri/videoclip/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/retri/videoclip/crosstask +task_type: sweep_small +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: BCE diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b49581e878390cd8861926988a58a892b5a38606 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml @@ -0,0 +1,65 @@ +dataset: + video_processor: ShardedVideoRetriVideoProcessor + bert_name: bert-base-uncased + meta_processor: ShardedHow2VideoRetriMetaProcessor + train_path: data/how2/how2_s3d_train.lst + val_path: data/how2/how2_s3d_val.lst + vfeat_dir: data/feat/feat_how2_s3d_shard_small + text_processor: ShardedVideoRetriTextProcessor + tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased. + aligner: VideoRetriOverlappedAligner + subsampling: 1 + sampled_min_len: 8 + sampled_max_len: 64 + max_video_len: 32 + max_len: 96 + lazy_vfeat_mask: true + mfm_probability: 0.15 + mlm_probability: 0.15 + mm_prob: 0.5 + sampled_video_min_len: 3 + sampled_video_max_len: 32 + num_video_per_batch: 32 + clip_per_video: 16 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 1 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 1000 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 25 + checkpoint: + save_dir: runs/retri/videoclip + save_interval_updates: 1024 + keep_interval_updates: 2 + keep_last_epochs: 30 +task_type: sweep_big +slurm_config: big +eval: + save_path: runs/retri/videoclip +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: MMContraLoss +task: VideoRetriTask +retri_epoch: 1 +vectorpool_cls: VideoVectorPool +retriever_cls: VectorRetriever +num_cands: 64 diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..409906203c22d670c361d011d1035874961d4b5b --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml @@ -0,0 +1,33 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: COINActionSegmentationAligner + bert_name: bert-base-uncased + test_path: data/coin/COIN.json + meta_processor: COINActionSegmentationMetaProcessor + vfeat_dir: data/feat/feat_coin_s3d + text_processor: COINActionSegmentationTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/coin/checkpoint_best.pt +model: + model_cls: MMFusionSeparateActionSegmentation + mm_encoder_cls: null + video_encoder_cls: MMBertForTokenClassification + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/coin/eval +metric: COINActionSegmentationMetric +predictor: COINPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b33739c7b6d7c4d935c0e4d78354a02f79e21d91 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml @@ -0,0 +1,33 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: COINActionSegmentationAligner + bert_name: bert-base-uncased + test_path: data/coin/COIN.json + meta_processor: COINActionSegmentationMetaProcessor + vfeat_dir: data/feat/feat_coin_s3d + text_processor: COINActionSegmentationTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/coin_zs/eval +metric: COINActionSegmentationMetric +predictor: COINZSPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e82f54fbe52defa233f1b561415737924dd6ba9c --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml @@ -0,0 +1,40 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: CrossTaskVideoProcessor + aligner: CrossTaskAligner + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/crosstask/checkpoint_best.pt +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/crosstask/eval +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6fc357cc1f5285229a508a95dc31c2aae17c977d --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml @@ -0,0 +1,40 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: CrossTaskVideoProcessor + aligner: CrossTaskAligner + bert_name: bert-base-uncased + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + text_processor: CrossTaskTextProcessor + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 1 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/crosstask_zs/eval +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19321ad5f489121381b1b78506b38ba3705b37e4 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/vtt/checkpoint_last.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/vtt/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d149fa3960294f4d9dcac9c4f4f18631eb622435 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/vtt_zs/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..295aeedbb0f29a28ad9065da0a18e82ac98b6683 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: MSRVTTQAAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTQAMetaProcessor + test_path: data/msrvtt-qa/MSR_MC_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTQATextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/vttqa/checkpoint_last.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/vttqa/eval +metric: QAMetric +predictor: QAPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a876c822ae38cadc03c69544e64025d516e8cca --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml @@ -0,0 +1,31 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: VideoProcessor + aligner: MSRVTTQAAligner + bert_name: bert-base-uncased + meta_processor: MSRVTTQAMetaProcessor + test_path: data/msrvtt-qa/MSR_MC_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTQATextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/vttqa_zs/eval +metric: QAMetric +predictor: QAPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86a4ab203e097ec156ce02d88ce9c1b2374cec9f --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml @@ -0,0 +1,33 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: YoucookVideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: YoucookMetaProcessor + test_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: true + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: TextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/youcook/checkpoint_last.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/youcook/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd2941708be0baacf9e12ca757e0de40676e570d --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml @@ -0,0 +1,33 @@ +slurm_config: big +task_type: local_predict +dataset: + split: test + video_processor: YoucookVideoProcessor + aligner: DSAligner + bert_name: bert-base-uncased + meta_processor: YoucookMetaProcessor + test_path: data/youcook/youcook_val.pkl + trainval_annotation: data/youcook/youcookii_annotations_trainval.json + use_annotation_text: true + vfeat_dir: data/feat/feat_youcook_s3d + text_processor: TextProcessor + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + dataset: + batch_size: 256 + valid_subset: test + num_workers: 2 + common_eval: + path: runs/retri/videoclip/checkpoint_best.pt +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/retri/videoclip/youcook_zs/eval +metric: RetrievalMetric +predictor: RetrievalPredictor diff --git a/fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml b/fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8b4079ac233af17ee16cfed7f3e2316e0e236c5 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml @@ -0,0 +1,51 @@ +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + full_test_path: data/msrvtt/MSRVTT_FULL_test.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 + max_video_len: 32 + max_len: 96 +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + fp16: true + dataset: + num_workers: 4 + batch_size: 224 + optimization: + lr: + - 5.0e-05 + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 + warmup_updates: 122 + weight_decay: 0.0 + ddp_backend: no_c10d + max_epoch: 10 + checkpoint: + restore_file: runs/retri/videoclip/checkpoint_best.pt + reset_optimizer: true + reset_dataloader: true + reset_meters: true + save_dir: runs/retri/videoclip/vtt +task_type: sweep_small +model: + model_cls: MMFusionSeparate + mm_encoder_cls: null + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +loss: + loss_cls: T2VContraLoss diff --git a/fairseq/examples/MMPT/projects/retri/videoretri.yaml b/fairseq/examples/MMPT/projects/retri/videoretri.yaml new file mode 100644 index 0000000000000000000000000000000000000000..969e1fb2793d5381e52188c4ad78aea240698969 --- /dev/null +++ b/fairseq/examples/MMPT/projects/retri/videoretri.yaml @@ -0,0 +1,51 @@ +includes: projects/mfmmlm.yaml +project_dir: retri/videoretri +run_task: + - how2.yaml +task_group: + pretrain: + task: VideoRetriTask + retri_epoch: 1 + vectorpool_cls: VideoVectorPool + retriever_cls: VectorRetriever + num_cands: 64 + dataset: + train_path: data/how2/how2_s3d_train.lst + meta_processor: ShardedHow2VideoRetriMetaProcessor + video_processor: ShardedVideoRetriVideoProcessor + text_processor: ShardedVideoRetriTextProcessor + aligner: VideoRetriOverlappedAligner + sampled_video_min_len: 3 + sampled_video_max_len: 32 + sampled_min_len: 8 + sampled_max_len: 64 + num_video_per_batch: 32 + # do not use subsampling as it changes fairseq batch_size. + subsampling: 1 # disable subsampling + clip_per_video: 16 + fairseq: + dataset: + batch_size: 1 + optimization: + max_epoch: 25 + model: + model_cls: MMFusionShare + mm_encoder_cls: MMBertForEncoder + loss: + loss_cls: MMContraLoss + finetune: + task_list: [vtt_videoclip.yaml, youcook_videoclip.yaml, vttqa_videoclip.yaml, crosstask_videoclip.yaml, coin_videoclip.yaml] + test: + task_list: + - test_youcook_zs.yaml + - test_vtt_zs.yaml + - test_vttqa_zs.yaml + - test_crosstask_zs_videoclip.yaml + - test_coin_zs.yaml + - test_didemo_zs.yaml + - test_youcook_videoclip.yaml + - test_vtt_videoclip.yaml + - test_vttqa_videoclip.yaml + - test_crosstask_videoclip.yaml + - test_coin_videoclip.yaml + diff --git a/fairseq/examples/MMPT/projects/task/crosstask.yaml b/fairseq/examples/MMPT/projects/task/crosstask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb4dbb0cb4a07020d1fbba9a810eb30fd81ae96d --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/crosstask.yaml @@ -0,0 +1,31 @@ +includes: projects/task/ft.yaml +dataset: + meta_processor: CrossTaskMetaProcessor + train_path: data/crosstask/crosstask_release/videos.csv # dummy + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv # dummy + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + video_processor: CrossTaskVideoProcessor + text_processor: CrossTaskTextProcessor + aligner: CrossTaskAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint +loss: + loss_cls: BCE +fairseq: + dataset: + batch_size: 1 + optimization: + max_epoch: 5 + checkpoint: + save_dir: runs/task/crosstask + restore_file: runs/task/checkpoint11.pt # for VLM diff --git a/fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml b/fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ec613c07fcacfa6e6d3c5fc78aaefcb2b33eff5 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml @@ -0,0 +1,10 @@ +includes: projects/task/crosstask.yaml +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel # dummy, not used. + num_hidden_video_layers: 6 +fairseq: + checkpoint: + restore_file: runs/task/checkpoint_best.pt # overwrite the default of VLM. diff --git a/fairseq/examples/MMPT/projects/task/default.yaml b/fairseq/examples/MMPT/projects/task/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..087fef71a4b94d81b74c5e9ef647db189626ff36 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/default.yaml @@ -0,0 +1,20 @@ +# this yaml cannot be run alone. you must use `how2.yaml`, `vtt.yaml` etc for training. +dataset: + video_processor: VideoProcessor + bert_name: bert-base-uncased +fairseq: + common: + tensorboard_logdir: run + log_interval: 1000 + dataset: + num_workers: 4 + optimization: + lr: [ 0.00005 ] + clip_norm: 2.0 + optimizer: adam + adam_betas: (0.9, 0.98) + lr_scheduler: polynomial_decay + total_num_update: 1000000 # backward compatible on fairseq 1.0.0a0+af0389f for reproducibility. + warmup_updates: 1000 + weight_decay: 0.0 + ddp_backend: no_c10d diff --git a/fairseq/examples/MMPT/projects/task/ft.yaml b/fairseq/examples/MMPT/projects/task/ft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c93b8a73ea93e00da8a0ee76751c9e1bd34ab8a2 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/ft.yaml @@ -0,0 +1,13 @@ +includes: projects/task/default.yaml +# all derived config will be run by fairseq-train. +task_type: sweep_small +fairseq: + optimization: + warmup_updates: 122 # copied from roberta glue: https://github.com/pytorch/fairseq/blob/master/examples/roberta/README.glue.md + checkpoint: + # save_interval_updates: 512 + # borrowed from Roberta script. + restore_file: runs/task/checkpoint_best.pt + reset_optimizer: True + reset_dataloader: True + reset_meters: True diff --git a/fairseq/examples/MMPT/projects/task/how2.yaml b/fairseq/examples/MMPT/projects/task/how2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..094dd04bfcfdb96699ac80066f874fca954dc56a --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/how2.yaml @@ -0,0 +1,22 @@ +includes: projects/task/default.yaml +task_type: sweep_big +slurm_config: big +dataset: + meta_processor: ShardedHow2MetaProcessor + train_path: data/how2/how2_s3d_train.lst + val_path: data/how2/how2_s3d_val.lst + video_processor: ShardedVideoProcessor + vfeat_dir: data/feat/feat_how2_s3d_shard_small + text_processor: ShardedTextProcessor + tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased. + aligner: FixedLenAligner +# disable direct running of this yaml +eval: + save_path: runs/task +fairseq: + checkpoint: + save_dir: runs/task + save_interval_updates: 1024 + keep_interval_updates: 2 + keep_last_epochs: 30 + diff --git a/fairseq/examples/MMPT/projects/task/test_coin.yaml b/fairseq/examples/MMPT/projects/task/test_coin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d919df7c2a69de33debd03a6d9b1a513da39618 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_coin.yaml @@ -0,0 +1,24 @@ +includes: projects/task/test.yaml +dataset: + split: test + test_path: data/coin/COIN.json + meta_processor: COINActionSegmentationMetaProcessor + vfeat_dir: data/feat/feat_coin_s3d + video_processor: VideoProcessor + text_processor: COINActionSegmentationTextProcessor + aligner: COINActionSegmentationAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 +model: + model_cls: MMFusionActionSegmentation + mm_encoder_cls: MMBertForTokenClassification +eval: + save_path: runs/task/coin/eval +fairseq: + dataset: + batch_size: 1 + common_eval: + path: runs/task/coin/checkpoint_best.pt +metric: COINActionSegmentationMetric +predictor: COINPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b41f5bc4890405993217859c22653fa77a0dd4c3 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml @@ -0,0 +1,7 @@ +includes: projects/task/test_coin.yaml +model: + model_cls: MMFusionSeparateActionSegmentation + mm_encoder_cls: + video_encoder_cls: MMBertForTokenClassification + text_encoder_cls: BertModel # dummy, not used. + num_hidden_video_layers: 6 diff --git a/fairseq/examples/MMPT/projects/task/test_coin_zs.yaml b/fairseq/examples/MMPT/projects/task/test_coin_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d19b09f1dee1922ffddf5bf030135cb46016dc1 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_coin_zs.yaml @@ -0,0 +1,13 @@ +includes: projects/task/test_coin.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/task/coin_zs/eval +fairseq: + common_eval: + path: runs/task/checkpoint_best.pt +predictor: COINZSPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_crosstask.yaml b/fairseq/examples/MMPT/projects/task/test_crosstask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6dd778e30be543a51997f16e326aa29f6e49e973 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_crosstask.yaml @@ -0,0 +1,32 @@ +includes: projects/task/test.yaml +dataset: + split: test + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv # dummy + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + video_processor: CrossTaskVideoProcessor + text_processor: CrossTaskTextProcessor + aligner: CrossTaskAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint +eval: + save_path: runs/task/crosstask/eval +fairseq: + # read code and find what is the checkpoint arg. + dataset: + batch_size: 1 + common_eval: + path: runs/task/crosstask/checkpoint_best.pt +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df12535d231ffe6d1759f77f161fe48d0833f4e4 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml @@ -0,0 +1,7 @@ +includes: projects/task/test_crosstask.yaml +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel # dummy, not used. + num_hidden_video_layers: 6 diff --git a/fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml b/fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..19386e495b37ab2570cc1d763277193b28309814 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml @@ -0,0 +1,32 @@ +includes: projects/task/test.yaml +dataset: + split: test + meta_processor: CrossTaskMetaProcessor + test_path: data/crosstask/crosstask_release/videos_val.csv + train_csv_path: data/crosstask/crosstask_release/videos.csv + val_path: data/crosstask/crosstask_release/videos_val.csv # dummy + val_csv_path: data/crosstask/crosstask_release/videos_val.csv + primary_path: data/crosstask/crosstask_release/tasks_primary.txt + related_path: data/crosstask/crosstask_release/tasks_related.txt + vfeat_dir: data/feat/feat_crosstask_s3d + annotation_path: data/crosstask/crosstask_release/annotations + n_train: 30 + video_processor: CrossTaskVideoProcessor + text_processor: CrossTaskTextProcessor + aligner: CrossTaskAligner + num_iso_layer: 12 + sliding_window: 16 + sliding_window_size: 32 +model: + model_cls: MMFusionActionLocalization + mm_encoder_cls: MMBertForJoint +eval: + save_path: runs/task/crosstask_zs/eval +fairseq: + # read code and find what is the checkpoint arg. + dataset: + batch_size: 1 + common_eval: + path: runs/task/checkpoint_best.pt # load the best from how2 on ACL submission: runs/task/checkpoint11.pt +metric: CrossTaskMetric +predictor: CrossTaskPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f0198276f74f592413e99fb10672431543a8f67 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml @@ -0,0 +1,7 @@ +includes: projects/task/test_crosstask_zs.yaml +model: + model_cls: MMFusionSeparateActionLocalization + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel # dummy, not used. + num_hidden_video_layers: 6 diff --git a/fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml b/fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b53dca71e4a0e77793483ee971379f0476366f3 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml @@ -0,0 +1,23 @@ +includes: projects/task/test.yaml +dataset: + meta_processor: DiDeMoMetaProcessor + test_path: data/didemo/test_data.json + video_processor: VideoProcessor + vfeat_dir: data/feat/feat_didemo_s3d + text_processor: DiDeMoTextProcessor + aligner: DiDeMoAligner + num_iso_layer: 12 +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/task/didemo_zs/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/checkpoint_best.pt +metric: DiDeMoMetric +predictor: DiDeMoPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb6564394c8de5edebe41d9a90e8233d2ef90edf --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml @@ -0,0 +1,8 @@ +includes: projects/task/test_vtt.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 + diff --git a/fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml b/fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57340924b434338815743d0d2c326737ccfe3818 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml @@ -0,0 +1,13 @@ +includes: projects/task/test_vtt.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/task/vtt_zs/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/checkpoint_best.pt diff --git a/fairseq/examples/MMPT/projects/task/test_vttqa.yaml b/fairseq/examples/MMPT/projects/task/test_vttqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ddf813c535c0ecd6e6fd81ba9013a88f8ec4296b --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vttqa.yaml @@ -0,0 +1,20 @@ +includes: projects/task/test.yaml +dataset: + meta_processor: MSRVTTQAMetaProcessor + test_path: data/msrvtt-qa/MSR_MC_test.csv + video_processor: VideoProcessor + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTQATextProcessor + aligner: MSRVTTQAAligner + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +eval: + save_path: runs/task/vttqa/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/vttqa/checkpoint_last.pt +metric: QAMetric +predictor: QAPredictor diff --git a/fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32a41e861c2e71642d4e79f8cf9dc314a5a3f621 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml @@ -0,0 +1,8 @@ +includes: projects/task/test_vttqa.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 + diff --git a/fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml b/fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e0e29d20729445258f5d07348fd1799f7feecf4 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml @@ -0,0 +1,13 @@ +includes: projects/task/test_vttqa.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/task/vttqa_zs/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/checkpoint_best.pt diff --git a/fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml b/fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b85ea434749009f50eea057146403691bb48c4a9 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml @@ -0,0 +1,8 @@ +includes: projects/task/test_youcook.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 + diff --git a/fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml b/fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a5875bea4eaca1e46c65dfdc50b586550cc6f36 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml @@ -0,0 +1,13 @@ +includes: projects/task/test_youcook.yaml +model: + model_cls: MMFusionSeparate + mm_encoder_cls: + video_encoder_cls: MMBertForEncoder + text_encoder_cls: BertModel + num_hidden_video_layers: 6 +eval: + save_path: runs/task/youcook_zs/eval +fairseq: + # read code and find what is the checkpoint arg. + common_eval: + path: runs/task/checkpoint_best.pt diff --git a/fairseq/examples/MMPT/projects/task/vttqa.yaml b/fairseq/examples/MMPT/projects/task/vttqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56d578eff0104e5dcb4cea8720303100a9964db9 --- /dev/null +++ b/fairseq/examples/MMPT/projects/task/vttqa.yaml @@ -0,0 +1,23 @@ +includes: projects/task/ft.yaml +dataset: + meta_processor: MSRVTTMetaProcessor + train_path: data/msrvtt/MSRVTT_train.csv + dup: 20 + val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv + vfeat_dir: data/feat/feat_vtt_s3d + text_processor: MSRVTTTextProcessor + json_path: data/msrvtt/MSRVTT_data.json + aligner: DSAligner + num_iso_layer: 12 +model: + model_cls: MMFusionJoint + mm_encoder_cls: MMBertForJoint +loss: + loss_cls: V2TContraLoss +fairseq: + dataset: + batch_size: 128 + optimization: + max_epoch: 5 + checkpoint: + save_dir: runs/task/vttqa diff --git a/fairseq/examples/__init__.py b/fairseq/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44bb24ae614941f23fea29c56d60167650c39bcb --- /dev/null +++ b/fairseq/examples/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +try: + from fairseq.version import __version__ # noqa +except ImportError: + pass