Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/docs/models.rst +104 -0
- fairseq/docs/tutorial_classifying_names.rst +415 -0
- fairseq/examples/MMPT/.gitignore +139 -0
- fairseq/examples/MMPT/README.md +166 -0
- fairseq/examples/MMPT/endtask.md +41 -0
- fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py +57 -0
- fairseq/examples/MMPT/mmpt/datasets/mmdataset.py +111 -0
- fairseq/examples/MMPT/mmpt/evaluators/__init__.py +13 -0
- fairseq/examples/MMPT/mmpt/evaluators/evaluator.py +54 -0
- fairseq/examples/MMPT/mmpt/models/transformermodel.py +734 -0
- fairseq/examples/MMPT/mmpt/modules/__init__.py +10 -0
- fairseq/examples/MMPT/mmpt/modules/mm.py +145 -0
- fairseq/examples/MMPT/mmpt/modules/retri.py +429 -0
- fairseq/examples/MMPT/mmpt/modules/vectorpool.py +246 -0
- fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py +242 -0
- fairseq/examples/MMPT/mmpt/processors/how2processor.py +887 -0
- fairseq/examples/MMPT/mmpt/processors/models/s3dg.py +336 -0
- fairseq/examples/MMPT/mmpt/processors/processor.py +274 -0
- fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py +104 -0
- fairseq/examples/MMPT/mmpt/tasks/milncetask.py +27 -0
- fairseq/examples/MMPT/mmpt/tasks/retritask.py +253 -0
- fairseq/examples/MMPT/mmpt/tasks/task.py +184 -0
- fairseq/examples/MMPT/mmpt/tasks/vlmtask.py +27 -0
- fairseq/examples/MMPT/mmpt/utils/__init__.py +68 -0
- fairseq/examples/MMPT/mmpt/utils/load_config.py +81 -0
- fairseq/examples/MMPT/mmpt_cli/localjob.py +117 -0
- fairseq/examples/MMPT/mmpt_cli/predict.py +113 -0
- fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml +19 -0
- fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml +47 -0
- fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml +55 -0
- fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml +38 -0
- fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml +29 -0
- fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml +29 -0
- fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml +32 -0
- fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml +49 -0
- fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml +47 -0
- fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml +47 -0
- fairseq/examples/MMPT/projects/retri/videoclip.yaml +10 -0
- fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml +49 -0
- fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml +55 -0
- fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml +65 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml +33 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml +33 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml +40 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml +40 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml +31 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml +31 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml +31 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml +31 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml +33 -0
fairseq/docs/models.rst
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.. role:: hidden
|
2 |
+
:class: hidden-section
|
3 |
+
|
4 |
+
.. module:: fairseq.models
|
5 |
+
|
6 |
+
.. _Models:
|
7 |
+
|
8 |
+
Models
|
9 |
+
======
|
10 |
+
|
11 |
+
A Model defines the neural network's ``forward()`` method and encapsulates all
|
12 |
+
of the learnable parameters in the network. Each model also provides a set of
|
13 |
+
named *architectures* that define the precise network configuration (e.g.,
|
14 |
+
embedding dimension, number of layers, etc.).
|
15 |
+
|
16 |
+
Both the model type and architecture are selected via the ``--arch``
|
17 |
+
command-line argument. Once selected, a model may expose additional command-line
|
18 |
+
arguments for further configuration.
|
19 |
+
|
20 |
+
.. note::
|
21 |
+
|
22 |
+
All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
|
23 |
+
:class:`torch.nn.Module`. Thus any fairseq Model can be used as a
|
24 |
+
stand-alone Module in other PyTorch code.
|
25 |
+
|
26 |
+
|
27 |
+
Convolutional Neural Networks (CNN)
|
28 |
+
-----------------------------------
|
29 |
+
|
30 |
+
.. module:: fairseq.models.fconv
|
31 |
+
.. autoclass:: fairseq.models.fconv.FConvModel
|
32 |
+
:members:
|
33 |
+
.. autoclass:: fairseq.models.fconv.FConvEncoder
|
34 |
+
:members:
|
35 |
+
:undoc-members:
|
36 |
+
.. autoclass:: fairseq.models.fconv.FConvDecoder
|
37 |
+
:members:
|
38 |
+
|
39 |
+
|
40 |
+
Long Short-Term Memory (LSTM) networks
|
41 |
+
--------------------------------------
|
42 |
+
|
43 |
+
.. module:: fairseq.models.lstm
|
44 |
+
.. autoclass:: fairseq.models.lstm.LSTMModel
|
45 |
+
:members:
|
46 |
+
.. autoclass:: fairseq.models.lstm.LSTMEncoder
|
47 |
+
:members:
|
48 |
+
.. autoclass:: fairseq.models.lstm.LSTMDecoder
|
49 |
+
:members:
|
50 |
+
|
51 |
+
|
52 |
+
Transformer (self-attention) networks
|
53 |
+
-------------------------------------
|
54 |
+
|
55 |
+
.. module:: fairseq.models.transformer
|
56 |
+
.. autoclass:: fairseq.models.transformer.TransformerModel
|
57 |
+
:members:
|
58 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoder
|
59 |
+
:members:
|
60 |
+
.. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
|
61 |
+
:members:
|
62 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoder
|
63 |
+
:members:
|
64 |
+
.. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
|
65 |
+
:members:
|
66 |
+
|
67 |
+
|
68 |
+
Adding new models
|
69 |
+
-----------------
|
70 |
+
|
71 |
+
.. currentmodule:: fairseq.models
|
72 |
+
.. autofunction:: fairseq.models.register_model
|
73 |
+
.. autofunction:: fairseq.models.register_model_architecture
|
74 |
+
.. autoclass:: fairseq.models.BaseFairseqModel
|
75 |
+
:members:
|
76 |
+
:undoc-members:
|
77 |
+
.. autoclass:: fairseq.models.FairseqEncoderDecoderModel
|
78 |
+
:members:
|
79 |
+
:undoc-members:
|
80 |
+
.. autoclass:: fairseq.models.FairseqEncoderModel
|
81 |
+
:members:
|
82 |
+
:undoc-members:
|
83 |
+
.. autoclass:: fairseq.models.FairseqLanguageModel
|
84 |
+
:members:
|
85 |
+
:undoc-members:
|
86 |
+
.. autoclass:: fairseq.models.FairseqMultiModel
|
87 |
+
:members:
|
88 |
+
:undoc-members:
|
89 |
+
.. autoclass:: fairseq.models.FairseqEncoder
|
90 |
+
:members:
|
91 |
+
.. autoclass:: fairseq.models.CompositeEncoder
|
92 |
+
:members:
|
93 |
+
.. autoclass:: fairseq.models.FairseqDecoder
|
94 |
+
:members:
|
95 |
+
|
96 |
+
|
97 |
+
.. _Incremental decoding:
|
98 |
+
|
99 |
+
Incremental decoding
|
100 |
+
--------------------
|
101 |
+
|
102 |
+
.. autoclass:: fairseq.models.FairseqIncrementalDecoder
|
103 |
+
:members:
|
104 |
+
:undoc-members:
|
fairseq/docs/tutorial_classifying_names.rst
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tutorial: Classifying Names with a Character-Level RNN
|
2 |
+
======================================================
|
3 |
+
|
4 |
+
In this tutorial we will extend fairseq to support *classification* tasks. In
|
5 |
+
particular we will re-implement the PyTorch tutorial for `Classifying Names with
|
6 |
+
a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
|
7 |
+
in fairseq. It is recommended to quickly skim that tutorial before beginning
|
8 |
+
this one.
|
9 |
+
|
10 |
+
This tutorial covers:
|
11 |
+
|
12 |
+
1. **Preprocessing the data** to create dictionaries.
|
13 |
+
2. **Registering a new Model** that encodes an input sentence with a simple RNN
|
14 |
+
and predicts the output label.
|
15 |
+
3. **Registering a new Task** that loads our dictionaries and dataset.
|
16 |
+
4. **Training the Model** using the existing command-line tools.
|
17 |
+
5. **Writing an evaluation script** that imports fairseq and allows us to
|
18 |
+
interactively evaluate our model on new inputs.
|
19 |
+
|
20 |
+
|
21 |
+
1. Preprocessing the data
|
22 |
+
-------------------------
|
23 |
+
|
24 |
+
The original tutorial provides raw data, but we'll work with a modified version
|
25 |
+
of the data that is already tokenized into characters and split into separate
|
26 |
+
train, valid and test sets.
|
27 |
+
|
28 |
+
Download and extract the data from here:
|
29 |
+
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
|
30 |
+
|
31 |
+
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
|
32 |
+
command-line tool to create the dictionaries. While this tool is primarily
|
33 |
+
intended for sequence-to-sequence problems, we're able to reuse it here by
|
34 |
+
treating the label as a "target" sequence of length 1. We'll also output the
|
35 |
+
preprocessed files in "raw" format using the ``--dataset-impl`` option to
|
36 |
+
enhance readability:
|
37 |
+
|
38 |
+
.. code-block:: console
|
39 |
+
|
40 |
+
> fairseq-preprocess \
|
41 |
+
--trainpref names/train --validpref names/valid --testpref names/test \
|
42 |
+
--source-lang input --target-lang label \
|
43 |
+
--destdir names-bin --dataset-impl raw
|
44 |
+
|
45 |
+
After running the above command you should see a new directory,
|
46 |
+
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
|
47 |
+
|
48 |
+
|
49 |
+
2. Registering a new Model
|
50 |
+
--------------------------
|
51 |
+
|
52 |
+
Next we'll register a new model in fairseq that will encode an input sentence
|
53 |
+
with a simple RNN and predict the output label. Compared to the original PyTorch
|
54 |
+
tutorial, our version will also work with batches of data and GPU Tensors.
|
55 |
+
|
56 |
+
First let's copy the simple RNN module implemented in the `PyTorch tutorial
|
57 |
+
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
|
58 |
+
Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
|
59 |
+
following contents::
|
60 |
+
|
61 |
+
import torch
|
62 |
+
import torch.nn as nn
|
63 |
+
|
64 |
+
class RNN(nn.Module):
|
65 |
+
|
66 |
+
def __init__(self, input_size, hidden_size, output_size):
|
67 |
+
super(RNN, self).__init__()
|
68 |
+
|
69 |
+
self.hidden_size = hidden_size
|
70 |
+
|
71 |
+
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
|
72 |
+
self.i2o = nn.Linear(input_size + hidden_size, output_size)
|
73 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
74 |
+
|
75 |
+
def forward(self, input, hidden):
|
76 |
+
combined = torch.cat((input, hidden), 1)
|
77 |
+
hidden = self.i2h(combined)
|
78 |
+
output = self.i2o(combined)
|
79 |
+
output = self.softmax(output)
|
80 |
+
return output, hidden
|
81 |
+
|
82 |
+
def initHidden(self):
|
83 |
+
return torch.zeros(1, self.hidden_size)
|
84 |
+
|
85 |
+
We must also *register* this model with fairseq using the
|
86 |
+
:func:`~fairseq.models.register_model` function decorator. Once the model is
|
87 |
+
registered we'll be able to use it with the existing :ref:`Command-line Tools`.
|
88 |
+
|
89 |
+
All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
|
90 |
+
interface, so we'll create a small wrapper class in the same file and register
|
91 |
+
it in fairseq with the name ``'rnn_classifier'``::
|
92 |
+
|
93 |
+
from fairseq.models import BaseFairseqModel, register_model
|
94 |
+
|
95 |
+
# Note: the register_model "decorator" should immediately precede the
|
96 |
+
# definition of the Model class.
|
97 |
+
|
98 |
+
@register_model('rnn_classifier')
|
99 |
+
class FairseqRNNClassifier(BaseFairseqModel):
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def add_args(parser):
|
103 |
+
# Models can override this method to add new command-line arguments.
|
104 |
+
# Here we'll add a new command-line argument to configure the
|
105 |
+
# dimensionality of the hidden state.
|
106 |
+
parser.add_argument(
|
107 |
+
'--hidden-dim', type=int, metavar='N',
|
108 |
+
help='dimensionality of the hidden state',
|
109 |
+
)
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def build_model(cls, args, task):
|
113 |
+
# Fairseq initializes models by calling the ``build_model()``
|
114 |
+
# function. This provides more flexibility, since the returned model
|
115 |
+
# instance can be of a different type than the one that was called.
|
116 |
+
# In this case we'll just return a FairseqRNNClassifier instance.
|
117 |
+
|
118 |
+
# Initialize our RNN module
|
119 |
+
rnn = RNN(
|
120 |
+
# We'll define the Task in the next section, but for now just
|
121 |
+
# notice that the task holds the dictionaries for the "source"
|
122 |
+
# (i.e., the input sentence) and "target" (i.e., the label).
|
123 |
+
input_size=len(task.source_dictionary),
|
124 |
+
hidden_size=args.hidden_dim,
|
125 |
+
output_size=len(task.target_dictionary),
|
126 |
+
)
|
127 |
+
|
128 |
+
# Return the wrapped version of the module
|
129 |
+
return FairseqRNNClassifier(
|
130 |
+
rnn=rnn,
|
131 |
+
input_vocab=task.source_dictionary,
|
132 |
+
)
|
133 |
+
|
134 |
+
def __init__(self, rnn, input_vocab):
|
135 |
+
super(FairseqRNNClassifier, self).__init__()
|
136 |
+
|
137 |
+
self.rnn = rnn
|
138 |
+
self.input_vocab = input_vocab
|
139 |
+
|
140 |
+
# The RNN module in the tutorial expects one-hot inputs, so we can
|
141 |
+
# precompute the identity matrix to help convert from indices to
|
142 |
+
# one-hot vectors. We register it as a buffer so that it is moved to
|
143 |
+
# the GPU when ``cuda()`` is called.
|
144 |
+
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
|
145 |
+
|
146 |
+
def forward(self, src_tokens, src_lengths):
|
147 |
+
# The inputs to the ``forward()`` function are determined by the
|
148 |
+
# Task, and in particular the ``'net_input'`` key in each
|
149 |
+
# mini-batch. We'll define the Task in the next section, but for
|
150 |
+
# now just know that *src_tokens* has shape `(batch, src_len)` and
|
151 |
+
# *src_lengths* has shape `(batch)`.
|
152 |
+
bsz, max_src_len = src_tokens.size()
|
153 |
+
|
154 |
+
# Initialize the RNN hidden state. Compared to the original PyTorch
|
155 |
+
# tutorial we'll also handle batched inputs and work on the GPU.
|
156 |
+
hidden = self.rnn.initHidden()
|
157 |
+
hidden = hidden.repeat(bsz, 1) # expand for batched inputs
|
158 |
+
hidden = hidden.to(src_tokens.device) # move to GPU
|
159 |
+
|
160 |
+
for i in range(max_src_len):
|
161 |
+
# WARNING: The inputs have padding, so we should mask those
|
162 |
+
# elements here so that padding doesn't affect the results.
|
163 |
+
# This is left as an exercise for the reader. The padding symbol
|
164 |
+
# is given by ``self.input_vocab.pad()`` and the unpadded length
|
165 |
+
# of each input is given by *src_lengths*.
|
166 |
+
|
167 |
+
# One-hot encode a batch of input characters.
|
168 |
+
input = self.one_hot_inputs[src_tokens[:, i].long()]
|
169 |
+
|
170 |
+
# Feed the input to our RNN.
|
171 |
+
output, hidden = self.rnn(input, hidden)
|
172 |
+
|
173 |
+
# Return the final output state for making a prediction
|
174 |
+
return output
|
175 |
+
|
176 |
+
Finally let's define a *named architecture* with the configuration for our
|
177 |
+
model. This is done with the :func:`~fairseq.models.register_model_architecture`
|
178 |
+
function decorator. Thereafter this named architecture can be used with the
|
179 |
+
``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
|
180 |
+
|
181 |
+
from fairseq.models import register_model_architecture
|
182 |
+
|
183 |
+
# The first argument to ``register_model_architecture()`` should be the name
|
184 |
+
# of the model we registered above (i.e., 'rnn_classifier'). The function we
|
185 |
+
# register here should take a single argument *args* and modify it in-place
|
186 |
+
# to match the desired architecture.
|
187 |
+
|
188 |
+
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
|
189 |
+
def pytorch_tutorial_rnn(args):
|
190 |
+
# We use ``getattr()`` to prioritize arguments that are explicitly given
|
191 |
+
# on the command-line, so that the defaults defined below are only used
|
192 |
+
# when no other value has been specified.
|
193 |
+
args.hidden_dim = getattr(args, 'hidden_dim', 128)
|
194 |
+
|
195 |
+
|
196 |
+
3. Registering a new Task
|
197 |
+
-------------------------
|
198 |
+
|
199 |
+
Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
|
200 |
+
dictionaries and dataset. Tasks can also control how the data is batched into
|
201 |
+
mini-batches, but in this tutorial we'll reuse the batching provided by
|
202 |
+
:class:`fairseq.data.LanguagePairDataset`.
|
203 |
+
|
204 |
+
Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
|
205 |
+
following contents::
|
206 |
+
|
207 |
+
import os
|
208 |
+
import torch
|
209 |
+
|
210 |
+
from fairseq.data import Dictionary, LanguagePairDataset
|
211 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
212 |
+
|
213 |
+
|
214 |
+
@register_task('simple_classification')
|
215 |
+
class SimpleClassificationTask(LegacyFairseqTask):
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def add_args(parser):
|
219 |
+
# Add some command-line arguments for specifying where the data is
|
220 |
+
# located and the maximum supported input length.
|
221 |
+
parser.add_argument('data', metavar='FILE',
|
222 |
+
help='file prefix for data')
|
223 |
+
parser.add_argument('--max-positions', default=1024, type=int,
|
224 |
+
help='max input length')
|
225 |
+
|
226 |
+
@classmethod
|
227 |
+
def setup_task(cls, args, **kwargs):
|
228 |
+
# Here we can perform any setup required for the task. This may include
|
229 |
+
# loading Dictionaries, initializing shared Embedding layers, etc.
|
230 |
+
# In this case we'll just load the Dictionaries.
|
231 |
+
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
|
232 |
+
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
|
233 |
+
print('| [input] dictionary: {} types'.format(len(input_vocab)))
|
234 |
+
print('| [label] dictionary: {} types'.format(len(label_vocab)))
|
235 |
+
|
236 |
+
return SimpleClassificationTask(args, input_vocab, label_vocab)
|
237 |
+
|
238 |
+
def __init__(self, args, input_vocab, label_vocab):
|
239 |
+
super().__init__(args)
|
240 |
+
self.input_vocab = input_vocab
|
241 |
+
self.label_vocab = label_vocab
|
242 |
+
|
243 |
+
def load_dataset(self, split, **kwargs):
|
244 |
+
"""Load a given dataset split (e.g., train, valid, test)."""
|
245 |
+
|
246 |
+
prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
|
247 |
+
|
248 |
+
# Read input sentences.
|
249 |
+
sentences, lengths = [], []
|
250 |
+
with open(prefix + '.input', encoding='utf-8') as file:
|
251 |
+
for line in file:
|
252 |
+
sentence = line.strip()
|
253 |
+
|
254 |
+
# Tokenize the sentence, splitting on spaces
|
255 |
+
tokens = self.input_vocab.encode_line(
|
256 |
+
sentence, add_if_not_exist=False,
|
257 |
+
)
|
258 |
+
|
259 |
+
sentences.append(tokens)
|
260 |
+
lengths.append(tokens.numel())
|
261 |
+
|
262 |
+
# Read labels.
|
263 |
+
labels = []
|
264 |
+
with open(prefix + '.label', encoding='utf-8') as file:
|
265 |
+
for line in file:
|
266 |
+
label = line.strip()
|
267 |
+
labels.append(
|
268 |
+
# Convert label to a numeric ID.
|
269 |
+
torch.LongTensor([self.label_vocab.add_symbol(label)])
|
270 |
+
)
|
271 |
+
|
272 |
+
assert len(sentences) == len(labels)
|
273 |
+
print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
|
274 |
+
|
275 |
+
# We reuse LanguagePairDataset since classification can be modeled as a
|
276 |
+
# sequence-to-sequence task where the target sequence has length 1.
|
277 |
+
self.datasets[split] = LanguagePairDataset(
|
278 |
+
src=sentences,
|
279 |
+
src_sizes=lengths,
|
280 |
+
src_dict=self.input_vocab,
|
281 |
+
tgt=labels,
|
282 |
+
tgt_sizes=torch.ones(len(labels)), # targets have length 1
|
283 |
+
tgt_dict=self.label_vocab,
|
284 |
+
left_pad_source=False,
|
285 |
+
# Since our target is a single class label, there's no need for
|
286 |
+
# teacher forcing. If we set this to ``True`` then our Model's
|
287 |
+
# ``forward()`` method would receive an additional argument called
|
288 |
+
# *prev_output_tokens* that would contain a shifted version of the
|
289 |
+
# target sequence.
|
290 |
+
input_feeding=False,
|
291 |
+
)
|
292 |
+
|
293 |
+
def max_positions(self):
|
294 |
+
"""Return the max input length allowed by the task."""
|
295 |
+
# The source should be less than *args.max_positions* and the "target"
|
296 |
+
# has max length 1.
|
297 |
+
return (self.args.max_positions, 1)
|
298 |
+
|
299 |
+
@property
|
300 |
+
def source_dictionary(self):
|
301 |
+
"""Return the source :class:`~fairseq.data.Dictionary`."""
|
302 |
+
return self.input_vocab
|
303 |
+
|
304 |
+
@property
|
305 |
+
def target_dictionary(self):
|
306 |
+
"""Return the target :class:`~fairseq.data.Dictionary`."""
|
307 |
+
return self.label_vocab
|
308 |
+
|
309 |
+
# We could override this method if we wanted more control over how batches
|
310 |
+
# are constructed, but it's not necessary for this tutorial since we can
|
311 |
+
# reuse the batching provided by LanguagePairDataset.
|
312 |
+
#
|
313 |
+
# def get_batch_iterator(
|
314 |
+
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
|
315 |
+
# ignore_invalid_inputs=False, required_batch_size_multiple=1,
|
316 |
+
# seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
|
317 |
+
# data_buffer_size=0, disable_iterator_cache=False,
|
318 |
+
# ):
|
319 |
+
# (...)
|
320 |
+
|
321 |
+
|
322 |
+
4. Training the Model
|
323 |
+
---------------------
|
324 |
+
|
325 |
+
Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
|
326 |
+
command-line tool for this, making sure to specify our new Task (``--task
|
327 |
+
simple_classification``) and Model architecture (``--arch
|
328 |
+
pytorch_tutorial_rnn``):
|
329 |
+
|
330 |
+
.. note::
|
331 |
+
|
332 |
+
You can also configure the dimensionality of the hidden state by passing the
|
333 |
+
``--hidden-dim`` argument to :ref:`fairseq-train`.
|
334 |
+
|
335 |
+
.. code-block:: console
|
336 |
+
|
337 |
+
> fairseq-train names-bin \
|
338 |
+
--task simple_classification \
|
339 |
+
--arch pytorch_tutorial_rnn \
|
340 |
+
--optimizer adam --lr 0.001 --lr-shrink 0.5 \
|
341 |
+
--max-tokens 1000
|
342 |
+
(...)
|
343 |
+
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
|
344 |
+
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
|
345 |
+
| done training in 31.6 seconds
|
346 |
+
|
347 |
+
The model files should appear in the :file:`checkpoints/` directory.
|
348 |
+
|
349 |
+
|
350 |
+
5. Writing an evaluation script
|
351 |
+
-------------------------------
|
352 |
+
|
353 |
+
Finally we can write a short script to evaluate our model on new inputs. Create
|
354 |
+
a new file named :file:`eval_classifier.py` with the following contents::
|
355 |
+
|
356 |
+
from fairseq import checkpoint_utils, data, options, tasks
|
357 |
+
|
358 |
+
# Parse command-line arguments for generation
|
359 |
+
parser = options.get_generation_parser(default_task='simple_classification')
|
360 |
+
args = options.parse_args_and_arch(parser)
|
361 |
+
|
362 |
+
# Setup task
|
363 |
+
task = tasks.setup_task(args)
|
364 |
+
|
365 |
+
# Load model
|
366 |
+
print('| loading model from {}'.format(args.path))
|
367 |
+
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
|
368 |
+
model = models[0]
|
369 |
+
|
370 |
+
while True:
|
371 |
+
sentence = input('\nInput: ')
|
372 |
+
|
373 |
+
# Tokenize into characters
|
374 |
+
chars = ' '.join(list(sentence.strip()))
|
375 |
+
tokens = task.source_dictionary.encode_line(
|
376 |
+
chars, add_if_not_exist=False,
|
377 |
+
)
|
378 |
+
|
379 |
+
# Build mini-batch to feed to the model
|
380 |
+
batch = data.language_pair_dataset.collate(
|
381 |
+
samples=[{'id': -1, 'source': tokens}], # bsz = 1
|
382 |
+
pad_idx=task.source_dictionary.pad(),
|
383 |
+
eos_idx=task.source_dictionary.eos(),
|
384 |
+
left_pad_source=False,
|
385 |
+
input_feeding=False,
|
386 |
+
)
|
387 |
+
|
388 |
+
# Feed batch to the model and get predictions
|
389 |
+
preds = model(**batch['net_input'])
|
390 |
+
|
391 |
+
# Print top 3 predictions and their log-probabilities
|
392 |
+
top_scores, top_labels = preds[0].topk(k=3)
|
393 |
+
for score, label_idx in zip(top_scores, top_labels):
|
394 |
+
label_name = task.target_dictionary.string([label_idx])
|
395 |
+
print('({:.2f})\t{}'.format(score, label_name))
|
396 |
+
|
397 |
+
Now we can evaluate our model interactively. Note that we have included the
|
398 |
+
original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
|
399 |
+
|
400 |
+
.. code-block:: console
|
401 |
+
|
402 |
+
> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
|
403 |
+
| [input] dictionary: 64 types
|
404 |
+
| [label] dictionary: 24 types
|
405 |
+
| loading model from checkpoints/checkpoint_best.pt
|
406 |
+
|
407 |
+
Input: Satoshi
|
408 |
+
(-0.61) Japanese
|
409 |
+
(-1.20) Arabic
|
410 |
+
(-2.86) Italian
|
411 |
+
|
412 |
+
Input: Sinbad
|
413 |
+
(-0.30) Arabic
|
414 |
+
(-1.76) English
|
415 |
+
(-4.08) Russian
|
fairseq/examples/MMPT/.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
runs
|
131 |
+
data
|
132 |
+
pretrained_models
|
133 |
+
projects/mmfusion_*
|
134 |
+
log_test
|
135 |
+
third-party
|
136 |
+
python_log
|
137 |
+
slurm_snapshot_code
|
138 |
+
lightning_logs
|
139 |
+
demos
|
fairseq/examples/MMPT/README.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VideoCLIP and VLM
|
2 |
+
|
3 |
+
You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq).
|
4 |
+
|
5 |
+
VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
|
6 |
+
|
7 |
+
<img src="videoclip.png" width="350" class="center">
|
8 |
+
|
9 |
+
VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
|
10 |
+
|
11 |
+
<img src="vlm.png" width="350" class="center">
|
12 |
+
|
13 |
+
### News
|
14 |
+
[Oct. 2021] Initial release of implementation for the following papers:
|
15 |
+
[VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)
|
16 |
+
[VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)
|
17 |
+
|
18 |
+
|
19 |
+
### Installation
|
20 |
+
We aim to minimize the dependency of this repo on other packages.
|
21 |
+
We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):
|
22 |
+
```
|
23 |
+
git clone https://github.com/pytorch/fairseq
|
24 |
+
cd fairseq
|
25 |
+
pip install -e . # also optionally follow fairseq README for apex installation for fp16 training.
|
26 |
+
export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy.
|
27 |
+
```
|
28 |
+
|
29 |
+
Then install this toolkit:
|
30 |
+
```
|
31 |
+
cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples.
|
32 |
+
pip install -e .
|
33 |
+
```
|
34 |
+
|
35 |
+
The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
|
36 |
+
Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`.
|
37 |
+
In addition, some downstream tasks may need `conda install pandas`.
|
38 |
+
|
39 |
+
|
40 |
+
### Usage
|
41 |
+
#### Download Checkpoints
|
42 |
+
We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
|
43 |
+
|
44 |
+
Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
|
45 |
+
|
46 |
+
#### Demo of Inference
|
47 |
+
run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
|
48 |
+
|
49 |
+
```python
|
50 |
+
import torch
|
51 |
+
|
52 |
+
from mmpt.models import MMPTModel
|
53 |
+
|
54 |
+
|
55 |
+
model, tokenizer, aligner = MMPTModel.from_pretrained(
|
56 |
+
"projects/retri/videoclip/how2.yaml")
|
57 |
+
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
|
61 |
+
# B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
|
62 |
+
video_frames = torch.randn(1, 2, 30, 224, 224, 3)
|
63 |
+
caps, cmasks = aligner._build_text_seq(
|
64 |
+
tokenizer("some text", add_special_tokens=False)["input_ids"]
|
65 |
+
)
|
66 |
+
|
67 |
+
caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
output = model(video_frames, caps, cmasks, return_score=True)
|
71 |
+
print(output["score"]) # dot-product
|
72 |
+
```
|
73 |
+
|
74 |
+
#### Data Preparation
|
75 |
+
See [dataset](DATASET.md) for each dataset.
|
76 |
+
|
77 |
+
#### Global Config for Training Pipeline
|
78 |
+
We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
|
79 |
+
|
80 |
+
We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.
|
81 |
+
|
82 |
+
First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.
|
83 |
+
|
84 |
+
Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
|
85 |
+
```
|
86 |
+
python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation.
|
87 |
+
python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
|
88 |
+
python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model.
|
89 |
+
```
|
90 |
+
|
91 |
+
Pretraining can be run as:
|
92 |
+
```
|
93 |
+
python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
|
94 |
+
```
|
95 |
+
You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
|
96 |
+
|
97 |
+
The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
|
98 |
+
|
99 |
+
|
100 |
+
### Development
|
101 |
+
Several components of this toolkit can be re-used for future research (and also our ongoing research).
|
102 |
+
|
103 |
+
#### Framework Wrapper
|
104 |
+
We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.
|
105 |
+
|
106 |
+
#### Processors
|
107 |
+
**Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).
|
108 |
+
Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
|
109 |
+
|
110 |
+
We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).
|
111 |
+
`MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.
|
112 |
+
`VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.
|
113 |
+
`TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).
|
114 |
+
`Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
|
115 |
+
|
116 |
+
#### Performance-tuned Components
|
117 |
+
To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
|
118 |
+
|
119 |
+
|
120 |
+
### Citation
|
121 |
+
If this codebase is useful for your work, please cite the following papers:
|
122 |
+
|
123 |
+
```BibTeX
|
124 |
+
@inproceedings{xu-etal-2021-videoclip,
|
125 |
+
title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
|
126 |
+
author = "Xu, Hu and
|
127 |
+
Ghosh, Gargi and
|
128 |
+
Huang, Po-Yao and
|
129 |
+
Okhonko, Dmytro and
|
130 |
+
Aghajanyan, Armen and
|
131 |
+
Metze, Florian and
|
132 |
+
Zettlemoyer, Luke and
|
133 |
+
Feichtenhofer, Christoph",
|
134 |
+
booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
|
135 |
+
month = nov,
|
136 |
+
year = "2021",
|
137 |
+
address = "Online",
|
138 |
+
publisher = "Association for Computational Linguistics",
|
139 |
+
}
|
140 |
+
|
141 |
+
@inproceedings{xu-etal-2021-vlm,
|
142 |
+
title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
|
143 |
+
author = "Xu, Hu and
|
144 |
+
Ghosh, Gargi and
|
145 |
+
Huang, Po-Yao and
|
146 |
+
Arora, Prahal and
|
147 |
+
Aminzadeh, Masoumeh and
|
148 |
+
Feichtenhofer, Christoph and
|
149 |
+
Metze, Florian and
|
150 |
+
Zettlemoyer, Luke",
|
151 |
+
booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
|
152 |
+
month = aug,
|
153 |
+
year = "2021",
|
154 |
+
address = "Online",
|
155 |
+
publisher = "Association for Computational Linguistics",
|
156 |
+
url = "https://aclanthology.org/2021.findings-acl.370",
|
157 |
+
doi = "10.18653/v1/2021.findings-acl.370",
|
158 |
+
pages = "4227--4239",
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
### Bug Reports
|
163 |
+
This repo is in its initial stage, welcome bug reports to [email protected]
|
164 |
+
|
165 |
+
### Copyright
|
166 |
+
The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.
|
fairseq/examples/MMPT/endtask.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Zero-shot Transfer and Finetuning
|
2 |
+
|
3 |
+
(If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
|
4 |
+
All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
|
5 |
+
Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
|
6 |
+
|
7 |
+
### Tasks
|
8 |
+
|
9 |
+
Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:
|
10 |
+
text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;
|
11 |
+
video captioning: `Youcook`;
|
12 |
+
Video Question and Answering: `MSRVTT-QA`.
|
13 |
+
|
14 |
+
To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
|
15 |
+
|
16 |
+
### Zero-shot Transfer (no Training)
|
17 |
+
Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
|
18 |
+
|
19 |
+
### Fine-tuning
|
20 |
+
|
21 |
+
The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
|
22 |
+
|
23 |
+
We typically do finetuning on 2 gpus (`local_small`).
|
24 |
+
|
25 |
+
### Testing
|
26 |
+
For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.
|
27 |
+
|
28 |
+
We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
|
29 |
+
|
30 |
+
Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
|
31 |
+
|
32 |
+
Launching a testing is as simple as training by specifying the path of a testing config:
|
33 |
+
```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
|
34 |
+
Testing will be launched locally by default since prediction is computationally less expensive.
|
35 |
+
|
36 |
+
### Third-party Libraries
|
37 |
+
We list the following finetuning tasks that require third-party libraries.
|
38 |
+
|
39 |
+
Youcook captioning: `https://github.com/Maluuba/nlg-eval`
|
40 |
+
|
41 |
+
CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)
|
fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
|
7 |
+
"""
|
8 |
+
|
9 |
+
from collections import OrderedDict
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
from fairseq.data import FairseqDataset, data_utils
|
14 |
+
|
15 |
+
|
16 |
+
class FairseqMMDataset(FairseqDataset):
|
17 |
+
"""
|
18 |
+
A wrapper class for MMDataset for fairseq.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, mmdataset):
|
22 |
+
if not isinstance(mmdataset, Dataset):
|
23 |
+
raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
|
24 |
+
self.mmdataset = mmdataset
|
25 |
+
|
26 |
+
def set_epoch(self, epoch, **unused):
|
27 |
+
super().set_epoch(epoch)
|
28 |
+
self.epoch = epoch
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
with data_utils.numpy_seed(43211, self.epoch, idx):
|
32 |
+
return self.mmdataset[idx]
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.mmdataset)
|
36 |
+
|
37 |
+
def collater(self, samples):
|
38 |
+
if hasattr(self.mmdataset, "collator"):
|
39 |
+
return self.mmdataset.collator(samples)
|
40 |
+
if len(samples) == 0:
|
41 |
+
return {}
|
42 |
+
if isinstance(samples[0], dict):
|
43 |
+
batch = OrderedDict()
|
44 |
+
for key in samples[0]:
|
45 |
+
if samples[0][key] is not None:
|
46 |
+
batch[key] = default_collate([sample[key] for sample in samples])
|
47 |
+
return batch
|
48 |
+
else:
|
49 |
+
return default_collate(samples)
|
50 |
+
|
51 |
+
def size(self, index):
|
52 |
+
"""dummy implementation: we don't use --max-tokens"""
|
53 |
+
return 1
|
54 |
+
|
55 |
+
def num_tokens(self, index):
|
56 |
+
"""dummy implementation: we don't use --max-tokens"""
|
57 |
+
return 1
|
fairseq/examples/MMPT/mmpt/datasets/mmdataset.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torch.utils.data.dataloader import default_collate
|
12 |
+
|
13 |
+
from ..utils import set_seed
|
14 |
+
|
15 |
+
|
16 |
+
class MMDataset(Dataset):
|
17 |
+
"""
|
18 |
+
A generic multi-modal dataset.
|
19 |
+
Args:
|
20 |
+
`meta_processor`: a meta processor,
|
21 |
+
handling loading meta data and return video_id and text_id.
|
22 |
+
`video_processor`: a video processor,
|
23 |
+
handling e.g., decoding, loading .np files.
|
24 |
+
`text_processor`: a text processor,
|
25 |
+
handling e.g., tokenization.
|
26 |
+
`aligner`: combine the video and text feature
|
27 |
+
as one training example.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
meta_processor,
|
33 |
+
video_processor,
|
34 |
+
text_processor,
|
35 |
+
align_processor,
|
36 |
+
):
|
37 |
+
self.split = meta_processor.split
|
38 |
+
self.meta_processor = meta_processor
|
39 |
+
self.video_processor = video_processor
|
40 |
+
self.text_processor = text_processor
|
41 |
+
self.align_processor = align_processor
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.meta_processor)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
if self.split == "test":
|
48 |
+
set_seed(idx)
|
49 |
+
video_id, text_id = self.meta_processor[idx]
|
50 |
+
video_feature = self.video_processor(video_id)
|
51 |
+
text_feature = self.text_processor(text_id)
|
52 |
+
output = self.align_processor(video_id, video_feature, text_feature)
|
53 |
+
# TODO (huxu): the following is for debug purpose.
|
54 |
+
output.update({"idx": idx})
|
55 |
+
return output
|
56 |
+
|
57 |
+
def collater(self, samples):
|
58 |
+
"""This collator is deprecated.
|
59 |
+
set self.collator = MMDataset.collater.
|
60 |
+
see collator in FairseqMMDataset.
|
61 |
+
"""
|
62 |
+
|
63 |
+
if len(samples) == 0:
|
64 |
+
return {}
|
65 |
+
if isinstance(samples[0], dict):
|
66 |
+
batch = OrderedDict()
|
67 |
+
for key in samples[0]:
|
68 |
+
if samples[0][key] is not None:
|
69 |
+
batch[key] = default_collate(
|
70 |
+
[sample[key] for sample in samples])
|
71 |
+
# if torch.is_tensor(batch[key]):
|
72 |
+
# print(key, batch[key].size())
|
73 |
+
# else:
|
74 |
+
# print(key, len(batch[key]))
|
75 |
+
return batch
|
76 |
+
else:
|
77 |
+
return default_collate(samples)
|
78 |
+
|
79 |
+
def print_example(self, output):
|
80 |
+
print("[one example]", output["video_id"])
|
81 |
+
if (
|
82 |
+
hasattr(self.align_processor, "subsampling")
|
83 |
+
and self.align_processor.subsampling is not None
|
84 |
+
and self.align_processor.subsampling > 1
|
85 |
+
):
|
86 |
+
for key in output:
|
87 |
+
if torch.is_tensor(output[key]):
|
88 |
+
output[key] = output[key][0]
|
89 |
+
|
90 |
+
# search tokenizer to translate ids back.
|
91 |
+
tokenizer = None
|
92 |
+
if hasattr(self.text_processor, "tokenizer"):
|
93 |
+
tokenizer = self.text_processor.tokenizer
|
94 |
+
elif hasattr(self.align_processor, "tokenizer"):
|
95 |
+
tokenizer = self.align_processor.tokenizer
|
96 |
+
if tokenizer is not None:
|
97 |
+
caps = output["caps"].tolist()
|
98 |
+
if isinstance(caps[0], list):
|
99 |
+
caps = caps[0]
|
100 |
+
print("caps", tokenizer.decode(caps))
|
101 |
+
print("caps", tokenizer.convert_ids_to_tokens(caps))
|
102 |
+
|
103 |
+
for key, value in output.items():
|
104 |
+
if torch.is_tensor(value):
|
105 |
+
if len(value.size()) >= 3: # attention_mask.
|
106 |
+
print(key, value.size())
|
107 |
+
print(key, "first", value[0, :, :])
|
108 |
+
print(key, "last", value[-1, :, :])
|
109 |
+
else:
|
110 |
+
print(key, value)
|
111 |
+
print("[end of one example]")
|
fairseq/examples/MMPT/mmpt/evaluators/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .metric import *
|
6 |
+
from .evaluator import *
|
7 |
+
|
8 |
+
|
9 |
+
# experimental.
|
10 |
+
try:
|
11 |
+
from .expmetric import *
|
12 |
+
except ImportError:
|
13 |
+
pass
|
fairseq/examples/MMPT/mmpt/evaluators/evaluator.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import glob
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from . import metric as metric_path
|
10 |
+
from . import predictor as predictor_path
|
11 |
+
|
12 |
+
|
13 |
+
class Evaluator(object):
|
14 |
+
"""
|
15 |
+
perform evaluation on a single (downstream) task.
|
16 |
+
make this both offline and online.
|
17 |
+
TODO(huxu) saving evaluation results.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, config, eval_dataloader=None):
|
21 |
+
if config.metric is None:
|
22 |
+
raise ValueError("config.metric is", config.metric)
|
23 |
+
metric_cls = getattr(metric_path, config.metric)
|
24 |
+
self.metric = metric_cls(config)
|
25 |
+
if config.predictor is None:
|
26 |
+
raise ValueError("config.predictor is", config.predictor)
|
27 |
+
predictor_cls = getattr(predictor_path, config.predictor)
|
28 |
+
self.predictor = predictor_cls(config)
|
29 |
+
self.eval_dataloader = eval_dataloader
|
30 |
+
|
31 |
+
def __call__(self):
|
32 |
+
try:
|
33 |
+
print(self.predictor.pred_dir)
|
34 |
+
for pred_file in glob.glob(
|
35 |
+
self.predictor.pred_dir + "/*_merged.npy"):
|
36 |
+
outputs = np.load(pred_file)
|
37 |
+
results = self.metric.compute_metrics(outputs)
|
38 |
+
self.metric.print_computed_metrics(results)
|
39 |
+
|
40 |
+
outputs = np.load(os.path.join(
|
41 |
+
self.predictor.pred_dir, "merged.npy"))
|
42 |
+
results = self.metric.compute_metrics(outputs)
|
43 |
+
return {"results": results, "metric": self.metric}
|
44 |
+
except FileNotFoundError:
|
45 |
+
print("\n[missing]", self.predictor.pred_dir)
|
46 |
+
return {}
|
47 |
+
|
48 |
+
def evaluate(self, model, eval_dataloader=None, output_file="merged"):
|
49 |
+
if eval_dataloader is None:
|
50 |
+
eval_dataloader = self.eval_dataloader
|
51 |
+
outputs = self.predictor.predict_loop(
|
52 |
+
model, eval_dataloader, output_file)
|
53 |
+
results = self.metric.compute_metrics(**outputs)
|
54 |
+
return results
|
fairseq/examples/MMPT/mmpt/models/transformermodel.py
ADDED
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
try:
|
23 |
+
from transformers.modeling_bert import (
|
24 |
+
BertPreTrainedModel,
|
25 |
+
BertModel,
|
26 |
+
BertEncoder,
|
27 |
+
BertPredictionHeadTransform,
|
28 |
+
)
|
29 |
+
except ImportError:
|
30 |
+
pass
|
31 |
+
|
32 |
+
from ..modules import VideoTokenMLP, MMBertEmbeddings
|
33 |
+
|
34 |
+
|
35 |
+
# --------------- fine-tuning models ---------------
|
36 |
+
class MMBertForJoint(BertPreTrainedModel):
|
37 |
+
"""A BertModel with isolated attention mask to separate modality."""
|
38 |
+
|
39 |
+
def __init__(self, config):
|
40 |
+
super().__init__(config)
|
41 |
+
self.videomlp = VideoTokenMLP(config)
|
42 |
+
self.bert = MMBertModel(config)
|
43 |
+
self.init_weights()
|
44 |
+
|
45 |
+
def forward(
|
46 |
+
self,
|
47 |
+
input_ids=None,
|
48 |
+
input_video_embeds=None,
|
49 |
+
attention_mask=None,
|
50 |
+
token_type_ids=None,
|
51 |
+
position_ids=None,
|
52 |
+
head_mask=None,
|
53 |
+
inputs_embeds=None,
|
54 |
+
next_sentence_label=None,
|
55 |
+
output_attentions=None,
|
56 |
+
output_hidden_states=None,
|
57 |
+
return_dict=None,
|
58 |
+
separate_forward_split=None,
|
59 |
+
):
|
60 |
+
return_dict = (
|
61 |
+
return_dict if return_dict is not None
|
62 |
+
else self.config.use_return_dict
|
63 |
+
)
|
64 |
+
video_tokens = self.videomlp(input_video_embeds)
|
65 |
+
|
66 |
+
outputs = self.bert(
|
67 |
+
input_ids,
|
68 |
+
video_tokens,
|
69 |
+
attention_mask=attention_mask,
|
70 |
+
token_type_ids=token_type_ids,
|
71 |
+
position_ids=position_ids,
|
72 |
+
head_mask=head_mask,
|
73 |
+
inputs_embeds=inputs_embeds,
|
74 |
+
output_attentions=output_attentions,
|
75 |
+
output_hidden_states=output_hidden_states,
|
76 |
+
return_dict=return_dict,
|
77 |
+
separate_forward_split=separate_forward_split,
|
78 |
+
)
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
|
83 |
+
class MMBertForTokenClassification(BertPreTrainedModel):
|
84 |
+
"""A BertModel similar to MMJointUni, with extra wrapper layer
|
85 |
+
to be fine-tuned from other pretrained MMFusion model."""
|
86 |
+
|
87 |
+
def __init__(self, config):
|
88 |
+
super().__init__(config)
|
89 |
+
self.videomlp = VideoTokenMLP(config)
|
90 |
+
self.bert = MMBertModel(config)
|
91 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
92 |
+
# TODO(huxu): 779 is the number of classes for COIN: move to config?
|
93 |
+
self.classifier = nn.Linear(config.hidden_size, 779)
|
94 |
+
self.init_weights()
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
input_ids=None,
|
99 |
+
input_video_embeds=None,
|
100 |
+
attention_mask=None,
|
101 |
+
token_type_ids=None,
|
102 |
+
position_ids=None,
|
103 |
+
head_mask=None,
|
104 |
+
inputs_embeds=None,
|
105 |
+
next_sentence_label=None,
|
106 |
+
output_attentions=None,
|
107 |
+
output_hidden_states=None,
|
108 |
+
return_dict=None,
|
109 |
+
separate_forward_split=None,
|
110 |
+
):
|
111 |
+
return_dict = (
|
112 |
+
return_dict if return_dict is not None
|
113 |
+
else self.config.use_return_dict
|
114 |
+
)
|
115 |
+
|
116 |
+
video_tokens = self.videomlp(input_video_embeds)
|
117 |
+
outputs = self.bert(
|
118 |
+
input_ids,
|
119 |
+
video_tokens,
|
120 |
+
attention_mask=attention_mask,
|
121 |
+
token_type_ids=token_type_ids,
|
122 |
+
position_ids=position_ids,
|
123 |
+
head_mask=head_mask,
|
124 |
+
inputs_embeds=inputs_embeds,
|
125 |
+
output_attentions=output_attentions,
|
126 |
+
output_hidden_states=output_hidden_states,
|
127 |
+
return_dict=return_dict,
|
128 |
+
separate_forward_split=separate_forward_split,
|
129 |
+
)
|
130 |
+
|
131 |
+
return (self.classifier(outputs[0]),)
|
132 |
+
|
133 |
+
|
134 |
+
# ------------ pre-training models ----------------
|
135 |
+
|
136 |
+
class MMBertForEncoder(BertPreTrainedModel):
|
137 |
+
"""A BertModel for Contrastive Learning."""
|
138 |
+
def __init__(self, config):
|
139 |
+
super().__init__(config)
|
140 |
+
self.videomlp = VideoTokenMLP(config)
|
141 |
+
self.bert = MMBertModel(config)
|
142 |
+
self.init_weights()
|
143 |
+
|
144 |
+
def forward(
|
145 |
+
self,
|
146 |
+
input_ids=None,
|
147 |
+
input_video_embeds=None,
|
148 |
+
attention_mask=None,
|
149 |
+
token_type_ids=None,
|
150 |
+
position_ids=None,
|
151 |
+
head_mask=None,
|
152 |
+
inputs_embeds=None,
|
153 |
+
output_attentions=None,
|
154 |
+
output_hidden_states=None,
|
155 |
+
return_dict=None,
|
156 |
+
):
|
157 |
+
return_dict = (
|
158 |
+
return_dict if return_dict is not None
|
159 |
+
else self.config.use_return_dict
|
160 |
+
)
|
161 |
+
if input_video_embeds is not None:
|
162 |
+
video_tokens = self.videomlp(input_video_embeds)
|
163 |
+
else:
|
164 |
+
video_tokens = None
|
165 |
+
|
166 |
+
outputs = self.bert(
|
167 |
+
input_ids,
|
168 |
+
video_tokens,
|
169 |
+
attention_mask=attention_mask,
|
170 |
+
token_type_ids=token_type_ids,
|
171 |
+
position_ids=position_ids,
|
172 |
+
head_mask=head_mask,
|
173 |
+
inputs_embeds=inputs_embeds,
|
174 |
+
output_attentions=output_attentions,
|
175 |
+
output_hidden_states=output_hidden_states,
|
176 |
+
return_dict=return_dict,
|
177 |
+
)
|
178 |
+
return outputs
|
179 |
+
|
180 |
+
|
181 |
+
class MMBertForMFMMLM(BertPreTrainedModel):
|
182 |
+
"""A BertModel with shared prediction head on MFM-MLM."""
|
183 |
+
def __init__(self, config):
|
184 |
+
super().__init__(config)
|
185 |
+
self.videomlp = VideoTokenMLP(config)
|
186 |
+
self.bert = MMBertModel(config)
|
187 |
+
self.cls = MFMMLMHead(config)
|
188 |
+
self.hidden_size = config.hidden_size
|
189 |
+
self.init_weights()
|
190 |
+
|
191 |
+
def get_output_embeddings(self):
|
192 |
+
return self.cls.predictions.decoder
|
193 |
+
|
194 |
+
def forward(
|
195 |
+
self,
|
196 |
+
input_ids=None,
|
197 |
+
input_video_embeds=None,
|
198 |
+
attention_mask=None,
|
199 |
+
token_type_ids=None,
|
200 |
+
position_ids=None,
|
201 |
+
head_mask=None,
|
202 |
+
inputs_embeds=None,
|
203 |
+
masked_frame_labels=None,
|
204 |
+
target_video_hidden_states=None,
|
205 |
+
non_masked_frame_mask=None,
|
206 |
+
masked_lm_labels=None,
|
207 |
+
output_attentions=None,
|
208 |
+
output_hidden_states=None,
|
209 |
+
return_dict=None,
|
210 |
+
):
|
211 |
+
return_dict = (
|
212 |
+
return_dict if return_dict is not None
|
213 |
+
else self.config.use_return_dict
|
214 |
+
)
|
215 |
+
if input_video_embeds is not None:
|
216 |
+
video_tokens = self.videomlp(input_video_embeds)
|
217 |
+
else:
|
218 |
+
video_tokens = None
|
219 |
+
|
220 |
+
if target_video_hidden_states is not None:
|
221 |
+
target_video_hidden_states = self.videomlp(
|
222 |
+
target_video_hidden_states)
|
223 |
+
|
224 |
+
non_masked_frame_hidden_states = video_tokens.masked_select(
|
225 |
+
non_masked_frame_mask.unsqueeze(-1)
|
226 |
+
).view(-1, self.hidden_size)
|
227 |
+
|
228 |
+
outputs = self.bert(
|
229 |
+
input_ids,
|
230 |
+
video_tokens,
|
231 |
+
attention_mask=attention_mask,
|
232 |
+
token_type_ids=token_type_ids,
|
233 |
+
position_ids=position_ids,
|
234 |
+
head_mask=head_mask,
|
235 |
+
inputs_embeds=inputs_embeds,
|
236 |
+
output_attentions=output_attentions,
|
237 |
+
output_hidden_states=output_hidden_states,
|
238 |
+
return_dict=return_dict,
|
239 |
+
)
|
240 |
+
|
241 |
+
sequence_output = outputs[0]
|
242 |
+
|
243 |
+
mfm_scores, prediction_scores = None, None
|
244 |
+
if masked_frame_labels is not None and masked_lm_labels is not None:
|
245 |
+
# split the sequence.
|
246 |
+
text_offset = masked_frame_labels.size(1) + 1 # [CLS]
|
247 |
+
video_sequence_output = sequence_output[
|
248 |
+
:, 1:text_offset
|
249 |
+
] # remove [SEP] as not in video_label.
|
250 |
+
text_sequence_output = torch.cat(
|
251 |
+
[sequence_output[:, :1], sequence_output[:, text_offset:]],
|
252 |
+
dim=1
|
253 |
+
)
|
254 |
+
|
255 |
+
hidden_size = video_sequence_output.size(-1)
|
256 |
+
selected_video_output = video_sequence_output.masked_select(
|
257 |
+
masked_frame_labels.unsqueeze(-1)
|
258 |
+
).view(-1, hidden_size)
|
259 |
+
|
260 |
+
# only compute select tokens to training to speed up.
|
261 |
+
hidden_size = text_sequence_output.size(-1)
|
262 |
+
# masked_lm_labels = masked_lm_labels.reshape(-1)
|
263 |
+
labels_mask = masked_lm_labels != -100
|
264 |
+
|
265 |
+
selected_text_output = text_sequence_output.masked_select(
|
266 |
+
labels_mask.unsqueeze(-1)
|
267 |
+
).view(-1, hidden_size)
|
268 |
+
mfm_scores, prediction_scores = self.cls(
|
269 |
+
selected_video_output,
|
270 |
+
target_video_hidden_states,
|
271 |
+
non_masked_frame_hidden_states,
|
272 |
+
selected_text_output,
|
273 |
+
)
|
274 |
+
|
275 |
+
output = (
|
276 |
+
mfm_scores,
|
277 |
+
prediction_scores,
|
278 |
+
) + outputs
|
279 |
+
return output
|
280 |
+
|
281 |
+
|
282 |
+
class BertMFMMLMPredictionHead(nn.Module):
|
283 |
+
def __init__(self, config):
|
284 |
+
super().__init__()
|
285 |
+
self.transform = BertPredictionHeadTransform(config)
|
286 |
+
# The output weights are the same as the input embeddings, but there is
|
287 |
+
# an output-only bias for each token.
|
288 |
+
self.decoder = nn.Linear(
|
289 |
+
config.hidden_size, config.vocab_size, bias=False)
|
290 |
+
|
291 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
292 |
+
|
293 |
+
# Need a link between the two variables so that the bias is correctly
|
294 |
+
# resized with `resize_token_embeddings`
|
295 |
+
self.decoder.bias = self.bias
|
296 |
+
|
297 |
+
def forward(
|
298 |
+
self,
|
299 |
+
video_hidden_states=None,
|
300 |
+
target_video_hidden_states=None,
|
301 |
+
non_masked_frame_hidden_states=None,
|
302 |
+
text_hidden_states=None,
|
303 |
+
):
|
304 |
+
video_logits, text_logits = None, None
|
305 |
+
if video_hidden_states is not None:
|
306 |
+
video_hidden_states = self.transform(video_hidden_states)
|
307 |
+
non_masked_frame_logits = torch.mm(
|
308 |
+
video_hidden_states,
|
309 |
+
non_masked_frame_hidden_states.transpose(1, 0)
|
310 |
+
)
|
311 |
+
masked_frame_logits = torch.bmm(
|
312 |
+
video_hidden_states.unsqueeze(1),
|
313 |
+
target_video_hidden_states.unsqueeze(-1),
|
314 |
+
).squeeze(-1)
|
315 |
+
video_logits = torch.cat(
|
316 |
+
[masked_frame_logits, non_masked_frame_logits], dim=1
|
317 |
+
)
|
318 |
+
|
319 |
+
if text_hidden_states is not None:
|
320 |
+
text_hidden_states = self.transform(text_hidden_states)
|
321 |
+
text_logits = self.decoder(text_hidden_states)
|
322 |
+
return video_logits, text_logits
|
323 |
+
|
324 |
+
|
325 |
+
class MFMMLMHead(nn.Module):
|
326 |
+
def __init__(self, config):
|
327 |
+
super().__init__()
|
328 |
+
self.predictions = BertMFMMLMPredictionHead(config)
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self,
|
332 |
+
video_hidden_states=None,
|
333 |
+
target_video_hidden_states=None,
|
334 |
+
non_masked_frame_hidden_states=None,
|
335 |
+
text_hidden_states=None,
|
336 |
+
):
|
337 |
+
video_logits, text_logits = self.predictions(
|
338 |
+
video_hidden_states,
|
339 |
+
target_video_hidden_states,
|
340 |
+
non_masked_frame_hidden_states,
|
341 |
+
text_hidden_states,
|
342 |
+
)
|
343 |
+
return video_logits, text_logits
|
344 |
+
|
345 |
+
|
346 |
+
class MMBertForMTM(MMBertForMFMMLM):
|
347 |
+
def __init__(self, config):
|
348 |
+
BertPreTrainedModel.__init__(self, config)
|
349 |
+
self.videomlp = VideoTokenMLP(config)
|
350 |
+
self.bert = MMBertModel(config)
|
351 |
+
self.cls = MTMHead(config)
|
352 |
+
self.hidden_size = config.hidden_size
|
353 |
+
self.init_weights()
|
354 |
+
|
355 |
+
|
356 |
+
class BertMTMPredictionHead(nn.Module):
|
357 |
+
def __init__(self, config):
|
358 |
+
super().__init__()
|
359 |
+
self.transform = BertPredictionHeadTransform(config)
|
360 |
+
self.decoder = nn.Linear(
|
361 |
+
config.hidden_size, config.vocab_size, bias=False)
|
362 |
+
|
363 |
+
def forward(
|
364 |
+
self,
|
365 |
+
video_hidden_states=None,
|
366 |
+
target_video_hidden_states=None,
|
367 |
+
non_masked_frame_hidden_states=None,
|
368 |
+
text_hidden_states=None,
|
369 |
+
):
|
370 |
+
non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
|
371 |
+
video_logits, text_logits = None, None
|
372 |
+
if video_hidden_states is not None:
|
373 |
+
video_hidden_states = self.transform(video_hidden_states)
|
374 |
+
|
375 |
+
masked_frame_logits = torch.bmm(
|
376 |
+
video_hidden_states.unsqueeze(1),
|
377 |
+
target_video_hidden_states.unsqueeze(-1),
|
378 |
+
).squeeze(-1)
|
379 |
+
|
380 |
+
non_masked_frame_logits = torch.mm(
|
381 |
+
video_hidden_states,
|
382 |
+
non_masked_frame_hidden_states
|
383 |
+
)
|
384 |
+
video_on_vocab_logits = self.decoder(video_hidden_states)
|
385 |
+
video_logits = torch.cat([
|
386 |
+
masked_frame_logits,
|
387 |
+
non_masked_frame_logits,
|
388 |
+
video_on_vocab_logits], dim=1)
|
389 |
+
|
390 |
+
if text_hidden_states is not None:
|
391 |
+
text_hidden_states = self.transform(text_hidden_states)
|
392 |
+
# text first so label does not need to be shifted.
|
393 |
+
text_on_vocab_logits = self.decoder(text_hidden_states)
|
394 |
+
text_on_video_logits = torch.mm(
|
395 |
+
text_hidden_states,
|
396 |
+
non_masked_frame_hidden_states
|
397 |
+
)
|
398 |
+
text_logits = torch.cat([
|
399 |
+
text_on_vocab_logits,
|
400 |
+
text_on_video_logits
|
401 |
+
], dim=1)
|
402 |
+
|
403 |
+
return video_logits, text_logits
|
404 |
+
|
405 |
+
|
406 |
+
class MTMHead(nn.Module):
|
407 |
+
def __init__(self, config):
|
408 |
+
super().__init__()
|
409 |
+
self.predictions = BertMTMPredictionHead(config)
|
410 |
+
|
411 |
+
def forward(
|
412 |
+
self,
|
413 |
+
video_hidden_states=None,
|
414 |
+
target_video_hidden_states=None,
|
415 |
+
non_masked_frame_hidden_states=None,
|
416 |
+
text_hidden_states=None,
|
417 |
+
):
|
418 |
+
video_logits, text_logits = self.predictions(
|
419 |
+
video_hidden_states,
|
420 |
+
target_video_hidden_states,
|
421 |
+
non_masked_frame_hidden_states,
|
422 |
+
text_hidden_states,
|
423 |
+
)
|
424 |
+
return video_logits, text_logits
|
425 |
+
|
426 |
+
|
427 |
+
class MMBertModel(BertModel):
|
428 |
+
"""MMBertModel has MMBertEmbedding to support video tokens."""
|
429 |
+
|
430 |
+
def __init__(self, config, add_pooling_layer=True):
|
431 |
+
super().__init__(config)
|
432 |
+
# overwrite embedding
|
433 |
+
self.embeddings = MMBertEmbeddings(config)
|
434 |
+
self.encoder = MultiLayerAttentionMaskBertEncoder(config)
|
435 |
+
self.init_weights()
|
436 |
+
|
437 |
+
def forward(
|
438 |
+
self,
|
439 |
+
input_ids=None,
|
440 |
+
input_video_embeds=None,
|
441 |
+
attention_mask=None,
|
442 |
+
token_type_ids=None,
|
443 |
+
position_ids=None,
|
444 |
+
head_mask=None,
|
445 |
+
inputs_embeds=None,
|
446 |
+
encoder_hidden_states=None,
|
447 |
+
encoder_attention_mask=None,
|
448 |
+
output_attentions=None,
|
449 |
+
output_hidden_states=None,
|
450 |
+
return_dict=None,
|
451 |
+
separate_forward_split=None,
|
452 |
+
):
|
453 |
+
output_attentions = (
|
454 |
+
output_attentions
|
455 |
+
if output_attentions is not None
|
456 |
+
else self.config.output_attentions
|
457 |
+
)
|
458 |
+
output_hidden_states = (
|
459 |
+
output_hidden_states
|
460 |
+
if output_hidden_states is not None
|
461 |
+
else self.config.output_hidden_states
|
462 |
+
)
|
463 |
+
return_dict = (
|
464 |
+
return_dict if return_dict is not None
|
465 |
+
else self.config.use_return_dict
|
466 |
+
)
|
467 |
+
|
468 |
+
if input_ids is not None and inputs_embeds is not None:
|
469 |
+
raise ValueError(
|
470 |
+
"You cannot specify both input_ids "
|
471 |
+
"and inputs_embeds at the same time"
|
472 |
+
)
|
473 |
+
elif input_ids is not None:
|
474 |
+
if input_video_embeds is not None:
|
475 |
+
input_shape = (
|
476 |
+
input_ids.size(0),
|
477 |
+
input_ids.size(1) + input_video_embeds.size(1),
|
478 |
+
)
|
479 |
+
else:
|
480 |
+
input_shape = (
|
481 |
+
input_ids.size(0),
|
482 |
+
input_ids.size(1),
|
483 |
+
)
|
484 |
+
elif inputs_embeds is not None:
|
485 |
+
if input_video_embeds is not None:
|
486 |
+
input_shape = (
|
487 |
+
inputs_embeds.size(0),
|
488 |
+
inputs_embeds.size(1) + input_video_embeds.size(1),
|
489 |
+
)
|
490 |
+
else:
|
491 |
+
input_shape = (
|
492 |
+
input_ids.size(0),
|
493 |
+
input_ids.size(1),
|
494 |
+
)
|
495 |
+
else:
|
496 |
+
raise ValueError(
|
497 |
+
"You have to specify either input_ids or inputs_embeds")
|
498 |
+
|
499 |
+
device = input_ids.device if input_ids is not None \
|
500 |
+
else inputs_embeds.device
|
501 |
+
|
502 |
+
if attention_mask is None:
|
503 |
+
attention_mask = torch.ones(input_shape, device=device)
|
504 |
+
if token_type_ids is None:
|
505 |
+
token_type_ids = torch.zeros(
|
506 |
+
input_shape, dtype=torch.long, device=device)
|
507 |
+
|
508 |
+
# We can provide a self-attention mask of dimensions
|
509 |
+
# [batch_size, from_seq_length, to_seq_length]
|
510 |
+
# ourselves in which case
|
511 |
+
# we just need to make it broadcastable to all heads.
|
512 |
+
extended_attention_mask: torch.Tensor = \
|
513 |
+
self.get_extended_attention_mask(
|
514 |
+
attention_mask, input_shape, device)
|
515 |
+
|
516 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
517 |
+
# we need to make broadcastable to
|
518 |
+
# [batch_size, num_heads, seq_length, seq_length]
|
519 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
520 |
+
(
|
521 |
+
encoder_batch_size,
|
522 |
+
encoder_sequence_length,
|
523 |
+
_,
|
524 |
+
) = encoder_hidden_states.size()
|
525 |
+
encoder_hidden_shape = (
|
526 |
+
encoder_batch_size, encoder_sequence_length)
|
527 |
+
if encoder_attention_mask is None:
|
528 |
+
encoder_attention_mask = torch.ones(
|
529 |
+
encoder_hidden_shape, device=device)
|
530 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
531 |
+
encoder_attention_mask
|
532 |
+
)
|
533 |
+
else:
|
534 |
+
encoder_extended_attention_mask = None
|
535 |
+
|
536 |
+
# Prepare head mask if needed
|
537 |
+
# 1.0 in head_mask indicate we keep the head
|
538 |
+
# attention_probs has shape bsz x n_heads x N x N
|
539 |
+
# input head_mask has shape [num_heads] or
|
540 |
+
# [num_hidden_layers x num_heads]
|
541 |
+
# and head_mask is converted to shape
|
542 |
+
# [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
543 |
+
|
544 |
+
head_mask = self.get_head_mask(
|
545 |
+
head_mask, self.config.num_hidden_layers)
|
546 |
+
|
547 |
+
embedding_output = self.embeddings(
|
548 |
+
input_ids,
|
549 |
+
input_video_embeds,
|
550 |
+
position_ids=position_ids,
|
551 |
+
token_type_ids=token_type_ids,
|
552 |
+
inputs_embeds=inputs_embeds,
|
553 |
+
)
|
554 |
+
|
555 |
+
if separate_forward_split is not None:
|
556 |
+
split_embedding_output = \
|
557 |
+
embedding_output[:, :separate_forward_split]
|
558 |
+
split_extended_attention_mask = extended_attention_mask[
|
559 |
+
:, :, :, :separate_forward_split, :separate_forward_split
|
560 |
+
]
|
561 |
+
split_encoder_outputs = self.encoder(
|
562 |
+
split_embedding_output,
|
563 |
+
attention_mask=split_extended_attention_mask,
|
564 |
+
head_mask=head_mask,
|
565 |
+
encoder_hidden_states=encoder_hidden_states,
|
566 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
567 |
+
output_attentions=output_attentions,
|
568 |
+
output_hidden_states=output_hidden_states,
|
569 |
+
return_dict=return_dict,
|
570 |
+
)
|
571 |
+
assert (
|
572 |
+
len(split_encoder_outputs) <= 2
|
573 |
+
), "we do not support merge on attention for now."
|
574 |
+
encoder_outputs = []
|
575 |
+
encoder_outputs.append([split_encoder_outputs[0]])
|
576 |
+
if len(split_encoder_outputs) == 2:
|
577 |
+
encoder_outputs.append([])
|
578 |
+
for _all_hidden_states in split_encoder_outputs[1]:
|
579 |
+
encoder_outputs[-1].append([_all_hidden_states])
|
580 |
+
|
581 |
+
split_embedding_output = \
|
582 |
+
embedding_output[:, separate_forward_split:]
|
583 |
+
split_extended_attention_mask = extended_attention_mask[
|
584 |
+
:, :, :, separate_forward_split:, separate_forward_split:
|
585 |
+
]
|
586 |
+
|
587 |
+
split_encoder_outputs = self.encoder(
|
588 |
+
split_embedding_output,
|
589 |
+
attention_mask=split_extended_attention_mask,
|
590 |
+
head_mask=head_mask,
|
591 |
+
encoder_hidden_states=encoder_hidden_states,
|
592 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
593 |
+
output_attentions=output_attentions,
|
594 |
+
output_hidden_states=output_hidden_states,
|
595 |
+
return_dict=return_dict,
|
596 |
+
)
|
597 |
+
|
598 |
+
assert (
|
599 |
+
len(split_encoder_outputs) <= 2
|
600 |
+
), "we do not support merge on attention for now."
|
601 |
+
encoder_outputs[0].append(split_encoder_outputs[0])
|
602 |
+
encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
|
603 |
+
if len(split_encoder_outputs) == 2:
|
604 |
+
for layer_idx, _all_hidden_states in enumerate(
|
605 |
+
split_encoder_outputs[1]
|
606 |
+
):
|
607 |
+
encoder_outputs[1][layer_idx].append(_all_hidden_states)
|
608 |
+
encoder_outputs[1][layer_idx] = torch.cat(
|
609 |
+
encoder_outputs[1][layer_idx], dim=1
|
610 |
+
)
|
611 |
+
encoder_outputs = tuple(encoder_outputs)
|
612 |
+
else:
|
613 |
+
encoder_outputs = self.encoder(
|
614 |
+
embedding_output,
|
615 |
+
attention_mask=extended_attention_mask,
|
616 |
+
head_mask=head_mask,
|
617 |
+
encoder_hidden_states=encoder_hidden_states,
|
618 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
619 |
+
output_attentions=output_attentions,
|
620 |
+
output_hidden_states=output_hidden_states,
|
621 |
+
return_dict=return_dict,
|
622 |
+
)
|
623 |
+
|
624 |
+
sequence_output = encoder_outputs[0]
|
625 |
+
pooled_output = (
|
626 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
627 |
+
)
|
628 |
+
|
629 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
630 |
+
|
631 |
+
def get_extended_attention_mask(self, attention_mask, input_shape, device):
|
632 |
+
"""This is borrowed from `modeling_utils.py` with the support of
|
633 |
+
multi-layer attention masks.
|
634 |
+
The second dim is expected to be number of layers.
|
635 |
+
See `MMAttentionMaskProcessor`.
|
636 |
+
Makes broadcastable attention and causal masks so that future
|
637 |
+
and masked tokens are ignored.
|
638 |
+
|
639 |
+
Arguments:
|
640 |
+
attention_mask (:obj:`torch.Tensor`):
|
641 |
+
Mask with ones indicating tokens to attend to,
|
642 |
+
zeros for tokens to ignore.
|
643 |
+
input_shape (:obj:`Tuple[int]`):
|
644 |
+
The shape of the input to the model.
|
645 |
+
device: (:obj:`torch.device`):
|
646 |
+
The device of the input to the model.
|
647 |
+
|
648 |
+
Returns:
|
649 |
+
:obj:`torch.Tensor` The extended attention mask, \
|
650 |
+
with a the same dtype as :obj:`attention_mask.dtype`.
|
651 |
+
"""
|
652 |
+
# We can provide a self-attention mask of dimensions
|
653 |
+
# [batch_size, from_seq_length, to_seq_length]
|
654 |
+
# ourselves in which case we just need to make it broadcastable
|
655 |
+
# to all heads.
|
656 |
+
if attention_mask.dim() == 4:
|
657 |
+
extended_attention_mask = attention_mask[:, :, None, :, :]
|
658 |
+
extended_attention_mask = extended_attention_mask.to(
|
659 |
+
dtype=self.dtype
|
660 |
+
) # fp16 compatibility
|
661 |
+
extended_attention_mask = (1.0 - extended_attention_mask) \
|
662 |
+
* -10000.0
|
663 |
+
return extended_attention_mask
|
664 |
+
else:
|
665 |
+
return super().get_extended_attention_mask(
|
666 |
+
attention_mask, input_shape, device
|
667 |
+
)
|
668 |
+
|
669 |
+
|
670 |
+
class MultiLayerAttentionMaskBertEncoder(BertEncoder):
|
671 |
+
"""extend BertEncoder with the capability of
|
672 |
+
multiple layers of attention mask."""
|
673 |
+
|
674 |
+
def forward(
|
675 |
+
self,
|
676 |
+
hidden_states,
|
677 |
+
attention_mask=None,
|
678 |
+
head_mask=None,
|
679 |
+
encoder_hidden_states=None,
|
680 |
+
encoder_attention_mask=None,
|
681 |
+
output_attentions=False,
|
682 |
+
output_hidden_states=False,
|
683 |
+
return_dict=False,
|
684 |
+
):
|
685 |
+
all_hidden_states = () if output_hidden_states else None
|
686 |
+
all_attentions = () if output_attentions else None
|
687 |
+
for i, layer_module in enumerate(self.layer):
|
688 |
+
if output_hidden_states:
|
689 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
690 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
691 |
+
|
692 |
+
layer_attention_mask = (
|
693 |
+
attention_mask[:, i, :, :, :]
|
694 |
+
if attention_mask.dim() == 5
|
695 |
+
else attention_mask
|
696 |
+
)
|
697 |
+
|
698 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
699 |
+
|
700 |
+
def create_custom_forward(module):
|
701 |
+
def custom_forward(*inputs):
|
702 |
+
return module(*inputs, output_attentions)
|
703 |
+
|
704 |
+
return custom_forward
|
705 |
+
|
706 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
707 |
+
create_custom_forward(layer_module),
|
708 |
+
hidden_states,
|
709 |
+
layer_attention_mask,
|
710 |
+
layer_head_mask,
|
711 |
+
encoder_hidden_states,
|
712 |
+
encoder_attention_mask,
|
713 |
+
)
|
714 |
+
else:
|
715 |
+
layer_outputs = layer_module(
|
716 |
+
hidden_states,
|
717 |
+
layer_attention_mask,
|
718 |
+
layer_head_mask,
|
719 |
+
encoder_hidden_states,
|
720 |
+
encoder_attention_mask,
|
721 |
+
output_attentions,
|
722 |
+
)
|
723 |
+
hidden_states = layer_outputs[0]
|
724 |
+
if output_attentions:
|
725 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
726 |
+
|
727 |
+
if output_hidden_states:
|
728 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
729 |
+
|
730 |
+
return tuple(
|
731 |
+
v
|
732 |
+
for v in [hidden_states, all_hidden_states, all_attentions]
|
733 |
+
if v is not None
|
734 |
+
)
|
fairseq/examples/MMPT/mmpt/modules/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
from .mm import *
|
6 |
+
|
7 |
+
try:
|
8 |
+
from .expmm import *
|
9 |
+
except ImportError:
|
10 |
+
pass
|
fairseq/examples/MMPT/mmpt/modules/mm.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
try:
|
24 |
+
from transformers.modeling_bert import (
|
25 |
+
BertEmbeddings,
|
26 |
+
ACT2FN,
|
27 |
+
)
|
28 |
+
except ImportError:
|
29 |
+
pass
|
30 |
+
|
31 |
+
|
32 |
+
class VideoTokenMLP(nn.Module):
|
33 |
+
def __init__(self, config):
|
34 |
+
super().__init__()
|
35 |
+
input_dim = config.input_dim if hasattr(config, "input_dim") else 512
|
36 |
+
self.linear1 = nn.Linear(input_dim, config.hidden_size)
|
37 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
38 |
+
self.activation = ACT2FN[config.hidden_act]
|
39 |
+
self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
|
40 |
+
|
41 |
+
def forward(self, hidden_states):
|
42 |
+
hidden_states = self.linear1(hidden_states)
|
43 |
+
hidden_states = self.activation(hidden_states)
|
44 |
+
hidden_states = self.LayerNorm(hidden_states)
|
45 |
+
hidden_states = self.linear2(hidden_states)
|
46 |
+
return hidden_states
|
47 |
+
|
48 |
+
|
49 |
+
class MMBertEmbeddings(BertEmbeddings):
|
50 |
+
def __init__(self, config):
|
51 |
+
super().__init__(config)
|
52 |
+
self.max_video_len = config.max_video_len
|
53 |
+
if hasattr(config, "use_seg_emb") and config.use_seg_emb:
|
54 |
+
"""the original VLM paper uses seg_embeddings for temporal space.
|
55 |
+
although not used it changed the randomness of initialization.
|
56 |
+
we keep it for reproducibility.
|
57 |
+
"""
|
58 |
+
self.seg_embeddings = nn.Embedding(256, config.hidden_size)
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids,
|
63 |
+
input_video_embeds,
|
64 |
+
token_type_ids=None,
|
65 |
+
position_ids=None,
|
66 |
+
inputs_embeds=None,
|
67 |
+
):
|
68 |
+
input_tensor = input_ids if input_ids is not None else inputs_embeds
|
69 |
+
if input_video_embeds is not None:
|
70 |
+
input_shape = (
|
71 |
+
input_tensor.size(0),
|
72 |
+
input_tensor.size(1) + input_video_embeds.size(1),
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
input_shape = (input_tensor.size(0), input_tensor.size(1))
|
76 |
+
|
77 |
+
if position_ids is None:
|
78 |
+
"""
|
79 |
+
Auto skip position embeddings for text only case.
|
80 |
+
use cases:
|
81 |
+
(1) action localization and segmentation:
|
82 |
+
feed in len-1 dummy video token needs text part to
|
83 |
+
skip input_video_embeds.size(1) for the right
|
84 |
+
position_ids for video [SEP] and rest text tokens.
|
85 |
+
(2) MMFusionShare for two forward passings:
|
86 |
+
in `forward_text`: input_video_embeds is None.
|
87 |
+
need to skip video [SEP] token.
|
88 |
+
|
89 |
+
# video_len + 1: [CLS] + video_embed
|
90 |
+
# self.max_video_len + 1: [SEP] for video.
|
91 |
+
# self.max_video_len + 2: [SEP] for video.
|
92 |
+
# self.max_video_len + input_ids.size(1): rest for text.
|
93 |
+
"""
|
94 |
+
if input_video_embeds is not None:
|
95 |
+
video_len = input_video_embeds.size(1)
|
96 |
+
starting_offset = self.max_video_len + 1 # video [SEP]
|
97 |
+
ending_offset = self.max_video_len + input_ids.size(1)
|
98 |
+
else:
|
99 |
+
video_len = 0
|
100 |
+
starting_offset = self.max_video_len + 2 # first text token.
|
101 |
+
ending_offset = self.max_video_len + input_ids.size(1) + 1
|
102 |
+
position_ids = torch.cat([
|
103 |
+
self.position_ids[:, :video_len + 1],
|
104 |
+
self.position_ids[:, starting_offset:ending_offset]
|
105 |
+
], dim=1)
|
106 |
+
|
107 |
+
if token_type_ids is None:
|
108 |
+
token_type_ids = torch.zeros(
|
109 |
+
input_shape, dtype=torch.long, device=self.position_ids.device
|
110 |
+
)
|
111 |
+
|
112 |
+
"""
|
113 |
+
the format of input_ids is [CLS] [SEP] caption [SEP] padding.
|
114 |
+
the goal is to build [CLS] video tokens [SEP] caption [SEP] .
|
115 |
+
"""
|
116 |
+
if inputs_embeds is None:
|
117 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
118 |
+
if input_video_embeds is not None:
|
119 |
+
inputs_mm_embeds = torch.cat([
|
120 |
+
inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
|
121 |
+
], dim=1)
|
122 |
+
else:
|
123 |
+
# text only for `MMFusionShare`.
|
124 |
+
inputs_mm_embeds = inputs_embeds
|
125 |
+
|
126 |
+
position_embeddings = self.position_embeddings(position_ids)
|
127 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
128 |
+
embeddings = inputs_mm_embeds + position_embeddings
|
129 |
+
embeddings += token_type_embeddings
|
130 |
+
|
131 |
+
embeddings = self.LayerNorm(embeddings)
|
132 |
+
embeddings = self.dropout(embeddings)
|
133 |
+
return embeddings
|
134 |
+
|
135 |
+
|
136 |
+
class AlignHead(nn.Module):
|
137 |
+
"""this will load pre-trained weights for NSP, which is desirable."""
|
138 |
+
|
139 |
+
def __init__(self, config):
|
140 |
+
super().__init__()
|
141 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
142 |
+
|
143 |
+
def forward(self, dropout_pooled_output):
|
144 |
+
logits = self.seq_relationship(dropout_pooled_output)
|
145 |
+
return logits
|
fairseq/examples/MMPT/mmpt/modules/retri.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import pickle
|
8 |
+
import time
|
9 |
+
|
10 |
+
try:
|
11 |
+
import faiss
|
12 |
+
except ImportError:
|
13 |
+
pass
|
14 |
+
|
15 |
+
from collections import defaultdict
|
16 |
+
|
17 |
+
from ..utils import get_local_rank, print_on_rank0
|
18 |
+
|
19 |
+
|
20 |
+
class VectorRetriever(object):
|
21 |
+
"""
|
22 |
+
How2 Video Retriver.
|
23 |
+
Reference usage of FAISS:
|
24 |
+
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
|
28 |
+
if db_type == "flatl2":
|
29 |
+
quantizer = faiss.IndexFlatL2(hidden_size) # the other index
|
30 |
+
self.db = faiss.IndexIVFFlat(
|
31 |
+
quantizer, hidden_size, cent, faiss.METRIC_L2)
|
32 |
+
elif db_type == "pq":
|
33 |
+
self.db = faiss.index_factory(
|
34 |
+
hidden_size, f"IVF{cent}_HNSW32,PQ32"
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
raise ValueError("unknown type of db", db_type)
|
38 |
+
self.train_thres = cent * examples_per_cent_to_train
|
39 |
+
self.train_cache = []
|
40 |
+
self.train_len = 0
|
41 |
+
self.videoid_to_vectoridx = {}
|
42 |
+
self.vectoridx_to_videoid = None
|
43 |
+
self.make_direct_maps_done = False
|
44 |
+
|
45 |
+
def make_direct_maps(self):
|
46 |
+
faiss.downcast_index(self.db).make_direct_map()
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return self.db.ntotal
|
50 |
+
|
51 |
+
def save(self, out_dir):
|
52 |
+
faiss.write_index(
|
53 |
+
self.db,
|
54 |
+
os.path.join(out_dir, "faiss_idx")
|
55 |
+
)
|
56 |
+
with open(
|
57 |
+
os.path.join(
|
58 |
+
out_dir, "videoid_to_vectoridx.pkl"),
|
59 |
+
"wb") as fw:
|
60 |
+
pickle.dump(
|
61 |
+
self.videoid_to_vectoridx, fw,
|
62 |
+
protocol=pickle.HIGHEST_PROTOCOL
|
63 |
+
)
|
64 |
+
|
65 |
+
def load(self, out_dir):
|
66 |
+
fn = os.path.join(out_dir, "faiss_idx")
|
67 |
+
self.db = faiss.read_index(fn)
|
68 |
+
with open(
|
69 |
+
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
|
70 |
+
self.videoid_to_vectoridx = pickle.load(fr)
|
71 |
+
|
72 |
+
def add(self, hidden_states, video_ids, last=False):
|
73 |
+
assert len(hidden_states) == len(video_ids), "{}, {}".format(
|
74 |
+
str(len(hidden_states)), str(len(video_ids)))
|
75 |
+
assert len(hidden_states.shape) == 2
|
76 |
+
assert hidden_states.dtype == np.float32
|
77 |
+
|
78 |
+
valid_idx = []
|
79 |
+
for idx, video_id in enumerate(video_ids):
|
80 |
+
if video_id not in self.videoid_to_vectoridx:
|
81 |
+
valid_idx.append(idx)
|
82 |
+
self.videoid_to_vectoridx[video_id] = \
|
83 |
+
len(self.videoid_to_vectoridx)
|
84 |
+
|
85 |
+
hidden_states = hidden_states[valid_idx]
|
86 |
+
if not self.db.is_trained:
|
87 |
+
self.train_cache.append(hidden_states)
|
88 |
+
self.train_len += hidden_states.shape[0]
|
89 |
+
if self.train_len < self.train_thres:
|
90 |
+
return
|
91 |
+
self.finalize_training()
|
92 |
+
else:
|
93 |
+
self.db.add(hidden_states)
|
94 |
+
|
95 |
+
def finalize_training(self):
|
96 |
+
hidden_states = np.concatenate(self.train_cache, axis=0)
|
97 |
+
del self.train_cache
|
98 |
+
local_rank = get_local_rank()
|
99 |
+
if local_rank == 0:
|
100 |
+
start = time.time()
|
101 |
+
print("training db on", self.train_thres, "/", self.train_len)
|
102 |
+
self.db.train(hidden_states[:self.train_thres])
|
103 |
+
if local_rank == 0:
|
104 |
+
print("training db for", time.time() - start)
|
105 |
+
self.db.add(hidden_states)
|
106 |
+
|
107 |
+
def search(
|
108 |
+
self,
|
109 |
+
query_hidden_states,
|
110 |
+
orig_dist,
|
111 |
+
):
|
112 |
+
if len(self.videoid_to_vectoridx) != self.db.ntotal:
|
113 |
+
raise ValueError(
|
114 |
+
"cannot search: size mismatch in-between index and db",
|
115 |
+
len(self.videoid_to_vectoridx),
|
116 |
+
self.db.ntotal
|
117 |
+
)
|
118 |
+
|
119 |
+
if self.vectoridx_to_videoid is None:
|
120 |
+
self.vectoridx_to_videoid = {
|
121 |
+
self.videoid_to_vectoridx[videoid]: videoid
|
122 |
+
for videoid in self.videoid_to_vectoridx
|
123 |
+
}
|
124 |
+
assert len(self.vectoridx_to_videoid) \
|
125 |
+
== len(self.videoid_to_vectoridx)
|
126 |
+
|
127 |
+
# MultilingualFaissDataset uses the following; not sure the purpose.
|
128 |
+
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
|
129 |
+
queried_dist, index = self.db.search(query_hidden_states, 1)
|
130 |
+
queried_dist, index = queried_dist[:, 0], index[:, 0]
|
131 |
+
|
132 |
+
outputs = np.array(
|
133 |
+
[self.vectoridx_to_videoid[_index]
|
134 |
+
if _index != -1 else (-1, -1, -1) for _index in index],
|
135 |
+
dtype=np.int32)
|
136 |
+
outputs[queried_dist <= orig_dist] = -1
|
137 |
+
return outputs
|
138 |
+
|
139 |
+
def search_by_video_ids(
|
140 |
+
self,
|
141 |
+
video_ids,
|
142 |
+
retri_factor
|
143 |
+
):
|
144 |
+
if len(self.videoid_to_vectoridx) != self.db.ntotal:
|
145 |
+
raise ValueError(
|
146 |
+
len(self.videoid_to_vectoridx),
|
147 |
+
self.db.ntotal
|
148 |
+
)
|
149 |
+
|
150 |
+
if not self.make_direct_maps_done:
|
151 |
+
self.make_direct_maps()
|
152 |
+
|
153 |
+
if self.vectoridx_to_videoid is None:
|
154 |
+
self.vectoridx_to_videoid = {
|
155 |
+
self.videoid_to_vectoridx[videoid]: videoid
|
156 |
+
for videoid in self.videoid_to_vectoridx
|
157 |
+
}
|
158 |
+
assert len(self.vectoridx_to_videoid) \
|
159 |
+
== len(self.videoid_to_vectoridx)
|
160 |
+
|
161 |
+
query_hidden_states = []
|
162 |
+
vector_ids = []
|
163 |
+
for video_id in video_ids:
|
164 |
+
vector_id = self.videoid_to_vectoridx[video_id]
|
165 |
+
vector_ids.append(vector_id)
|
166 |
+
query_hidden_state = self.db.reconstruct(vector_id)
|
167 |
+
query_hidden_states.append(query_hidden_state)
|
168 |
+
query_hidden_states = np.stack(query_hidden_states)
|
169 |
+
|
170 |
+
# MultilingualFaissDataset uses the following; not sure the reason.
|
171 |
+
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
|
172 |
+
_, index = self.db.search(query_hidden_states, retri_factor)
|
173 |
+
outputs = []
|
174 |
+
for sample_idx, sample in enumerate(index):
|
175 |
+
# the first video_id is always the video itself.
|
176 |
+
cands = [video_ids[sample_idx]]
|
177 |
+
for vector_idx in sample:
|
178 |
+
if vector_idx >= 0 \
|
179 |
+
and vector_ids[sample_idx] != vector_idx:
|
180 |
+
cands.append(
|
181 |
+
self.vectoridx_to_videoid[vector_idx]
|
182 |
+
)
|
183 |
+
outputs.append(cands)
|
184 |
+
return outputs
|
185 |
+
|
186 |
+
|
187 |
+
class VectorRetrieverDM(VectorRetriever):
|
188 |
+
"""
|
189 |
+
with direct map.
|
190 |
+
How2 Video Retriver.
|
191 |
+
Reference usage of FAISS:
|
192 |
+
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
hidden_size,
|
198 |
+
cent,
|
199 |
+
db_type,
|
200 |
+
examples_per_cent_to_train
|
201 |
+
):
|
202 |
+
super().__init__(
|
203 |
+
hidden_size, cent, db_type, examples_per_cent_to_train)
|
204 |
+
self.make_direct_maps_done = False
|
205 |
+
|
206 |
+
def make_direct_maps(self):
|
207 |
+
faiss.downcast_index(self.db).make_direct_map()
|
208 |
+
self.make_direct_maps_done = True
|
209 |
+
|
210 |
+
def search(
|
211 |
+
self,
|
212 |
+
query_hidden_states,
|
213 |
+
orig_dist,
|
214 |
+
):
|
215 |
+
if len(self.videoid_to_vectoridx) != self.db.ntotal:
|
216 |
+
raise ValueError(
|
217 |
+
len(self.videoid_to_vectoridx),
|
218 |
+
self.db.ntotal
|
219 |
+
)
|
220 |
+
|
221 |
+
if not self.make_direct_maps_done:
|
222 |
+
self.make_direct_maps()
|
223 |
+
if self.vectoridx_to_videoid is None:
|
224 |
+
self.vectoridx_to_videoid = {
|
225 |
+
self.videoid_to_vectoridx[videoid]: videoid
|
226 |
+
for videoid in self.videoid_to_vectoridx
|
227 |
+
}
|
228 |
+
assert len(self.vectoridx_to_videoid) \
|
229 |
+
== len(self.videoid_to_vectoridx)
|
230 |
+
|
231 |
+
# MultilingualFaissDataset uses the following; not sure the reason.
|
232 |
+
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
|
233 |
+
queried_dist, index = self.db.search(query_hidden_states, 1)
|
234 |
+
outputs = []
|
235 |
+
for sample_idx, sample in enumerate(index):
|
236 |
+
# and queried_dist[sample_idx] < thres \
|
237 |
+
if sample >= 0 \
|
238 |
+
and queried_dist[sample_idx] < orig_dist[sample_idx]:
|
239 |
+
outputs.append(self.vectoridx_to_videoid[sample])
|
240 |
+
else:
|
241 |
+
outputs.append(None)
|
242 |
+
return outputs
|
243 |
+
|
244 |
+
def search_by_video_ids(
|
245 |
+
self,
|
246 |
+
video_ids,
|
247 |
+
retri_factor=8
|
248 |
+
):
|
249 |
+
if len(self.videoid_to_vectoridx) != self.db.ntotal:
|
250 |
+
raise ValueError(
|
251 |
+
len(self.videoid_to_vectoridx),
|
252 |
+
self.db.ntotal
|
253 |
+
)
|
254 |
+
|
255 |
+
if not self.make_direct_maps_done:
|
256 |
+
self.make_direct_maps()
|
257 |
+
if self.vectoridx_to_videoid is None:
|
258 |
+
self.vectoridx_to_videoid = {
|
259 |
+
self.videoid_to_vectoridx[videoid]: videoid
|
260 |
+
for videoid in self.videoid_to_vectoridx
|
261 |
+
}
|
262 |
+
assert len(self.vectoridx_to_videoid) \
|
263 |
+
== len(self.videoid_to_vectoridx)
|
264 |
+
|
265 |
+
query_hidden_states = []
|
266 |
+
vector_ids = []
|
267 |
+
for video_id in video_ids:
|
268 |
+
vector_id = self.videoid_to_vectoridx[video_id]
|
269 |
+
vector_ids.append(vector_id)
|
270 |
+
query_hidden_state = self.db.reconstruct(vector_id)
|
271 |
+
query_hidden_states.append(query_hidden_state)
|
272 |
+
query_hidden_states = np.stack(query_hidden_states)
|
273 |
+
|
274 |
+
# MultilingualFaissDataset uses the following; not sure the reason.
|
275 |
+
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
|
276 |
+
_, index = self.db.search(query_hidden_states, retri_factor)
|
277 |
+
outputs = []
|
278 |
+
for sample_idx, sample in enumerate(index):
|
279 |
+
# the first video_id is always the video itself.
|
280 |
+
cands = [video_ids[sample_idx]]
|
281 |
+
for vector_idx in sample:
|
282 |
+
if vector_idx >= 0 \
|
283 |
+
and vector_ids[sample_idx] != vector_idx:
|
284 |
+
cands.append(
|
285 |
+
self.vectoridx_to_videoid[vector_idx]
|
286 |
+
)
|
287 |
+
outputs.append(cands)
|
288 |
+
return outputs
|
289 |
+
|
290 |
+
|
291 |
+
class MMVectorRetriever(VectorRetrieverDM):
|
292 |
+
"""
|
293 |
+
multimodal vector retriver:
|
294 |
+
text retrieve video or video retrieve text.
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
|
298 |
+
super().__init__(
|
299 |
+
hidden_size, cent, db_type, examples_per_cent_to_train)
|
300 |
+
video_db = self.db
|
301 |
+
super().__init__(
|
302 |
+
hidden_size, cent, db_type, examples_per_cent_to_train)
|
303 |
+
text_db = self.db
|
304 |
+
self.db = {"video": video_db, "text": text_db}
|
305 |
+
self.video_to_videoid = defaultdict(list)
|
306 |
+
|
307 |
+
def __len__(self):
|
308 |
+
assert self.db["video"].ntotal == self.db["text"].ntotal
|
309 |
+
return self.db["video"].ntotal
|
310 |
+
|
311 |
+
def make_direct_maps(self):
|
312 |
+
faiss.downcast_index(self.db["video"]).make_direct_map()
|
313 |
+
faiss.downcast_index(self.db["text"]).make_direct_map()
|
314 |
+
|
315 |
+
def save(self, out_dir):
|
316 |
+
faiss.write_index(
|
317 |
+
self.db["video"],
|
318 |
+
os.path.join(out_dir, "video_faiss_idx")
|
319 |
+
)
|
320 |
+
faiss.write_index(
|
321 |
+
self.db["text"],
|
322 |
+
os.path.join(out_dir, "text_faiss_idx")
|
323 |
+
)
|
324 |
+
|
325 |
+
with open(
|
326 |
+
os.path.join(
|
327 |
+
out_dir, "videoid_to_vectoridx.pkl"),
|
328 |
+
"wb") as fw:
|
329 |
+
pickle.dump(
|
330 |
+
self.videoid_to_vectoridx, fw,
|
331 |
+
protocol=pickle.HIGHEST_PROTOCOL
|
332 |
+
)
|
333 |
+
|
334 |
+
def load(self, out_dir):
|
335 |
+
fn = os.path.join(out_dir, "video_faiss_idx")
|
336 |
+
video_db = faiss.read_index(fn)
|
337 |
+
fn = os.path.join(out_dir, "text_faiss_idx")
|
338 |
+
text_db = faiss.read_index(fn)
|
339 |
+
self.db = {"video": video_db, "text": text_db}
|
340 |
+
with open(
|
341 |
+
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
|
342 |
+
self.videoid_to_vectoridx = pickle.load(fr)
|
343 |
+
self.video_to_videoid = defaultdict(list)
|
344 |
+
|
345 |
+
def add(self, hidden_states, video_ids):
|
346 |
+
"""hidden_states is a pair `(video, text)`"""
|
347 |
+
assert len(hidden_states) == len(video_ids), "{}, {}".format(
|
348 |
+
str(len(hidden_states)), str(len(video_ids)))
|
349 |
+
assert len(hidden_states.shape) == 3
|
350 |
+
assert len(self.video_to_videoid) == 0
|
351 |
+
|
352 |
+
valid_idx = []
|
353 |
+
for idx, video_id in enumerate(video_ids):
|
354 |
+
if video_id not in self.videoid_to_vectoridx:
|
355 |
+
valid_idx.append(idx)
|
356 |
+
self.videoid_to_vectoridx[video_id] = \
|
357 |
+
len(self.videoid_to_vectoridx)
|
358 |
+
|
359 |
+
batch_size = hidden_states.shape[0]
|
360 |
+
hidden_states = hidden_states[valid_idx]
|
361 |
+
|
362 |
+
hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy()
|
363 |
+
if not self.db["video"].is_trained:
|
364 |
+
self.train_cache.append(hidden_states)
|
365 |
+
train_len = batch_size * len(self.train_cache)
|
366 |
+
if train_len < self.train_thres:
|
367 |
+
return
|
368 |
+
|
369 |
+
hidden_states = np.concatenate(self.train_cache, axis=1)
|
370 |
+
del self.train_cache
|
371 |
+
self.db["video"].train(hidden_states[0, :self.train_thres])
|
372 |
+
self.db["text"].train(hidden_states[1, :self.train_thres])
|
373 |
+
self.db["video"].add(hidden_states[0])
|
374 |
+
self.db["text"].add(hidden_states[1])
|
375 |
+
|
376 |
+
def get_clips_by_video_id(self, video_id):
|
377 |
+
if not self.video_to_videoid:
|
378 |
+
for video_id, video_clip, text_clip in self.videoid_to_vectoridx:
|
379 |
+
self.video_to_videoid[video_id].append(
|
380 |
+
(video_id, video_clip, text_clip))
|
381 |
+
return self.video_to_videoid[video_id]
|
382 |
+
|
383 |
+
def search(
|
384 |
+
self,
|
385 |
+
video_ids,
|
386 |
+
target_modality,
|
387 |
+
retri_factor=8
|
388 |
+
):
|
389 |
+
if len(self.videoid_to_vectoridx) != len(self):
|
390 |
+
raise ValueError(
|
391 |
+
len(self.videoid_to_vectoridx),
|
392 |
+
len(self)
|
393 |
+
)
|
394 |
+
|
395 |
+
if not self.make_direct_maps_done:
|
396 |
+
self.make_direct_maps()
|
397 |
+
if self.vectoridx_to_videoid is None:
|
398 |
+
self.vectoridx_to_videoid = {
|
399 |
+
self.videoid_to_vectoridx[videoid]: videoid
|
400 |
+
for videoid in self.videoid_to_vectoridx
|
401 |
+
}
|
402 |
+
assert len(self.vectoridx_to_videoid) \
|
403 |
+
== len(self.videoid_to_vectoridx)
|
404 |
+
|
405 |
+
src_modality = "text" if target_modality == "video" else "video"
|
406 |
+
|
407 |
+
query_hidden_states = []
|
408 |
+
vector_ids = []
|
409 |
+
for video_id in video_ids:
|
410 |
+
vector_id = self.videoid_to_vectoridx[video_id]
|
411 |
+
vector_ids.append(vector_id)
|
412 |
+
query_hidden_state = self.db[src_modality].reconstruct(vector_id)
|
413 |
+
query_hidden_states.append(query_hidden_state)
|
414 |
+
query_hidden_states = np.stack(query_hidden_states)
|
415 |
+
|
416 |
+
# MultilingualFaissDataset uses the following; not sure the reason.
|
417 |
+
# faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
|
418 |
+
_, index = self.db[target_modality].search(
|
419 |
+
query_hidden_states, retri_factor)
|
420 |
+
outputs = []
|
421 |
+
for sample_idx, sample in enumerate(index):
|
422 |
+
cands = []
|
423 |
+
for vector_idx in sample:
|
424 |
+
if vector_idx >= 0:
|
425 |
+
cands.append(
|
426 |
+
self.vectoridx_to_videoid[vector_idx]
|
427 |
+
)
|
428 |
+
outputs.append(cands)
|
429 |
+
return outputs
|
fairseq/examples/MMPT/mmpt/modules/vectorpool.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pickle
|
7 |
+
|
8 |
+
from . import retri
|
9 |
+
from ..utils import get_local_rank
|
10 |
+
|
11 |
+
|
12 |
+
class VectorPool(object):
|
13 |
+
"""
|
14 |
+
Base class of retrieval space.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, config):
|
18 |
+
from transformers import AutoConfig
|
19 |
+
self.hidden_size = AutoConfig.from_pretrained(
|
20 |
+
config.dataset.bert_name).hidden_size
|
21 |
+
self.retriever_cls = getattr(retri, config.retriever_cls)
|
22 |
+
|
23 |
+
def __call__(self, sample, **kwargs):
|
24 |
+
raise NotImplementedError
|
25 |
+
|
26 |
+
def build_retriver(
|
27 |
+
self,
|
28 |
+
retriever_cls=None,
|
29 |
+
hidden_size=None,
|
30 |
+
centroids=512,
|
31 |
+
db_type="flatl2",
|
32 |
+
examples_per_cent_to_train=48
|
33 |
+
):
|
34 |
+
|
35 |
+
"""merge results from multiple gpus and return a retriver.."""
|
36 |
+
self.retriver = retriever_cls(
|
37 |
+
hidden_size, centroids, db_type, examples_per_cent_to_train)
|
38 |
+
return self.retriver
|
39 |
+
|
40 |
+
def __repr__(self):
|
41 |
+
if hasattr(self, "retriver"):
|
42 |
+
retriver_name = str(len(self.retriver))
|
43 |
+
else:
|
44 |
+
retriver_name = "no retriver field yet"
|
45 |
+
return self.__class__.__name__ \
|
46 |
+
+ "(" + retriver_name + ")"
|
47 |
+
|
48 |
+
|
49 |
+
class VideoVectorPool(VectorPool):
|
50 |
+
"""
|
51 |
+
average clips of a video as video representation.
|
52 |
+
"""
|
53 |
+
def __init__(self, config):
|
54 |
+
super().__init__(config)
|
55 |
+
self.build_retriver(self.retriever_cls, self.hidden_size)
|
56 |
+
|
57 |
+
def __call__(self, sample, subsampling, **kwargs):
|
58 |
+
hidden_states = (
|
59 |
+
sample["pooled_video"] + sample["pooled_text"]) / 2.
|
60 |
+
hidden_states = hidden_states.view(
|
61 |
+
-1, subsampling,
|
62 |
+
hidden_states.size(-1))
|
63 |
+
hidden_states = torch.mean(hidden_states, dim=1)
|
64 |
+
hidden_states = hidden_states.cpu().detach().numpy()
|
65 |
+
video_ids = []
|
66 |
+
for offset_idx, video_id in enumerate(sample["video_id"]):
|
67 |
+
if isinstance(video_id, tuple) and len(video_id) == 3:
|
68 |
+
# a sharded video_id.
|
69 |
+
video_id = video_id[0]
|
70 |
+
video_ids.append(video_id)
|
71 |
+
assert len(video_ids) == len(hidden_states)
|
72 |
+
self.retriver.add(
|
73 |
+
hidden_states.astype("float32"),
|
74 |
+
video_ids
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
class DistributedVectorPool(VectorPool):
|
79 |
+
"""
|
80 |
+
support sync of multiple gpus/nodes.
|
81 |
+
"""
|
82 |
+
def __init__(self, config):
|
83 |
+
super().__init__(config)
|
84 |
+
self.out_dir = os.path.join(
|
85 |
+
config.fairseq.checkpoint.save_dir,
|
86 |
+
"retri")
|
87 |
+
os.makedirs(self.out_dir, exist_ok=True)
|
88 |
+
self.hidden_states = []
|
89 |
+
self.video_ids = []
|
90 |
+
|
91 |
+
def build_retriver(
|
92 |
+
self,
|
93 |
+
retriever_cls=None,
|
94 |
+
hidden_size=None,
|
95 |
+
centroids=4096,
|
96 |
+
db_type="flatl2",
|
97 |
+
examples_per_cent_to_train=48
|
98 |
+
):
|
99 |
+
if retriever_cls is None:
|
100 |
+
retriever_cls = self.retriever_cls
|
101 |
+
if hidden_size is None:
|
102 |
+
hidden_size = self.hidden_size
|
103 |
+
"""merge results from multiple gpus and return a retriver.."""
|
104 |
+
if torch.distributed.is_initialized():
|
105 |
+
self.save()
|
106 |
+
# sync saving.
|
107 |
+
torch.distributed.barrier()
|
108 |
+
world_size = torch.distributed.get_world_size()
|
109 |
+
else:
|
110 |
+
world_size = 1
|
111 |
+
self.retriver = retriever_cls(
|
112 |
+
hidden_size, centroids, db_type, examples_per_cent_to_train)
|
113 |
+
# each gpu process has its own retriever.
|
114 |
+
for local_rank in range(world_size):
|
115 |
+
if get_local_rank() == 0:
|
116 |
+
print("load local_rank", local_rank)
|
117 |
+
hidden_states, video_ids = self.load(local_rank)
|
118 |
+
hidden_states = hidden_states.astype("float32")
|
119 |
+
self.retriver.add(hidden_states, video_ids)
|
120 |
+
return self.retriver
|
121 |
+
|
122 |
+
def load(self, local_rank):
|
123 |
+
hidden_states = np.load(
|
124 |
+
os.path.join(
|
125 |
+
self.out_dir,
|
126 |
+
"hidden_state" + str(local_rank) + ".npy"
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
with open(
|
131 |
+
os.path.join(
|
132 |
+
self.out_dir, "video_id" + str(local_rank) + ".pkl"),
|
133 |
+
"rb") as fr:
|
134 |
+
video_ids = pickle.load(fr)
|
135 |
+
return hidden_states, video_ids
|
136 |
+
|
137 |
+
def save(self):
|
138 |
+
hidden_states = np.vstack(self.hidden_states)
|
139 |
+
assert len(hidden_states) == len(self.video_ids), "{}, {}".format(
|
140 |
+
len(hidden_states),
|
141 |
+
len(self.video_ids)
|
142 |
+
)
|
143 |
+
local_rank = torch.distributed.get_rank() \
|
144 |
+
if torch.distributed.is_initialized() else 0
|
145 |
+
|
146 |
+
np.save(
|
147 |
+
os.path.join(
|
148 |
+
self.out_dir,
|
149 |
+
"hidden_state" + str(local_rank) + ".npy"),
|
150 |
+
hidden_states)
|
151 |
+
|
152 |
+
with open(
|
153 |
+
os.path.join(
|
154 |
+
self.out_dir,
|
155 |
+
"video_id" + str(local_rank) + ".pkl"),
|
156 |
+
"wb") as fw:
|
157 |
+
pickle.dump(
|
158 |
+
self.video_ids,
|
159 |
+
fw,
|
160 |
+
protocol=pickle.HIGHEST_PROTOCOL
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
class DistributedVideoVectorPool(DistributedVectorPool):
|
165 |
+
"""
|
166 |
+
average clips of a video as video representation.
|
167 |
+
"""
|
168 |
+
def __call__(self, sample, subsampling, **kwargs):
|
169 |
+
hidden_states = (
|
170 |
+
sample["pooled_video"] + sample["pooled_text"]) / 2.
|
171 |
+
hidden_states = hidden_states.view(
|
172 |
+
-1, subsampling,
|
173 |
+
hidden_states.size(-1))
|
174 |
+
hidden_states = torch.mean(hidden_states, dim=1)
|
175 |
+
hidden_states = hidden_states.cpu().detach().numpy()
|
176 |
+
video_ids = []
|
177 |
+
for offset_idx, video_id in enumerate(sample["video_id"]):
|
178 |
+
if isinstance(video_id, tuple) and len(video_id) == 3:
|
179 |
+
# a sharded video_id.
|
180 |
+
video_id = video_id[0]
|
181 |
+
video_ids.append(video_id)
|
182 |
+
assert len(video_ids) == len(hidden_states)
|
183 |
+
self.hidden_states.append(hidden_states)
|
184 |
+
self.video_ids.extend(video_ids)
|
185 |
+
|
186 |
+
|
187 |
+
# ------------ the following are deprecated --------------
|
188 |
+
|
189 |
+
class TextClipVectorPool(VectorPool):
|
190 |
+
def __init__(self, config):
|
191 |
+
from transformers import AutoConfig
|
192 |
+
hidden_size = AutoConfig.from_pretrained(
|
193 |
+
config.dataset.bert_name).hidden_size
|
194 |
+
retriever_cls = getattr(retri, config.retriever_cls)
|
195 |
+
self.build_retriver(retriever_cls, hidden_size)
|
196 |
+
|
197 |
+
def __call__(self, sample, **kwargs):
|
198 |
+
clip_meta = sample["clip_meta"].cpu()
|
199 |
+
assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5]))
|
200 |
+
text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]]
|
201 |
+
|
202 |
+
if hasattr(self, "retriver"):
|
203 |
+
# build_retriver is called.
|
204 |
+
self.retriver.add(
|
205 |
+
sample["pooled_text"].cpu().numpy().astype("float32"),
|
206 |
+
text_meta
|
207 |
+
)
|
208 |
+
else:
|
209 |
+
raise NotImplementedError
|
210 |
+
|
211 |
+
|
212 |
+
class MMClipVectorPool(VectorPool):
|
213 |
+
"""
|
214 |
+
Multimodal Clip-level vector pool.
|
215 |
+
"""
|
216 |
+
def __init__(self, out_dir):
|
217 |
+
"""use hidden_states to store `(video, text)`."""
|
218 |
+
"""use video_ids to store `(video_id, start, end)`."""
|
219 |
+
super().__init__(out_dir)
|
220 |
+
|
221 |
+
def __call__(self, sample, **kwargs):
|
222 |
+
pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy()
|
223 |
+
pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy()
|
224 |
+
|
225 |
+
self.hidden_states.append(
|
226 |
+
np.concatenate([pooled_video, pooled_text], axis=1)
|
227 |
+
)
|
228 |
+
|
229 |
+
video_starts = sample["video_start"].cpu()
|
230 |
+
video_ends = sample["video_end"].cpu()
|
231 |
+
assert torch.all(torch.le(video_starts, video_ends))
|
232 |
+
|
233 |
+
text_starts = sample["text_start"].cpu()
|
234 |
+
text_ends = sample["text_end"].cpu()
|
235 |
+
assert torch.all(torch.le(text_starts, text_ends))
|
236 |
+
subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"])
|
237 |
+
video_ids = [video_id for video_id in sample["video_id"]
|
238 |
+
for _ in range(subsample_size)
|
239 |
+
]
|
240 |
+
for video_id, video_start, video_end, text_start, text_end in zip(
|
241 |
+
video_ids, video_starts, video_ends, text_starts, text_ends):
|
242 |
+
self.video_ids.append((
|
243 |
+
video_id,
|
244 |
+
(int(video_start), int(video_end)),
|
245 |
+
(int(text_start), int(text_end))
|
246 |
+
))
|
fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
import pickle
|
9 |
+
from tqdm import tqdm
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
class CaptionDedupProcessor(object):
|
15 |
+
"""remove overlapping of caption sentences(clip).
|
16 |
+
Some statistics:
|
17 |
+
caption:
|
18 |
+
{'t_clip_len': 246.6448431320854,
|
19 |
+
'video_len': 281.09174795676245,
|
20 |
+
'clip_tps': 0.8841283727427481,
|
21 |
+
'video_tps': 0.7821156477732097,
|
22 |
+
'min_clip_len': 0.0,
|
23 |
+
'max_clip_len': 398.3,
|
24 |
+
'mean_clip_len': 3.196580003006861,
|
25 |
+
'num_clip': 77.15897706301081}
|
26 |
+
|
27 |
+
raw_caption:
|
28 |
+
{'t_clip_len': 238.95908778424115,
|
29 |
+
'video_len': 267.5914859862507,
|
30 |
+
'clip_tps': 2.4941363624267963,
|
31 |
+
'video_tps': 2.258989769647173,
|
32 |
+
'min_clip_len': 0.0,
|
33 |
+
'max_clip_len': 398.3,
|
34 |
+
'mean_clip_len': 3.0537954186814265,
|
35 |
+
'num_clip': 78.24986779481756}
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, pkl_file):
|
39 |
+
with open(pkl_file, "rb") as fd:
|
40 |
+
self.data = pickle.load(fd)
|
41 |
+
self.stat = {
|
42 |
+
"t_clip_len": [],
|
43 |
+
"video_len": [],
|
44 |
+
"clip_tps": [],
|
45 |
+
"video_tps": [],
|
46 |
+
"clip_len": [],
|
47 |
+
}
|
48 |
+
|
49 |
+
def __call__(self):
|
50 |
+
for idx, video_id in enumerate(tqdm(self.data)):
|
51 |
+
caption = json.loads(self.data[video_id])
|
52 |
+
caption = self._dedup(caption)
|
53 |
+
if idx < 4096: # for the first 4096 examples, compute the statistics.
|
54 |
+
self.save_stat(video_id, caption)
|
55 |
+
self.data[video_id] = json.dumps(caption)
|
56 |
+
self.print_stat()
|
57 |
+
|
58 |
+
def single(self, video_id):
|
59 |
+
caption = json.loads(self.data[video_id])
|
60 |
+
for clip_idx, (start, end, text) in enumerate(
|
61 |
+
zip(caption["start"], caption["end"], caption["text"])
|
62 |
+
):
|
63 |
+
print(start, end, text)
|
64 |
+
print("@" * 100)
|
65 |
+
caption = self._dedup(caption)
|
66 |
+
for clip_idx, (start, end, text) in enumerate(
|
67 |
+
zip(caption["start"], caption["end"], caption["text"])
|
68 |
+
):
|
69 |
+
print(start, end, text)
|
70 |
+
print("#" * 100)
|
71 |
+
self.save_stat(video_id, caption)
|
72 |
+
self.print_stat()
|
73 |
+
|
74 |
+
def finalize(self, tgt_fn):
|
75 |
+
with open(tgt_fn, "wb") as fw:
|
76 |
+
pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL)
|
77 |
+
|
78 |
+
def save_stat(self, video_id, caption):
|
79 |
+
video_fn = os.path.join(
|
80 |
+
"data/feat/feat_how2_s3d", video_id + ".npy"
|
81 |
+
)
|
82 |
+
if os.path.isfile(video_fn):
|
83 |
+
with open(video_fn, "rb", 1) as fr: # 24 is the buffer size. buffered
|
84 |
+
version = np.lib.format.read_magic(fr)
|
85 |
+
shape, fortran, dtype = np.lib.format._read_array_header(fr, version)
|
86 |
+
video_len = shape[0]
|
87 |
+
|
88 |
+
t_clip_len = 0.0
|
89 |
+
t_tokens = 0
|
90 |
+
for idx, (start, end, text) in enumerate(
|
91 |
+
zip(caption["start"], caption["end"], caption["text"])
|
92 |
+
):
|
93 |
+
clip_len = (
|
94 |
+
(end - max(caption["end"][idx - 1], start))
|
95 |
+
if idx > 0
|
96 |
+
else end - start
|
97 |
+
)
|
98 |
+
t_clip_len += clip_len
|
99 |
+
t_tokens += len(text.split(" "))
|
100 |
+
self.stat["clip_len"].append(clip_len)
|
101 |
+
self.stat["t_clip_len"].append(t_clip_len)
|
102 |
+
self.stat["video_len"].append(video_len)
|
103 |
+
self.stat["clip_tps"].append(t_tokens / t_clip_len)
|
104 |
+
self.stat["video_tps"].append(t_tokens / video_len)
|
105 |
+
|
106 |
+
def print_stat(self):
|
107 |
+
result = {
|
108 |
+
"t_clip_len": np.mean(self.stat["t_clip_len"]),
|
109 |
+
"video_len": np.mean(self.stat["video_len"]),
|
110 |
+
"clip_tps": np.mean(self.stat["clip_tps"]),
|
111 |
+
"video_tps": np.mean(self.stat["video_tps"]),
|
112 |
+
"min_clip_len": min(self.stat["clip_len"]),
|
113 |
+
"max_clip_len": max(self.stat["clip_len"]),
|
114 |
+
"mean_clip_len": np.mean(self.stat["clip_len"]),
|
115 |
+
"num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]),
|
116 |
+
}
|
117 |
+
print(result)
|
118 |
+
|
119 |
+
def _dedup(self, caption):
|
120 |
+
def random_merge(end_idx, start, end, text, starts, ends, texts):
|
121 |
+
if random.random() > 0.5:
|
122 |
+
# print(clip_idx, "[PARTIAL INTO PREV]", end_idx)
|
123 |
+
# overlapped part goes to the end of previous.
|
124 |
+
ends[-1] = max(ends[-1], start) # ?
|
125 |
+
rest_text = text[end_idx:].strip()
|
126 |
+
if rest_text:
|
127 |
+
starts.append(max(ends[-1], start))
|
128 |
+
ends.append(max(end, starts[-1]))
|
129 |
+
texts.append(rest_text)
|
130 |
+
else: # goes to the beginning of the current.
|
131 |
+
# strip the previous.
|
132 |
+
left_text = texts[-1][:-end_idx].strip()
|
133 |
+
if left_text:
|
134 |
+
# print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx)
|
135 |
+
ends[-1] = min(ends[-1], start)
|
136 |
+
texts[-1] = left_text
|
137 |
+
else:
|
138 |
+
# print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx)
|
139 |
+
starts.pop(-1)
|
140 |
+
ends.pop(-1)
|
141 |
+
texts.pop(-1)
|
142 |
+
starts.append(start)
|
143 |
+
ends.append(end)
|
144 |
+
texts.append(text)
|
145 |
+
|
146 |
+
starts, ends, texts = [], [], []
|
147 |
+
for clip_idx, (start, end, text) in enumerate(
|
148 |
+
zip(caption["start"], caption["end"], caption["text"])
|
149 |
+
):
|
150 |
+
if not isinstance(text, str):
|
151 |
+
continue
|
152 |
+
text = text.replace("\n", " ").strip()
|
153 |
+
if len(text) == 0:
|
154 |
+
continue
|
155 |
+
starts.append(start)
|
156 |
+
ends.append(end)
|
157 |
+
texts.append(text)
|
158 |
+
break
|
159 |
+
|
160 |
+
for clip_idx, (start, end, text) in enumerate(
|
161 |
+
zip(
|
162 |
+
caption["start"][clip_idx + 1:],
|
163 |
+
caption["end"][clip_idx + 1:],
|
164 |
+
caption["text"][clip_idx + 1:],
|
165 |
+
)
|
166 |
+
):
|
167 |
+
if not isinstance(text, str):
|
168 |
+
continue
|
169 |
+
text = text.replace("\n", " ").strip()
|
170 |
+
if len(text) == 0:
|
171 |
+
continue
|
172 |
+
|
173 |
+
# print(clip_idx, texts[-5:])
|
174 |
+
# print(clip_idx, start, end, text)
|
175 |
+
if texts[-1].endswith(text): # subset of prev caption -> merge
|
176 |
+
# print(clip_idx, "[MERGE INTO PREV]")
|
177 |
+
ends[-1] = max(ends[-1], end)
|
178 |
+
elif text.startswith(texts[-1]): # superset of prev caption -> merge
|
179 |
+
# print(clip_idx, "[PREV MERGE INTO CUR]")
|
180 |
+
texts[-1] = text
|
181 |
+
starts[-1] = min(starts[-1], start)
|
182 |
+
ends[-1] = max(ends[-1], end)
|
183 |
+
else: # overlapping or non-overlapping.
|
184 |
+
for end_idx in range(1, len(text) + 1):
|
185 |
+
if texts[-1].endswith(text[:end_idx]):
|
186 |
+
random_merge(end_idx, start, end, text, starts, ends, texts)
|
187 |
+
break
|
188 |
+
else:
|
189 |
+
starts.append(start)
|
190 |
+
ends.append(end)
|
191 |
+
texts.append(text)
|
192 |
+
|
193 |
+
assert (ends[-1] + 0.001) >= starts[-1] and len(
|
194 |
+
texts[-1]
|
195 |
+
) > 0, "{} {} {} <- {} {} {}, {} {} {}".format(
|
196 |
+
str(starts[-1]),
|
197 |
+
str(ends[-1]),
|
198 |
+
texts[-1],
|
199 |
+
caption["start"][clip_idx - 1],
|
200 |
+
caption["end"][clip_idx - 1],
|
201 |
+
caption["text"][clip_idx - 1],
|
202 |
+
str(start),
|
203 |
+
str(end),
|
204 |
+
text,
|
205 |
+
)
|
206 |
+
|
207 |
+
return {"start": starts, "end": ends, "text": texts}
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == "__main__":
|
211 |
+
import argparse
|
212 |
+
|
213 |
+
parser = argparse.ArgumentParser(description="dedup how2 caption")
|
214 |
+
parser.add_argument('--how2dir', default="data/how2")
|
215 |
+
args = parser.parse_args()
|
216 |
+
|
217 |
+
raw_caption_json = os.path.join(args.how2dir, "raw_caption.json")
|
218 |
+
raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl")
|
219 |
+
raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl")
|
220 |
+
|
221 |
+
def convert_to_pickle(src_fn, tgt_fn):
|
222 |
+
with open(src_fn) as fd:
|
223 |
+
captions = json.load(fd)
|
224 |
+
|
225 |
+
for video_id in captions:
|
226 |
+
captions[video_id] = json.dumps(captions[video_id])
|
227 |
+
|
228 |
+
with open(tgt_fn, "wb") as fw:
|
229 |
+
pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL)
|
230 |
+
|
231 |
+
if not os.path.isfile(raw_caption_pickle):
|
232 |
+
convert_to_pickle(raw_caption_json, raw_caption_pickle)
|
233 |
+
|
234 |
+
deduper = CaptionDedupProcessor(raw_caption_pickle)
|
235 |
+
deduper()
|
236 |
+
deduper.finalize(raw_caption_dedup_pickle)
|
237 |
+
|
238 |
+
"""
|
239 |
+
# demo
|
240 |
+
deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl")
|
241 |
+
deduper.single("HfIeQ9pzL5U")
|
242 |
+
"""
|
fairseq/examples/MMPT/mmpt/processors/how2processor.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import math
|
21 |
+
import pickle
|
22 |
+
import random
|
23 |
+
import os
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from collections import deque
|
27 |
+
from typing import Optional, Tuple, List
|
28 |
+
from .processor import (
|
29 |
+
Processor,
|
30 |
+
MetaProcessor,
|
31 |
+
TextProcessor,
|
32 |
+
Aligner,
|
33 |
+
MMAttentionMask2DProcessor
|
34 |
+
)
|
35 |
+
|
36 |
+
from ..utils import ShardedTensor
|
37 |
+
|
38 |
+
|
39 |
+
class How2MetaProcessor(MetaProcessor):
|
40 |
+
def __init__(self, config):
|
41 |
+
super().__init__(config)
|
42 |
+
path = self._get_split_path(config)
|
43 |
+
with open(path) as fd:
|
44 |
+
self.data = [line.strip() for line in fd]
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
video_id = self.data[idx]
|
48 |
+
return video_id, video_id
|
49 |
+
|
50 |
+
|
51 |
+
class ShardedHow2MetaProcessor(How2MetaProcessor):
|
52 |
+
def __init__(self, config):
|
53 |
+
super().__init__(config)
|
54 |
+
self.split = str(config.split)
|
55 |
+
self.vfeat_dir = config.vfeat_dir
|
56 |
+
self._init_shard()
|
57 |
+
|
58 |
+
def _init_shard(self):
|
59 |
+
if self.split == "train":
|
60 |
+
meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
|
61 |
+
with open(meta_fn, "rb") as fr:
|
62 |
+
meta = pickle.load(fr)
|
63 |
+
elif self.split == "valid":
|
64 |
+
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
|
65 |
+
with open(meta_fn, "rb") as fr:
|
66 |
+
meta = pickle.load(fr)
|
67 |
+
elif self.split == "test":
|
68 |
+
print("use how2 val as test.")
|
69 |
+
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
|
70 |
+
with open(meta_fn, "rb") as fr:
|
71 |
+
meta = pickle.load(fr)
|
72 |
+
else:
|
73 |
+
raise ValueError("unsupported for MetaProcessor:", self.split)
|
74 |
+
video_id_to_shard = {}
|
75 |
+
for shard_id in meta:
|
76 |
+
for video_idx, video_id in enumerate(meta[shard_id]):
|
77 |
+
video_id_to_shard[video_id] = (shard_id, video_idx)
|
78 |
+
self.video_id_to_shard = video_id_to_shard
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
video_id, video_id = super().__getitem__(idx)
|
82 |
+
shard_id, shard_idx = self.video_id_to_shard[video_id]
|
83 |
+
meta = (video_id, idx, shard_id, shard_idx)
|
84 |
+
return meta, meta
|
85 |
+
|
86 |
+
|
87 |
+
class ShardedVideoProcessor(Processor):
|
88 |
+
"""
|
89 |
+
mmaped shards of numpy video features.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, config):
|
93 |
+
self.split = str(config.split)
|
94 |
+
self.vfeat_dir = config.vfeat_dir
|
95 |
+
|
96 |
+
def __call__(self, video_id):
|
97 |
+
_, _, shard_id, video_idx = video_id
|
98 |
+
if self.split == "train":
|
99 |
+
shard = ShardedTensor.load(
|
100 |
+
os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
|
101 |
+
"r"
|
102 |
+
)
|
103 |
+
elif self.split == "valid":
|
104 |
+
shard = ShardedTensor.load(
|
105 |
+
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
|
106 |
+
"r"
|
107 |
+
)
|
108 |
+
elif self.split == "test":
|
109 |
+
shard = ShardedTensor.load(
|
110 |
+
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
|
111 |
+
"r"
|
112 |
+
)
|
113 |
+
else:
|
114 |
+
raise ValueError("unknown split", self.split)
|
115 |
+
feat = shard[video_idx]
|
116 |
+
return feat
|
117 |
+
|
118 |
+
|
119 |
+
class ShardedTextProcessor(Processor):
|
120 |
+
def __init__(self, config):
|
121 |
+
self.tfeat_dir = str(config.tfeat_dir)
|
122 |
+
self.split = str(config.split)
|
123 |
+
|
124 |
+
def __call__(self, video_id):
|
125 |
+
_, _, shard_id, shard_idx = video_id
|
126 |
+
if self.split == "train":
|
127 |
+
target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
|
128 |
+
elif self.split == "valid":
|
129 |
+
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
|
130 |
+
elif self.split == "test":
|
131 |
+
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
|
132 |
+
else:
|
133 |
+
raise ValueError("unknown split", self.split)
|
134 |
+
|
135 |
+
startend = ShardedTensor.load(
|
136 |
+
target_path + ".startends", "r")[shard_idx]
|
137 |
+
cap_ids = ShardedTensor.load(
|
138 |
+
target_path + ".caps_ids", "r")[shard_idx]
|
139 |
+
cap = []
|
140 |
+
for clip_idx in range(len(cap_ids)):
|
141 |
+
clip = cap_ids[clip_idx]
|
142 |
+
cap.append(clip[clip != -1].tolist())
|
143 |
+
start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
|
144 |
+
return {"start": start, "end": end, "cap": cap}
|
145 |
+
|
146 |
+
|
147 |
+
class FixedLenAligner(Aligner):
|
148 |
+
"""
|
149 |
+
In the model we assume text is on the left (closer to BERT formulation)
|
150 |
+
and video is on the right.
|
151 |
+
We fix the total length of text + video.
|
152 |
+
max_video_len is in number of secs.
|
153 |
+
max_text_len is in number of tokens.
|
154 |
+
|
155 |
+
special tokens formats:
|
156 |
+
we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
|
157 |
+
[CLS] will be splitted out into:
|
158 |
+
[CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
|
159 |
+
token_type_ids will be generated by the model (for now).
|
160 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
161 |
+
| first sequence | second sequence |
|
162 |
+
so each sequence owns a [SEP] token for no-ops.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, config):
|
166 |
+
super().__init__(config)
|
167 |
+
self.text_clip_sampler = TextClipSamplingProcessor(
|
168 |
+
self.max_len - self.max_video_len - 3
|
169 |
+
)
|
170 |
+
"""
|
171 |
+
decide subsampling:
|
172 |
+
`config.subsampling` will change batch_size in trainer.
|
173 |
+
`config.clip_per_video` (used by RetriTask) doesn't
|
174 |
+
change batch_size in trainer.
|
175 |
+
"""
|
176 |
+
subsampling = config.subsampling \
|
177 |
+
if config.subsampling is not None else None
|
178 |
+
if config.clip_per_video is not None:
|
179 |
+
subsampling = config.clip_per_video
|
180 |
+
self.subsampling = subsampling
|
181 |
+
|
182 |
+
def _get_text_maxlen(self):
|
183 |
+
# use max text len
|
184 |
+
return self.text_clip_sampler.max_text_len
|
185 |
+
|
186 |
+
def __call__(self, video_id, video_feature, text_feature):
|
187 |
+
from transformers import default_data_collator
|
188 |
+
video_idx = video_id[1]
|
189 |
+
if self.subsampling is not None and self.subsampling >= 1:
|
190 |
+
batch = []
|
191 |
+
for _ in range(self.subsampling):
|
192 |
+
centerclip_idx = random.randint(
|
193 |
+
0, len(text_feature["start"]) - 1)
|
194 |
+
batch.append(
|
195 |
+
self.sampling(
|
196 |
+
video_idx,
|
197 |
+
video_feature,
|
198 |
+
text_feature,
|
199 |
+
centerclip_idx,
|
200 |
+
self._get_text_maxlen()
|
201 |
+
))
|
202 |
+
batch = self.batch_post_processing(batch, video_feature)
|
203 |
+
batch = default_data_collator(batch)
|
204 |
+
else:
|
205 |
+
raise ValueError(
|
206 |
+
"dataset.subsampling must be >= 1 for efficient video loading.")
|
207 |
+
batch = self.sampling(video_idx, video_feature, text_feature)
|
208 |
+
batch = self.batch_post_processing(batch, video_feature)
|
209 |
+
|
210 |
+
batch["video_id"] = video_id if isinstance(video_id, str) \
|
211 |
+
else video_id[0]
|
212 |
+
# e2e: make sure frame ids is into tensor.
|
213 |
+
assert torch.is_tensor(batch["vfeats"])
|
214 |
+
return batch
|
215 |
+
|
216 |
+
def sampling(
|
217 |
+
self,
|
218 |
+
video_idx,
|
219 |
+
video_feature,
|
220 |
+
text_feature,
|
221 |
+
centerclip_idx=None,
|
222 |
+
sampled_max_text_len=None,
|
223 |
+
):
|
224 |
+
text_clip_indexs = self.text_clip_sampler(
|
225 |
+
text_feature, centerclip_idx,
|
226 |
+
sampled_max_text_len
|
227 |
+
)
|
228 |
+
if isinstance(video_feature, np.ndarray):
|
229 |
+
video_len = len(video_feature)
|
230 |
+
else:
|
231 |
+
video_len = math.ceil(text_feature["end"][-1])
|
232 |
+
|
233 |
+
video_end = min(
|
234 |
+
math.ceil(text_feature["end"][text_clip_indexs[-1]]),
|
235 |
+
video_len
|
236 |
+
)
|
237 |
+
video_start = max(
|
238 |
+
min(
|
239 |
+
math.floor(text_feature["start"][text_clip_indexs[0]]),
|
240 |
+
video_end),
|
241 |
+
0
|
242 |
+
)
|
243 |
+
|
244 |
+
video_clips = {"start": [video_start], "end": [video_end]}
|
245 |
+
|
246 |
+
# tensorize.
|
247 |
+
vfeats, vmasks = self._build_video_seq(
|
248 |
+
video_feature, video_clips
|
249 |
+
)
|
250 |
+
caps, cmasks = self._build_text_seq(
|
251 |
+
text_feature, text_clip_indexs
|
252 |
+
)
|
253 |
+
|
254 |
+
text_start = text_clip_indexs[0]
|
255 |
+
text_end = text_clip_indexs[-1] + 1
|
256 |
+
|
257 |
+
return {
|
258 |
+
"caps": caps,
|
259 |
+
"cmasks": cmasks,
|
260 |
+
"vfeats": vfeats,
|
261 |
+
"vmasks": vmasks,
|
262 |
+
"video_start": video_start,
|
263 |
+
"video_end": video_end,
|
264 |
+
"text_start": text_start,
|
265 |
+
"text_end": text_end,
|
266 |
+
}
|
267 |
+
|
268 |
+
|
269 |
+
class VariedLenAligner(FixedLenAligner):
|
270 |
+
def __init__(self, config):
|
271 |
+
super().__init__(config)
|
272 |
+
self.sampled_min_len = config.sampled_min_len
|
273 |
+
self.sampled_max_len = config.sampled_max_len
|
274 |
+
|
275 |
+
def _get_text_maxlen(self):
|
276 |
+
return random.randint(self.sampled_min_len, self.sampled_max_len)
|
277 |
+
|
278 |
+
|
279 |
+
class StartClipAligner(VariedLenAligner):
|
280 |
+
def sampling(
|
281 |
+
self,
|
282 |
+
video_idx,
|
283 |
+
video_feature,
|
284 |
+
text_feature,
|
285 |
+
centerclip_idx=None,
|
286 |
+
sampled_max_text_len=None,
|
287 |
+
):
|
288 |
+
return super().sampling(
|
289 |
+
video_idx, video_feature, text_feature, 0)
|
290 |
+
|
291 |
+
|
292 |
+
class OverlappedAligner(VariedLenAligner):
|
293 |
+
"""video clip and text clip has overlappings
|
294 |
+
but may not be the same start/end."""
|
295 |
+
def __init__(self, config):
|
296 |
+
super().__init__(config)
|
297 |
+
self.sampled_video_min_len = config.sampled_video_min_len
|
298 |
+
self.sampled_video_max_len = config.sampled_video_max_len
|
299 |
+
|
300 |
+
self.video_clip_sampler = VideoClipSamplingProcessor()
|
301 |
+
|
302 |
+
def _get_video_maxlen(self):
|
303 |
+
return random.randint(
|
304 |
+
self.sampled_video_min_len, self.sampled_video_max_len)
|
305 |
+
|
306 |
+
def sampling(
|
307 |
+
self,
|
308 |
+
video_idx,
|
309 |
+
video_feature,
|
310 |
+
text_feature,
|
311 |
+
centerclip_idx=None,
|
312 |
+
sampled_max_text_len=None,
|
313 |
+
):
|
314 |
+
text_clip_indexs = self.text_clip_sampler(
|
315 |
+
text_feature, centerclip_idx,
|
316 |
+
sampled_max_text_len
|
317 |
+
)
|
318 |
+
if isinstance(video_feature, np.ndarray):
|
319 |
+
video_len = len(video_feature)
|
320 |
+
else:
|
321 |
+
video_len = math.ceil(text_feature["end"][-1])
|
322 |
+
low = math.floor(text_feature["start"][text_clip_indexs[0]])
|
323 |
+
high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
|
324 |
+
if low < high:
|
325 |
+
center = random.randint(low, high)
|
326 |
+
else:
|
327 |
+
center = int((low + high) // 2)
|
328 |
+
center = max(0, min(video_feature.shape[0] - 1, center))
|
329 |
+
|
330 |
+
assert 0 <= center < video_feature.shape[0]
|
331 |
+
|
332 |
+
video_clips = self.video_clip_sampler(
|
333 |
+
video_len, self._get_video_maxlen(), center
|
334 |
+
)
|
335 |
+
video_start = video_clips["start"][0]
|
336 |
+
video_end = video_clips["end"][0]
|
337 |
+
|
338 |
+
# tensorize.
|
339 |
+
vfeats, vmasks = self._build_video_seq(
|
340 |
+
video_feature, video_clips
|
341 |
+
)
|
342 |
+
caps, cmasks = self._build_text_seq(
|
343 |
+
text_feature, text_clip_indexs
|
344 |
+
)
|
345 |
+
|
346 |
+
text_start = text_clip_indexs[0]
|
347 |
+
text_end = text_clip_indexs[-1] + 1
|
348 |
+
|
349 |
+
return {
|
350 |
+
"caps": caps,
|
351 |
+
"cmasks": cmasks,
|
352 |
+
"vfeats": vfeats,
|
353 |
+
"vmasks": vmasks,
|
354 |
+
"video_start": video_start,
|
355 |
+
"video_end": video_end,
|
356 |
+
"text_start": text_start,
|
357 |
+
"text_end": text_end,
|
358 |
+
}
|
359 |
+
|
360 |
+
|
361 |
+
class MFMMLMAligner(FixedLenAligner):
|
362 |
+
"""
|
363 |
+
`FixedLenAligner` with Masked Language Model and Masked Frame Model.
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self, config):
|
367 |
+
super().__init__(config)
|
368 |
+
keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
|
369 |
+
self.text_clip_sampler = TextClipSamplingProcessor(
|
370 |
+
self.max_len - self.max_video_len - 3, keep_prob
|
371 |
+
)
|
372 |
+
self.sampled_min_len = config.sampled_min_len
|
373 |
+
self.sampled_max_len = config.sampled_max_len
|
374 |
+
self.masked_token_sampler = TextMaskingProcessor(config)
|
375 |
+
self.mm_type = config.mm_type \
|
376 |
+
if config.mm_type is not None else "full"
|
377 |
+
self.attnmasker = MMAttentionMask2DProcessor() \
|
378 |
+
if self.mm_type == "textgen" else None
|
379 |
+
self.masked_frame_sampler = FrameMaskingProcessor(config)
|
380 |
+
self.lazy_vfeat_mask = (
|
381 |
+
False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
|
382 |
+
)
|
383 |
+
self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
|
384 |
+
|
385 |
+
def __call__(self, video_id, video_feature, text_feature):
|
386 |
+
from transformers import default_data_collator
|
387 |
+
if self.subsampling is not None and self.subsampling > 1:
|
388 |
+
batch = []
|
389 |
+
for _ in range(self.subsampling):
|
390 |
+
centerclip_idx = random.randint(
|
391 |
+
0, len(text_feature["start"]) - 1)
|
392 |
+
sampled_max_text_len = random.randint(
|
393 |
+
self.sampled_min_len, self.sampled_max_len
|
394 |
+
)
|
395 |
+
batch.append(
|
396 |
+
self.sampling(
|
397 |
+
video_id,
|
398 |
+
video_feature,
|
399 |
+
text_feature,
|
400 |
+
centerclip_idx,
|
401 |
+
sampled_max_text_len,
|
402 |
+
)
|
403 |
+
)
|
404 |
+
batch = self.batch_post_processing(batch, video_feature)
|
405 |
+
batch = default_data_collator(batch)
|
406 |
+
else:
|
407 |
+
batch = self.sampling(video_id, video_feature, text_feature)
|
408 |
+
batch = self.batch_post_processing(batch, video_feature)
|
409 |
+
batch["video_id"] = video_id if isinstance(video_id, str) \
|
410 |
+
else video_id[0]
|
411 |
+
return batch
|
412 |
+
|
413 |
+
def sampling(
|
414 |
+
self,
|
415 |
+
video_id,
|
416 |
+
video_feature,
|
417 |
+
text_feature,
|
418 |
+
centerclip_idx=None,
|
419 |
+
sampled_max_text_len=None,
|
420 |
+
):
|
421 |
+
output = FixedLenAligner.sampling(self,
|
422 |
+
video_id, video_feature, text_feature,
|
423 |
+
centerclip_idx, sampled_max_text_len)
|
424 |
+
|
425 |
+
masking_text, masking_video = None, None
|
426 |
+
if random.random() < self.mm_prob:
|
427 |
+
if random.random() > 0.5:
|
428 |
+
masking_text, masking_video = self.mm_type, "no"
|
429 |
+
else:
|
430 |
+
masking_text, masking_video = "no", "full"
|
431 |
+
video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
|
432 |
+
video_label = self.masked_frame_sampler(
|
433 |
+
output["vmasks"], masking_video, vfeats=video_feats)
|
434 |
+
caps, text_label = self.masked_token_sampler(
|
435 |
+
output["caps"], masking_text)
|
436 |
+
|
437 |
+
output.update({
|
438 |
+
"caps": caps,
|
439 |
+
"video_label": video_label,
|
440 |
+
"text_label": text_label,
|
441 |
+
})
|
442 |
+
|
443 |
+
if self.attnmasker is not None:
|
444 |
+
attention_mask = self.attnmasker(
|
445 |
+
output["vmasks"], output["cmasks"], masking_text)
|
446 |
+
output.update({
|
447 |
+
"attention_mask": attention_mask
|
448 |
+
})
|
449 |
+
return output
|
450 |
+
|
451 |
+
|
452 |
+
class FrameMaskingProcessor(Processor):
|
453 |
+
def __init__(self, config):
|
454 |
+
self.mfm_probability = 0.15
|
455 |
+
if config.mfm_probability is not None:
|
456 |
+
self.mfm_probability = config.mfm_probability
|
457 |
+
|
458 |
+
def __call__(self, vmasks, modality_masking=None, vfeats=None):
|
459 |
+
"""
|
460 |
+
We perform lazy masking to save data transfer time.
|
461 |
+
It only generates video_labels by default and MFM model
|
462 |
+
will do actualy masking.
|
463 |
+
Return: `video_label` is a binary mask.
|
464 |
+
"""
|
465 |
+
video_label = vmasks.clone()
|
466 |
+
if modality_masking is not None:
|
467 |
+
if modality_masking == "full":
|
468 |
+
probability_matrix = torch.full(video_label.shape, 1.)
|
469 |
+
elif modality_masking == "no":
|
470 |
+
probability_matrix = torch.full(video_label.shape, 0.)
|
471 |
+
elif modality_masking == "inverse":
|
472 |
+
probability_matrix = torch.full(
|
473 |
+
video_label.shape, 1. - self.mfm_probability)
|
474 |
+
else:
|
475 |
+
raise ValueError("unknown modality masking.", modality_masking)
|
476 |
+
else:
|
477 |
+
probability_matrix = torch.full(
|
478 |
+
video_label.shape, self.mfm_probability)
|
479 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
480 |
+
# We only compute loss on masked tokens
|
481 |
+
video_label[~masked_indices] = 0
|
482 |
+
if vfeats is not None:
|
483 |
+
vfeats[video_label, :] = 0.0
|
484 |
+
return video_label
|
485 |
+
|
486 |
+
|
487 |
+
class TextGenerationProcessor(Processor):
|
488 |
+
def __init__(self, tokenizer):
|
489 |
+
self.bos_token_id = tokenizer.bos_token_id
|
490 |
+
self.pad_token_id = tokenizer.pad_token_id
|
491 |
+
|
492 |
+
def __call__(self, inputs):
|
493 |
+
labels = inputs.clone()
|
494 |
+
# [CLS] [SEP] for video
|
495 |
+
labels[:2] = -100
|
496 |
+
# keep [SEP] for text.
|
497 |
+
pad_mask = labels == self.pad_token_id
|
498 |
+
labels[pad_mask] = -100
|
499 |
+
inputs[2:] = torch.cat([
|
500 |
+
torch.LongTensor([self.bos_token_id]),
|
501 |
+
inputs[2:-1]])
|
502 |
+
inputs[pad_mask] = self.pad_token_id
|
503 |
+
assert len(inputs) == len(labels)
|
504 |
+
return inputs, labels
|
505 |
+
|
506 |
+
|
507 |
+
class TextMaskingProcessor(Processor):
|
508 |
+
def __init__(self, config):
|
509 |
+
"""this function is borrowed from
|
510 |
+
`transformers/data/data_collator.DataCollatorForLanguageModeling`"""
|
511 |
+
self.mlm_probability = 0.15
|
512 |
+
if config.mlm_probability is not None:
|
513 |
+
self.mlm_probability = config.mlm_probability
|
514 |
+
self.bert_name = config.bert_name
|
515 |
+
# [CLS] is used as bos_token and [SEP] is used as eos_token.
|
516 |
+
# https://huggingface.co/transformers/master/model_doc/bertgeneration.html
|
517 |
+
from transformers import AutoTokenizer
|
518 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
519 |
+
self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
|
520 |
+
self.textgen = TextGenerationProcessor(self.tokenizer)
|
521 |
+
|
522 |
+
def __call__(
|
523 |
+
self, inputs: torch.Tensor,
|
524 |
+
modality_masking=None,
|
525 |
+
special_tokens_mask: Optional[torch.Tensor] = None
|
526 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
527 |
+
"""
|
528 |
+
expand modality_masking into
|
529 |
+
None: traditional bert masking.
|
530 |
+
"no": no masking.
|
531 |
+
"full": all [MASK] token for generation.
|
532 |
+
"gen": autoregressive generation.
|
533 |
+
"""
|
534 |
+
"""
|
535 |
+
Prepare masked tokens inputs/labels for masked language modeling:
|
536 |
+
80% MASK, 10% random, 10% original.
|
537 |
+
"""
|
538 |
+
labels = inputs.clone()
|
539 |
+
# We sample a few tokens in each sequence for MLM training
|
540 |
+
# (with probability `self.mlm_probability`)
|
541 |
+
if modality_masking is not None:
|
542 |
+
if modality_masking == "full":
|
543 |
+
probability_matrix = torch.full(labels.shape, 1.)
|
544 |
+
elif modality_masking == "no":
|
545 |
+
probability_matrix = torch.full(labels.shape, 0.)
|
546 |
+
elif modality_masking.startswith("textgen"):
|
547 |
+
# [CLS] [SEP] <s> ...
|
548 |
+
inputs, labels = self.textgen(inputs)
|
549 |
+
if "mask" not in modality_masking:
|
550 |
+
return inputs, labels
|
551 |
+
inputs = self.mask_input(inputs, special_tokens_mask)
|
552 |
+
return inputs, labels
|
553 |
+
elif modality_masking == "mask":
|
554 |
+
inputs = self.mask_input(inputs, special_tokens_mask)
|
555 |
+
labels = torch.full(inputs.shape, -100)
|
556 |
+
return inputs, labels
|
557 |
+
elif modality_masking == "inverse":
|
558 |
+
probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
|
559 |
+
else:
|
560 |
+
raise ValueError("unknown modality masking.", modality_masking)
|
561 |
+
else:
|
562 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
563 |
+
|
564 |
+
if special_tokens_mask is None:
|
565 |
+
special_tokens_mask = self.get_special_tokens_mask(
|
566 |
+
labels.tolist(), already_has_special_tokens=True
|
567 |
+
)
|
568 |
+
special_tokens_mask = torch.tensor(
|
569 |
+
special_tokens_mask, dtype=torch.bool)
|
570 |
+
else:
|
571 |
+
special_tokens_mask = special_tokens_mask.bool()
|
572 |
+
|
573 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
574 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
575 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
576 |
+
|
577 |
+
# 80% of the time,
|
578 |
+
# we replace masked input tokens with tokenizer.mask_token ([MASK])
|
579 |
+
indices_replaced = (
|
580 |
+
torch.bernoulli(
|
581 |
+
torch.full(labels.shape, 0.8)).bool() & masked_indices
|
582 |
+
)
|
583 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
|
584 |
+
self.tokenizer.mask_token
|
585 |
+
)
|
586 |
+
|
587 |
+
# 10% of the time, we replace masked input tokens with random word
|
588 |
+
indices_random = (
|
589 |
+
torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
|
590 |
+
& masked_indices
|
591 |
+
& ~indices_replaced
|
592 |
+
)
|
593 |
+
random_words = torch.randint(
|
594 |
+
len(self.tokenizer), labels.shape, dtype=torch.long
|
595 |
+
)
|
596 |
+
inputs[indices_random] = random_words[indices_random]
|
597 |
+
|
598 |
+
# The rest of the time (10% of the time) we keep the masked input
|
599 |
+
# tokens unchanged
|
600 |
+
return inputs, labels
|
601 |
+
|
602 |
+
def mask_input(self, inputs, special_tokens_mask=None):
|
603 |
+
# the following is new with masked autoregressive.
|
604 |
+
probability_matrix = torch.full(
|
605 |
+
inputs.shape, self.mlm_probability)
|
606 |
+
if special_tokens_mask is None:
|
607 |
+
special_tokens_mask = self.get_special_tokens_mask(
|
608 |
+
inputs.tolist(), already_has_special_tokens=True
|
609 |
+
)
|
610 |
+
special_tokens_mask = torch.tensor(
|
611 |
+
special_tokens_mask, dtype=torch.bool)
|
612 |
+
else:
|
613 |
+
special_tokens_mask = special_tokens_mask.bool()
|
614 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
615 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
616 |
+
indices_replaced = (
|
617 |
+
torch.bernoulli(
|
618 |
+
torch.full(inputs.shape, 0.8)).bool() & masked_indices
|
619 |
+
)
|
620 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
|
621 |
+
self.tokenizer.mask_token
|
622 |
+
)
|
623 |
+
|
624 |
+
# 10% of the time, we replace masked input tokens with random word
|
625 |
+
indices_random = (
|
626 |
+
torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
|
627 |
+
& masked_indices
|
628 |
+
& ~indices_replaced
|
629 |
+
)
|
630 |
+
random_words = torch.randint(
|
631 |
+
len(self.tokenizer), inputs.shape, dtype=torch.long
|
632 |
+
)
|
633 |
+
inputs[indices_random] = random_words[indices_random]
|
634 |
+
return inputs
|
635 |
+
|
636 |
+
def get_special_tokens_mask(
|
637 |
+
self, token_ids_0: List[int],
|
638 |
+
token_ids_1: Optional[List[int]] = None,
|
639 |
+
already_has_special_tokens: bool = False
|
640 |
+
) -> List[int]:
|
641 |
+
"""
|
642 |
+
Note: the version from transformers do not consider pad
|
643 |
+
as special tokens.
|
644 |
+
"""
|
645 |
+
|
646 |
+
if already_has_special_tokens:
|
647 |
+
if token_ids_1 is not None:
|
648 |
+
raise ValueError(
|
649 |
+
"You should not supply a second sequence if"
|
650 |
+
"the provided sequence of "
|
651 |
+
"ids is already formated with special tokens "
|
652 |
+
"for the model."
|
653 |
+
)
|
654 |
+
return list(map(lambda x: 1 if x in [
|
655 |
+
self.tokenizer.sep_token_id,
|
656 |
+
self.tokenizer.cls_token_id,
|
657 |
+
self.tokenizer.pad_token_id] else 0, token_ids_0))
|
658 |
+
|
659 |
+
if token_ids_1 is not None:
|
660 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
661 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
662 |
+
|
663 |
+
|
664 |
+
class TextClipSamplingProcessor(Processor):
|
665 |
+
def __init__(self, max_text_len, keep_prob=1.0):
|
666 |
+
self.max_text_len = max_text_len
|
667 |
+
self.max_video_len = 256 # always hold.
|
668 |
+
self.keep_prob = keep_prob
|
669 |
+
|
670 |
+
def __call__(
|
671 |
+
self,
|
672 |
+
text_feature,
|
673 |
+
centerclip_idx=None,
|
674 |
+
sampled_max_text_len=None,
|
675 |
+
sampled_max_video_len=None,
|
676 |
+
):
|
677 |
+
# Let's use all caps for now and see if 256 can cover all of them.
|
678 |
+
if sampled_max_text_len is not None:
|
679 |
+
max_text_len = sampled_max_text_len
|
680 |
+
else:
|
681 |
+
max_text_len = self.max_text_len
|
682 |
+
if sampled_max_video_len is not None:
|
683 |
+
max_video_len = sampled_max_video_len
|
684 |
+
else:
|
685 |
+
max_video_len = self.max_video_len
|
686 |
+
|
687 |
+
t_num_clips = len(text_feature["start"])
|
688 |
+
|
689 |
+
if centerclip_idx is None:
|
690 |
+
centerclip_idx = random.randint(0, t_num_clips - 1)
|
691 |
+
|
692 |
+
start_idx, end_idx = centerclip_idx, centerclip_idx + 1
|
693 |
+
text_clip_indexs = deque()
|
694 |
+
text_clip_indexs.append(start_idx)
|
695 |
+
text_len = len(text_feature["cap"][start_idx])
|
696 |
+
|
697 |
+
video_len = max(
|
698 |
+
0,
|
699 |
+
text_feature["end"][start_idx]
|
700 |
+
- text_feature["start"][start_idx],
|
701 |
+
)
|
702 |
+
|
703 |
+
while (
|
704 |
+
(start_idx > 0 or end_idx < t_num_clips)
|
705 |
+
and text_len < max_text_len
|
706 |
+
and video_len < max_video_len
|
707 |
+
):
|
708 |
+
if random.random() > 0.5 and end_idx < t_num_clips:
|
709 |
+
# skip the next one?
|
710 |
+
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
|
711 |
+
end_idx = end_idx + 1
|
712 |
+
text_clip_indexs.append(end_idx)
|
713 |
+
text_len += len(text_feature["cap"][end_idx])
|
714 |
+
end_idx += 1
|
715 |
+
elif start_idx > 0:
|
716 |
+
if random.random() > self.keep_prob and (start_idx - 1) > 0:
|
717 |
+
start_idx = start_idx - 1
|
718 |
+
start_idx -= 1
|
719 |
+
text_clip_indexs.insert(0, start_idx)
|
720 |
+
text_len += len(text_feature["cap"][start_idx])
|
721 |
+
else:
|
722 |
+
if end_idx < t_num_clips:
|
723 |
+
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
|
724 |
+
end_idx = end_idx + 1
|
725 |
+
text_clip_indexs.append(end_idx)
|
726 |
+
text_len += len(text_feature["cap"][end_idx])
|
727 |
+
end_idx += 1
|
728 |
+
else:
|
729 |
+
return text_clip_indexs
|
730 |
+
video_len = max(
|
731 |
+
0,
|
732 |
+
text_feature["end"][text_clip_indexs[-1]]
|
733 |
+
- text_feature["start"][text_clip_indexs[0]],
|
734 |
+
)
|
735 |
+
return text_clip_indexs
|
736 |
+
|
737 |
+
|
738 |
+
class VideoClipSamplingProcessor(Processor):
|
739 |
+
def __call__(self, video_len, max_video_len, center):
|
740 |
+
"""
|
741 |
+
`video_len`: length of the video.
|
742 |
+
`max_video_len`: maximum video tokens allowd in a sequence.
|
743 |
+
`center`: initial starting index.
|
744 |
+
"""
|
745 |
+
assert center >= 0 and center < video_len
|
746 |
+
t_clip_len = 0
|
747 |
+
start, end = center, center
|
748 |
+
while (start > 0 or end < video_len) and t_clip_len < max_video_len:
|
749 |
+
# decide the direction to grow.
|
750 |
+
if start <= 0:
|
751 |
+
end += 1
|
752 |
+
elif end >= video_len:
|
753 |
+
start -= 1
|
754 |
+
elif random.random() > 0.5:
|
755 |
+
end += 1
|
756 |
+
else:
|
757 |
+
start -= 1
|
758 |
+
t_clip_len += 1
|
759 |
+
return {"start": [start], "end": [end]}
|
760 |
+
|
761 |
+
|
762 |
+
class How2MILNCEAligner(FixedLenAligner):
|
763 |
+
"""reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
|
764 |
+
|
765 |
+
def __init__(self, config):
|
766 |
+
super().__init__(config)
|
767 |
+
self.num_candidates = 4
|
768 |
+
self.min_time = 5.0
|
769 |
+
self.num_sec = 3.2
|
770 |
+
# self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5
|
771 |
+
# self.num_frames = 16
|
772 |
+
|
773 |
+
def sampling(
|
774 |
+
self,
|
775 |
+
video_id,
|
776 |
+
video_feature,
|
777 |
+
text_feature,
|
778 |
+
centerclip_idx=None, # will be ignored.
|
779 |
+
sampled_max_text_len=None # will be ignored.
|
780 |
+
):
|
781 |
+
text, start, end = self._get_text(text_feature)
|
782 |
+
video = self._get_video(video_feature, start, end)
|
783 |
+
|
784 |
+
vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
|
785 |
+
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
|
786 |
+
vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
|
787 |
+
vmasks[: video.shape[0]] = 1
|
788 |
+
|
789 |
+
caps, cmasks = [], []
|
790 |
+
for words in text:
|
791 |
+
cap, cmask = self._build_text_seq(text_feature, words)
|
792 |
+
caps.append(cap)
|
793 |
+
cmasks.append(cmask)
|
794 |
+
caps = torch.stack(caps)
|
795 |
+
cmasks = torch.stack(cmasks)
|
796 |
+
# video of shape: (video_len)
|
797 |
+
# text of shape (num_candidates, max_text_len)
|
798 |
+
|
799 |
+
return {
|
800 |
+
"caps": caps,
|
801 |
+
"cmasks": cmasks,
|
802 |
+
"vfeats": vfeats,
|
803 |
+
"vmasks": vmasks,
|
804 |
+
# "video_id": video_id,
|
805 |
+
}
|
806 |
+
|
807 |
+
def _get_video(self, video_feature, start, end):
|
808 |
+
start_seek = random.randint(start, int(max(start, end - self.num_sec)))
|
809 |
+
# duration = self.num_sec + 0.1
|
810 |
+
return video_feature[start_seek : int(start_seek + self.num_sec)]
|
811 |
+
|
812 |
+
def _get_text(self, cap):
|
813 |
+
ind = random.randint(0, len(cap["start"]) - 1)
|
814 |
+
if self.num_candidates == 1:
|
815 |
+
words = [ind]
|
816 |
+
else:
|
817 |
+
words = []
|
818 |
+
cap_start = self._find_nearest_candidates(cap, ind)
|
819 |
+
for i in range(self.num_candidates):
|
820 |
+
words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
|
821 |
+
|
822 |
+
start, end = cap["start"][ind], cap["end"][ind]
|
823 |
+
# TODO: May need to be improved for edge cases.
|
824 |
+
# expand the min time.
|
825 |
+
if end - start < self.min_time:
|
826 |
+
diff = self.min_time - end + start
|
827 |
+
start = max(0, start - diff / 2)
|
828 |
+
end = start + self.min_time
|
829 |
+
return words, int(start), int(end)
|
830 |
+
|
831 |
+
def _find_nearest_candidates(self, caption, ind):
|
832 |
+
"""find the range of the clips."""
|
833 |
+
start, end = ind, ind
|
834 |
+
#diff = caption["end"][end] - caption["start"][start]
|
835 |
+
n_candidate = 1
|
836 |
+
while n_candidate < self.num_candidates:
|
837 |
+
# the first clip
|
838 |
+
if start == 0:
|
839 |
+
return 0
|
840 |
+
# we add () in the following condition to fix the bug.
|
841 |
+
elif end == (len(caption["start"]) - 1):
|
842 |
+
return start - (self.num_candidates - n_candidate)
|
843 |
+
elif (caption["end"][end] - caption["start"][start - 1]) < (
|
844 |
+
caption["end"][end + 1] - caption["start"][start]
|
845 |
+
):
|
846 |
+
start -= 1
|
847 |
+
else:
|
848 |
+
end += 1
|
849 |
+
n_candidate += 1
|
850 |
+
return start
|
851 |
+
|
852 |
+
|
853 |
+
class PKLJSONStrTextProcessor(TextProcessor):
|
854 |
+
"""`caption.json` from howto100m are preprocessed as a
|
855 |
+
dict `[video_id, json_str]`.
|
856 |
+
Json parsing tokenization are conducted on-the-fly and cached into dict.
|
857 |
+
"""
|
858 |
+
|
859 |
+
def __init__(self, config, max_clip_text_len=96):
|
860 |
+
print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
|
861 |
+
self.caption_pkl_path = str(config.caption_pkl_path)
|
862 |
+
with open(self.caption_pkl_path, "rb") as fd:
|
863 |
+
self.data = pickle.load(fd)
|
864 |
+
self.max_clip_text_len = max_clip_text_len
|
865 |
+
from transformers import AutoTokenizer
|
866 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
867 |
+
str(config.bert_name), use_fast=config.use_fast
|
868 |
+
)
|
869 |
+
|
870 |
+
def __call__(self, video_id):
|
871 |
+
caption = self.data[video_id]
|
872 |
+
if isinstance(caption, str):
|
873 |
+
import json
|
874 |
+
caption = json.loads(caption)
|
875 |
+
cap = []
|
876 |
+
for clip_idx, text_clip in enumerate(caption["text"]):
|
877 |
+
clip_ids = []
|
878 |
+
if isinstance(text_clip, str):
|
879 |
+
clip_ids = self.tokenizer(
|
880 |
+
text_clip[: self.max_clip_text_len],
|
881 |
+
add_special_tokens=False
|
882 |
+
)["input_ids"]
|
883 |
+
cap.append(clip_ids)
|
884 |
+
caption["cap"] = cap
|
885 |
+
caption.pop("text") # save space.
|
886 |
+
self.data[video_id] = caption
|
887 |
+
return caption
|
fairseq/examples/MMPT/mmpt/processors/models/s3dg.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is licensed under the MIT license found in the
|
2 |
+
# LICENSE file in the root directory of this source tree.
|
3 |
+
|
4 |
+
"""Contains a PyTorch definition for Gated Separable 3D network (S3D-G)
|
5 |
+
with a text module for computing joint text-video embedding from raw text
|
6 |
+
and video input. The following code will enable you to load the HowTo100M
|
7 |
+
pretrained S3D Text-Video model from:
|
8 |
+
A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman,
|
9 |
+
End-to-End Learning of Visual Representations from Uncurated Instructional Videos.
|
10 |
+
https://arxiv.org/abs/1912.06430.
|
11 |
+
|
12 |
+
S3D-G was proposed by:
|
13 |
+
S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy,
|
14 |
+
Rethinking Spatiotemporal Feature Learning For Video Understanding.
|
15 |
+
https://arxiv.org/abs/1712.04851.
|
16 |
+
Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
|
17 |
+
|
18 |
+
The S3D architecture was slightly modified with a space to depth trick for TPU
|
19 |
+
optimization.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import torch as th
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.nn as nn
|
25 |
+
import os
|
26 |
+
import numpy as np
|
27 |
+
import re
|
28 |
+
|
29 |
+
|
30 |
+
class InceptionBlock(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
input_dim,
|
34 |
+
num_outputs_0_0a,
|
35 |
+
num_outputs_1_0a,
|
36 |
+
num_outputs_1_0b,
|
37 |
+
num_outputs_2_0a,
|
38 |
+
num_outputs_2_0b,
|
39 |
+
num_outputs_3_0b,
|
40 |
+
gating=True,
|
41 |
+
):
|
42 |
+
super(InceptionBlock, self).__init__()
|
43 |
+
self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1])
|
44 |
+
self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1])
|
45 |
+
self.conv_b1_b = STConv3D(
|
46 |
+
num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True
|
47 |
+
)
|
48 |
+
self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1])
|
49 |
+
self.conv_b2_b = STConv3D(
|
50 |
+
num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True
|
51 |
+
)
|
52 |
+
self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1)
|
53 |
+
self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1])
|
54 |
+
self.gating = gating
|
55 |
+
self.output_dim = (
|
56 |
+
num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b
|
57 |
+
)
|
58 |
+
if gating:
|
59 |
+
self.gating_b0 = SelfGating(num_outputs_0_0a)
|
60 |
+
self.gating_b1 = SelfGating(num_outputs_1_0b)
|
61 |
+
self.gating_b2 = SelfGating(num_outputs_2_0b)
|
62 |
+
self.gating_b3 = SelfGating(num_outputs_3_0b)
|
63 |
+
|
64 |
+
def forward(self, input):
|
65 |
+
"""Inception block
|
66 |
+
"""
|
67 |
+
b0 = self.conv_b0(input)
|
68 |
+
b1 = self.conv_b1_a(input)
|
69 |
+
b1 = self.conv_b1_b(b1)
|
70 |
+
b2 = self.conv_b2_a(input)
|
71 |
+
b2 = self.conv_b2_b(b2)
|
72 |
+
b3 = self.maxpool_b3(input)
|
73 |
+
b3 = self.conv_b3_b(b3)
|
74 |
+
if self.gating:
|
75 |
+
b0 = self.gating_b0(b0)
|
76 |
+
b1 = self.gating_b1(b1)
|
77 |
+
b2 = self.gating_b2(b2)
|
78 |
+
b3 = self.gating_b3(b3)
|
79 |
+
return th.cat((b0, b1, b2, b3), dim=1)
|
80 |
+
|
81 |
+
|
82 |
+
class SelfGating(nn.Module):
|
83 |
+
def __init__(self, input_dim):
|
84 |
+
super(SelfGating, self).__init__()
|
85 |
+
self.fc = nn.Linear(input_dim, input_dim)
|
86 |
+
|
87 |
+
def forward(self, input_tensor):
|
88 |
+
"""Feature gating as used in S3D-G.
|
89 |
+
"""
|
90 |
+
spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4])
|
91 |
+
weights = self.fc(spatiotemporal_average)
|
92 |
+
weights = th.sigmoid(weights)
|
93 |
+
return weights[:, :, None, None, None] * input_tensor
|
94 |
+
|
95 |
+
|
96 |
+
class STConv3D(nn.Module):
|
97 |
+
def __init__(
|
98 |
+
self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False
|
99 |
+
):
|
100 |
+
super(STConv3D, self).__init__()
|
101 |
+
self.separable = separable
|
102 |
+
self.relu = nn.ReLU(inplace=True)
|
103 |
+
assert len(kernel_size) == 3
|
104 |
+
if separable and kernel_size[0] != 1:
|
105 |
+
spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
|
106 |
+
temporal_kernel_size = [kernel_size[0], 1, 1]
|
107 |
+
if isinstance(stride, list) and len(stride) == 3:
|
108 |
+
spatial_stride = [1, stride[1], stride[2]]
|
109 |
+
temporal_stride = [stride[0], 1, 1]
|
110 |
+
else:
|
111 |
+
spatial_stride = [1, stride, stride]
|
112 |
+
temporal_stride = [stride, 1, 1]
|
113 |
+
if isinstance(padding, list) and len(padding) == 3:
|
114 |
+
spatial_padding = [0, padding[1], padding[2]]
|
115 |
+
temporal_padding = [padding[0], 0, 0]
|
116 |
+
else:
|
117 |
+
spatial_padding = [0, padding, padding]
|
118 |
+
temporal_padding = [padding, 0, 0]
|
119 |
+
if separable:
|
120 |
+
self.conv1 = nn.Conv3d(
|
121 |
+
input_dim,
|
122 |
+
output_dim,
|
123 |
+
kernel_size=spatial_kernel_size,
|
124 |
+
stride=spatial_stride,
|
125 |
+
padding=spatial_padding,
|
126 |
+
bias=False,
|
127 |
+
)
|
128 |
+
self.bn1 = nn.BatchNorm3d(output_dim)
|
129 |
+
self.conv2 = nn.Conv3d(
|
130 |
+
output_dim,
|
131 |
+
output_dim,
|
132 |
+
kernel_size=temporal_kernel_size,
|
133 |
+
stride=temporal_stride,
|
134 |
+
padding=temporal_padding,
|
135 |
+
bias=False,
|
136 |
+
)
|
137 |
+
self.bn2 = nn.BatchNorm3d(output_dim)
|
138 |
+
else:
|
139 |
+
self.conv1 = nn.Conv3d(
|
140 |
+
input_dim,
|
141 |
+
output_dim,
|
142 |
+
kernel_size=kernel_size,
|
143 |
+
stride=stride,
|
144 |
+
padding=padding,
|
145 |
+
bias=False,
|
146 |
+
)
|
147 |
+
self.bn1 = nn.BatchNorm3d(output_dim)
|
148 |
+
|
149 |
+
def forward(self, input):
|
150 |
+
out = self.relu(self.bn1(self.conv1(input)))
|
151 |
+
if self.separable:
|
152 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
153 |
+
return out
|
154 |
+
|
155 |
+
|
156 |
+
class MaxPool3dTFPadding(th.nn.Module):
|
157 |
+
def __init__(self, kernel_size, stride=None, padding="SAME"):
|
158 |
+
super(MaxPool3dTFPadding, self).__init__()
|
159 |
+
if padding == "SAME":
|
160 |
+
padding_shape = self._get_padding_shape(kernel_size, stride)
|
161 |
+
self.padding_shape = padding_shape
|
162 |
+
self.pad = th.nn.ConstantPad3d(padding_shape, 0)
|
163 |
+
self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
|
164 |
+
|
165 |
+
def _get_padding_shape(self, filter_shape, stride):
|
166 |
+
def _pad_top_bottom(filter_dim, stride_val):
|
167 |
+
pad_along = max(filter_dim - stride_val, 0)
|
168 |
+
pad_top = pad_along // 2
|
169 |
+
pad_bottom = pad_along - pad_top
|
170 |
+
return pad_top, pad_bottom
|
171 |
+
|
172 |
+
padding_shape = []
|
173 |
+
for filter_dim, stride_val in zip(filter_shape, stride):
|
174 |
+
pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
|
175 |
+
padding_shape.append(pad_top)
|
176 |
+
padding_shape.append(pad_bottom)
|
177 |
+
depth_top = padding_shape.pop(0)
|
178 |
+
depth_bottom = padding_shape.pop(0)
|
179 |
+
padding_shape.append(depth_top)
|
180 |
+
padding_shape.append(depth_bottom)
|
181 |
+
return tuple(padding_shape)
|
182 |
+
|
183 |
+
def forward(self, inp):
|
184 |
+
inp = self.pad(inp)
|
185 |
+
out = self.pool(inp)
|
186 |
+
return out
|
187 |
+
|
188 |
+
|
189 |
+
class Sentence_Embedding(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
embd_dim,
|
193 |
+
num_embeddings=66250,
|
194 |
+
word_embedding_dim=300,
|
195 |
+
token_to_word_path="dict.npy",
|
196 |
+
max_words=16,
|
197 |
+
output_dim=2048,
|
198 |
+
):
|
199 |
+
super(Sentence_Embedding, self).__init__()
|
200 |
+
self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim)
|
201 |
+
self.fc1 = nn.Linear(word_embedding_dim, output_dim)
|
202 |
+
self.fc2 = nn.Linear(output_dim, embd_dim)
|
203 |
+
self.word_to_token = {}
|
204 |
+
self.max_words = max_words
|
205 |
+
token_to_word = np.load(token_to_word_path)
|
206 |
+
for i, t in enumerate(token_to_word):
|
207 |
+
self.word_to_token[t] = i + 1
|
208 |
+
|
209 |
+
def _zero_pad_tensor_token(self, tensor, size):
|
210 |
+
if len(tensor) >= size:
|
211 |
+
return tensor[:size]
|
212 |
+
else:
|
213 |
+
zero = th.zeros(size - len(tensor)).long()
|
214 |
+
return th.cat((tensor, zero), dim=0)
|
215 |
+
|
216 |
+
def _split_text(self, sentence):
|
217 |
+
w = re.findall(r"[\w']+", str(sentence))
|
218 |
+
return w
|
219 |
+
|
220 |
+
def _words_to_token(self, words):
|
221 |
+
words = [
|
222 |
+
self.word_to_token[word] for word in words if word in self.word_to_token
|
223 |
+
]
|
224 |
+
if words:
|
225 |
+
we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words)
|
226 |
+
return we
|
227 |
+
else:
|
228 |
+
return th.zeros(self.max_words).long()
|
229 |
+
|
230 |
+
def _words_to_ids(self, x):
|
231 |
+
split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x]
|
232 |
+
return th.stack(split_x, dim=0)
|
233 |
+
|
234 |
+
def forward(self, x):
|
235 |
+
x = self._words_to_ids(x)
|
236 |
+
x = self.word_embd(x)
|
237 |
+
x = F.relu(self.fc1(x))
|
238 |
+
x = th.max(x, dim=1)[0]
|
239 |
+
x = self.fc2(x)
|
240 |
+
return {'text_embedding': x}
|
241 |
+
|
242 |
+
|
243 |
+
class S3D(nn.Module):
|
244 |
+
def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True):
|
245 |
+
super(S3D, self).__init__()
|
246 |
+
self.num_classes = num_classes
|
247 |
+
self.gating = gating
|
248 |
+
self.space_to_depth = space_to_depth
|
249 |
+
if space_to_depth:
|
250 |
+
self.conv1 = STConv3D(
|
251 |
+
24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
self.conv1 = STConv3D(
|
255 |
+
3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False
|
256 |
+
)
|
257 |
+
self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False)
|
258 |
+
self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True)
|
259 |
+
self.gating = SelfGating(192)
|
260 |
+
self.maxpool_2a = MaxPool3dTFPadding(
|
261 |
+
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
|
262 |
+
)
|
263 |
+
self.maxpool_3a = MaxPool3dTFPadding(
|
264 |
+
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
|
265 |
+
)
|
266 |
+
self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
|
267 |
+
self.mixed_3c = InceptionBlock(
|
268 |
+
self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64
|
269 |
+
)
|
270 |
+
self.maxpool_4a = MaxPool3dTFPadding(
|
271 |
+
kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME"
|
272 |
+
)
|
273 |
+
self.mixed_4b = InceptionBlock(
|
274 |
+
self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64
|
275 |
+
)
|
276 |
+
self.mixed_4c = InceptionBlock(
|
277 |
+
self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64
|
278 |
+
)
|
279 |
+
self.mixed_4d = InceptionBlock(
|
280 |
+
self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64
|
281 |
+
)
|
282 |
+
self.mixed_4e = InceptionBlock(
|
283 |
+
self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64
|
284 |
+
)
|
285 |
+
self.mixed_4f = InceptionBlock(
|
286 |
+
self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128
|
287 |
+
)
|
288 |
+
self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
|
289 |
+
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME"
|
290 |
+
)
|
291 |
+
self.mixed_5b = InceptionBlock(
|
292 |
+
self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128
|
293 |
+
)
|
294 |
+
self.mixed_5c = InceptionBlock(
|
295 |
+
self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128
|
296 |
+
)
|
297 |
+
self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes)
|
298 |
+
self.text_module = Sentence_Embedding(num_classes,
|
299 |
+
token_to_word_path=dict_path)
|
300 |
+
|
301 |
+
def _space_to_depth(self, input):
|
302 |
+
"""3D space to depth trick for TPU optimization.
|
303 |
+
"""
|
304 |
+
B, C, T, H, W = input.shape
|
305 |
+
input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2)
|
306 |
+
input = input.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
307 |
+
input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2)
|
308 |
+
return input
|
309 |
+
|
310 |
+
def forward(self, inputs):
|
311 |
+
"""Defines the S3DG base architecture."""
|
312 |
+
if self.space_to_depth:
|
313 |
+
inputs = self._space_to_depth(inputs)
|
314 |
+
net = self.conv1(inputs)
|
315 |
+
if self.space_to_depth:
|
316 |
+
# we need to replicate 'SAME' tensorflow padding
|
317 |
+
net = net[:, :, 1:, 1:, 1:]
|
318 |
+
net = self.maxpool_2a(net)
|
319 |
+
net = self.conv_2b(net)
|
320 |
+
net = self.conv_2c(net)
|
321 |
+
if self.gating:
|
322 |
+
net = self.gating(net)
|
323 |
+
net = self.maxpool_3a(net)
|
324 |
+
net = self.mixed_3b(net)
|
325 |
+
net = self.mixed_3c(net)
|
326 |
+
net = self.maxpool_4a(net)
|
327 |
+
net = self.mixed_4b(net)
|
328 |
+
net = self.mixed_4c(net)
|
329 |
+
net = self.mixed_4d(net)
|
330 |
+
net = self.mixed_4e(net)
|
331 |
+
net = self.mixed_4f(net)
|
332 |
+
net = self.maxpool_5a(net)
|
333 |
+
net = self.mixed_5b(net)
|
334 |
+
net = self.mixed_5c(net)
|
335 |
+
net = th.mean(net, dim=[2, 3, 4])
|
336 |
+
return {'video_embedding': self.fc(net), 'mixed_5c': net}
|
fairseq/examples/MMPT/mmpt/processors/processor.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class Processor(object):
|
9 |
+
"""
|
10 |
+
A generic processor for video (codec, feature etc.) and text.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __call__(self, **kwargs):
|
14 |
+
raise NotImplementedError
|
15 |
+
|
16 |
+
|
17 |
+
class MetaProcessor(Processor):
|
18 |
+
"""
|
19 |
+
A meta processor is expected to load the metadata of a dataset:
|
20 |
+
(e.g., video_ids, or captions).
|
21 |
+
You must implement the `__getitem__` (meta datasets are rather diverse.).
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, config):
|
25 |
+
self.split = config.split
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.data)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
def _get_split_path(self, config):
|
34 |
+
splits = {
|
35 |
+
"train": config.train_path,
|
36 |
+
"valid": config.val_path,
|
37 |
+
"test": config.test_path,
|
38 |
+
}
|
39 |
+
if config.split is not None:
|
40 |
+
return splits[config.split]
|
41 |
+
return config.train_path
|
42 |
+
|
43 |
+
|
44 |
+
class TextProcessor(Processor):
|
45 |
+
"""
|
46 |
+
A generic Text processor: rename this as `withTokenizer`.
|
47 |
+
tokenize a string of text on-the-fly.
|
48 |
+
Warning: mostly used for end tasks.
|
49 |
+
(on-the-fly tokenization is slow for how2.)
|
50 |
+
TODO(huxu): move this class as a subclass.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self, config):
|
54 |
+
self.bert_name = str(config.bert_name)
|
55 |
+
self.use_fast = config.use_fast
|
56 |
+
from transformers import AutoTokenizer
|
57 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
58 |
+
self.bert_name, use_fast=self.use_fast
|
59 |
+
)
|
60 |
+
|
61 |
+
def __call__(self, text_id):
|
62 |
+
caption = self.tokenizer(text_id, add_special_tokens=False)
|
63 |
+
return caption["input_ids"]
|
64 |
+
|
65 |
+
|
66 |
+
class VideoProcessor(Processor):
|
67 |
+
"""
|
68 |
+
A generic video processor: load a numpy video tokens by default.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, config):
|
72 |
+
self.vfeat_dir = config.vfeat_dir
|
73 |
+
|
74 |
+
def __call__(self, video_fn):
|
75 |
+
if isinstance(video_fn, tuple):
|
76 |
+
video_fn = video_fn[0]
|
77 |
+
assert isinstance(video_fn, str)
|
78 |
+
video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy")
|
79 |
+
feat = np.load(video_fn)
|
80 |
+
return feat
|
81 |
+
|
82 |
+
|
83 |
+
class Aligner(object):
|
84 |
+
"""
|
85 |
+
An alignprocessor align video and text and output a dict of tensors (for a model).
|
86 |
+
"""
|
87 |
+
def __init__(self, config):
|
88 |
+
"""__init__ needs to be light weight for more workers/threads."""
|
89 |
+
self.split = config.split
|
90 |
+
self.max_video_len = config.max_video_len
|
91 |
+
self.max_len = config.max_len
|
92 |
+
from transformers import AutoTokenizer
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
94 |
+
str(config.bert_name), use_fast=config.use_fast
|
95 |
+
)
|
96 |
+
self.cls_token_id = tokenizer.cls_token_id
|
97 |
+
self.sep_token_id = tokenizer.sep_token_id
|
98 |
+
self.pad_token_id = tokenizer.pad_token_id
|
99 |
+
self.mask_token_id = tokenizer.mask_token_id
|
100 |
+
|
101 |
+
def __call__(self, video_id, video_feature, text_feature):
|
102 |
+
raise NotImplementedError
|
103 |
+
|
104 |
+
def _build_video_seq(self, video_feature, video_clips=None):
|
105 |
+
"""
|
106 |
+
`video_feature`: available video tokens.
|
107 |
+
`video_clips`: video clip sequence to build.
|
108 |
+
"""
|
109 |
+
if not isinstance(video_feature, np.ndarray):
|
110 |
+
raise ValueError(
|
111 |
+
"unsupported type of video_feature", type(video_feature)
|
112 |
+
)
|
113 |
+
|
114 |
+
if video_clips is None:
|
115 |
+
# this is borrowed from DSAligner
|
116 |
+
video_start = 0
|
117 |
+
video_end = min(len(video_feature), self.max_video_len)
|
118 |
+
# the whole sequence is a single clip.
|
119 |
+
video_clips = {"start": [video_start], "end": [video_end]}
|
120 |
+
|
121 |
+
vfeats = np.zeros(
|
122 |
+
(self.max_video_len, video_feature.shape[1]), dtype=np.float32
|
123 |
+
)
|
124 |
+
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
|
125 |
+
video_len = 0
|
126 |
+
for start, end in zip(video_clips["start"], video_clips["end"]):
|
127 |
+
clip_len = min(self.max_video_len - video_len, (end - start))
|
128 |
+
if clip_len > 0:
|
129 |
+
vfeats[video_len: video_len + clip_len] = video_feature[
|
130 |
+
start: start + clip_len
|
131 |
+
]
|
132 |
+
vmasks[video_len: video_len + clip_len] = 1
|
133 |
+
video_len += clip_len
|
134 |
+
vfeats = torch.from_numpy(vfeats)
|
135 |
+
|
136 |
+
return vfeats, vmasks
|
137 |
+
|
138 |
+
def _build_text_seq(self, text_feature, text_clip_indexs=None):
|
139 |
+
"""
|
140 |
+
`text_feature`: all available clips.
|
141 |
+
`text_clip_indexes`: clip sequence to build.
|
142 |
+
"""
|
143 |
+
if text_clip_indexs is None:
|
144 |
+
text_clip_indexs = [0]
|
145 |
+
|
146 |
+
full_caps = []
|
147 |
+
if isinstance(text_feature, dict):
|
148 |
+
for clip_idx in text_clip_indexs:
|
149 |
+
full_caps.extend(text_feature["cap"][clip_idx])
|
150 |
+
else:
|
151 |
+
full_caps = text_feature
|
152 |
+
max_text_len = self.max_len - self.max_video_len - 3
|
153 |
+
full_caps = full_caps[:max_text_len]
|
154 |
+
full_caps = (
|
155 |
+
[self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id]
|
156 |
+
)
|
157 |
+
text_pad_len = self.max_len - len(full_caps) - self.max_video_len
|
158 |
+
padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len
|
159 |
+
caps = torch.LongTensor(padded_full_caps)
|
160 |
+
cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool)
|
161 |
+
cmasks[: len(full_caps)] = 1
|
162 |
+
|
163 |
+
return caps, cmasks
|
164 |
+
|
165 |
+
def batch_post_processing(self, batch, video_feature):
|
166 |
+
return batch
|
167 |
+
|
168 |
+
|
169 |
+
class MMAttentionMask2DProcessor(Processor):
|
170 |
+
"""text generation requires 2d mask
|
171 |
+
that is harder to generate by GPU at this stage."""
|
172 |
+
|
173 |
+
def __call__(self, vmask, cmask, mtype):
|
174 |
+
if mtype == "textgen":
|
175 |
+
return self._build_textgeneration_mask(vmask, cmask)
|
176 |
+
elif mtype == "videogen":
|
177 |
+
return self._build_videogeneration_mask(vmask, cmask)
|
178 |
+
else:
|
179 |
+
return self._build_mm_mask(vmask, cmask)
|
180 |
+
|
181 |
+
def _build_mm_mask(self, vmask, cmask):
|
182 |
+
mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)
|
183 |
+
return mask_1d[None, :].repeat(mask_1d.size(0), 1)
|
184 |
+
|
185 |
+
def _build_videogeneration_mask(self, vmask, cmask):
|
186 |
+
# cls_mask is only about text otherwise it will leak generation.
|
187 |
+
cls_text_mask = torch.cat([
|
188 |
+
# [CLS]
|
189 |
+
torch.ones(
|
190 |
+
(1,), dtype=torch.bool, device=cmask.device),
|
191 |
+
# video tokens and [SEP] for video.
|
192 |
+
torch.zeros(
|
193 |
+
(vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device),
|
194 |
+
cmask[2:]
|
195 |
+
], dim=0)
|
196 |
+
|
197 |
+
# concat horizontially.
|
198 |
+
video_len = int(vmask.sum())
|
199 |
+
video_masks = torch.cat([
|
200 |
+
# [CLS]
|
201 |
+
torch.ones(
|
202 |
+
(video_len, 1), dtype=torch.bool, device=cmask.device
|
203 |
+
),
|
204 |
+
torch.tril(
|
205 |
+
torch.ones(
|
206 |
+
(video_len, video_len),
|
207 |
+
dtype=torch.bool, device=cmask.device)),
|
208 |
+
# video_padding
|
209 |
+
torch.zeros(
|
210 |
+
(video_len, vmask.size(0) - video_len),
|
211 |
+
dtype=torch.bool, device=cmask.device
|
212 |
+
),
|
213 |
+
# [SEP] for video (unused).
|
214 |
+
torch.zeros(
|
215 |
+
(video_len, 1), dtype=torch.bool, device=cmask.device
|
216 |
+
),
|
217 |
+
cmask[2:].unsqueeze(0).repeat(video_len, 1)
|
218 |
+
], dim=1)
|
219 |
+
|
220 |
+
text_masks = cls_text_mask[None, :].repeat(
|
221 |
+
cmask.size(0) - 2, 1)
|
222 |
+
video_padding_masks = cls_text_mask[None, :].repeat(
|
223 |
+
vmask.size(0) - video_len, 1)
|
224 |
+
|
225 |
+
return torch.cat([
|
226 |
+
cls_text_mask[None, :],
|
227 |
+
video_masks,
|
228 |
+
video_padding_masks,
|
229 |
+
torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:],
|
230 |
+
text_masks
|
231 |
+
], dim=0)
|
232 |
+
|
233 |
+
def _build_textgeneration_mask(self, vmask, cmask):
|
234 |
+
# cls_mask is only about video otherwise it will leak generation.
|
235 |
+
cls_video_mask = torch.cat([
|
236 |
+
# [CLS]
|
237 |
+
torch.ones(
|
238 |
+
(1,), dtype=torch.bool, device=cmask.device),
|
239 |
+
vmask,
|
240 |
+
# [SEP]
|
241 |
+
torch.ones((1,), dtype=torch.bool, device=cmask.device),
|
242 |
+
torch.zeros(
|
243 |
+
(cmask.size(0)-2,), dtype=torch.bool, device=cmask.device)
|
244 |
+
], dim=0)
|
245 |
+
|
246 |
+
# concat horizontially.
|
247 |
+
text_len = int(cmask[2:].sum())
|
248 |
+
text_masks = torch.cat([
|
249 |
+
# [CLS]
|
250 |
+
torch.ones(
|
251 |
+
(text_len, 1), dtype=torch.bool, device=cmask.device
|
252 |
+
),
|
253 |
+
vmask.unsqueeze(0).repeat(text_len, 1),
|
254 |
+
# [SEP] for video.
|
255 |
+
torch.ones(
|
256 |
+
(text_len, 1), dtype=torch.bool, device=cmask.device
|
257 |
+
),
|
258 |
+
torch.tril(
|
259 |
+
torch.ones(
|
260 |
+
(text_len, text_len),
|
261 |
+
dtype=torch.bool, device=cmask.device)),
|
262 |
+
# padding.
|
263 |
+
torch.zeros(
|
264 |
+
(text_len, cmask.size(0) - text_len - 2),
|
265 |
+
dtype=torch.bool, device=cmask.device
|
266 |
+
)
|
267 |
+
], dim=1)
|
268 |
+
|
269 |
+
cls_video_masks = cls_video_mask[None, :].repeat(
|
270 |
+
vmask.size(0) + 2, 1)
|
271 |
+
text_padding_masks = cls_video_mask[None, :].repeat(
|
272 |
+
cmask.size(0) - text_len - 2, 1)
|
273 |
+
return torch.cat([
|
274 |
+
cls_video_masks, text_masks, text_padding_masks], dim=0)
|
fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
make a general fairseq task for MM pretraining.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import random
|
10 |
+
|
11 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
12 |
+
|
13 |
+
from .task import Task
|
14 |
+
from .retritask import RetriTask
|
15 |
+
from ..datasets import FairseqMMDataset
|
16 |
+
from .. import utils
|
17 |
+
|
18 |
+
|
19 |
+
@register_task("mmtask")
|
20 |
+
class FairseqMMTask(LegacyFairseqTask):
|
21 |
+
@staticmethod
|
22 |
+
def add_args(parser):
|
23 |
+
# Add some command-line arguments for specifying where the data is
|
24 |
+
# located and the maximum supported input length.
|
25 |
+
parser.add_argument(
|
26 |
+
"taskconfig",
|
27 |
+
metavar="FILE",
|
28 |
+
help=("taskconfig to load all configurations" "outside fairseq parser."),
|
29 |
+
)
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def setup_task(cls, args, **kwargs):
|
33 |
+
return FairseqMMTask(args)
|
34 |
+
|
35 |
+
def __init__(self, args):
|
36 |
+
super().__init__(args)
|
37 |
+
config = utils.load_config(args)
|
38 |
+
self.mmtask = Task.config_task(config)
|
39 |
+
self.mmtask.build_dataset()
|
40 |
+
self.mmtask.build_model()
|
41 |
+
self.mmtask.build_loss()
|
42 |
+
|
43 |
+
def load_dataset(self, split, **kwargs):
|
44 |
+
split_map = {
|
45 |
+
"train": self.mmtask.train_data,
|
46 |
+
"valid": self.mmtask.val_data,
|
47 |
+
"test": self.mmtask.test_data,
|
48 |
+
}
|
49 |
+
if split not in split_map:
|
50 |
+
raise ValueError("unknown split type.")
|
51 |
+
if split_map[split] is not None:
|
52 |
+
self.datasets[split] = FairseqMMDataset(split_map[split])
|
53 |
+
|
54 |
+
def get_batch_iterator(
|
55 |
+
self,
|
56 |
+
dataset,
|
57 |
+
max_tokens=None,
|
58 |
+
max_sentences=None,
|
59 |
+
max_positions=None,
|
60 |
+
ignore_invalid_inputs=False,
|
61 |
+
required_batch_size_multiple=1,
|
62 |
+
seed=1,
|
63 |
+
num_shards=1,
|
64 |
+
shard_id=0,
|
65 |
+
num_workers=0,
|
66 |
+
epoch=1,
|
67 |
+
data_buffer_size=0,
|
68 |
+
disable_iterator_cache=False,
|
69 |
+
skip_remainder_batch=False,
|
70 |
+
grouped_shuffling=False,
|
71 |
+
update_epoch_batch_itr=False,
|
72 |
+
):
|
73 |
+
random.seed(epoch)
|
74 |
+
if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
|
75 |
+
if epoch >= self.mmtask.config.retri_epoch:
|
76 |
+
if not hasattr(self.mmtask, "retri_dataloader"):
|
77 |
+
self.mmtask.build_dataloader()
|
78 |
+
self.mmtask.retrive_candidates(epoch)
|
79 |
+
|
80 |
+
return super().get_batch_iterator(
|
81 |
+
dataset,
|
82 |
+
max_tokens,
|
83 |
+
max_sentences,
|
84 |
+
max_positions,
|
85 |
+
ignore_invalid_inputs,
|
86 |
+
required_batch_size_multiple,
|
87 |
+
seed,
|
88 |
+
num_shards,
|
89 |
+
shard_id,
|
90 |
+
num_workers,
|
91 |
+
epoch,
|
92 |
+
data_buffer_size,
|
93 |
+
disable_iterator_cache,
|
94 |
+
grouped_shuffling,
|
95 |
+
update_epoch_batch_itr,
|
96 |
+
)
|
97 |
+
|
98 |
+
@property
|
99 |
+
def source_dictionary(self):
|
100 |
+
return None
|
101 |
+
|
102 |
+
@property
|
103 |
+
def target_dictionary(self):
|
104 |
+
return None
|
fairseq/examples/MMPT/mmpt/tasks/milncetask.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .task import Task
|
9 |
+
|
10 |
+
|
11 |
+
class MILNCETask(Task):
|
12 |
+
def reshape_subsample(self, sample):
|
13 |
+
if (
|
14 |
+
hasattr(self.config.dataset, "subsampling")
|
15 |
+
and self.config.dataset.subsampling is not None
|
16 |
+
and self.config.dataset.subsampling > 1
|
17 |
+
):
|
18 |
+
for key in sample:
|
19 |
+
if torch.is_tensor(sample[key]):
|
20 |
+
tensor = self.flat_subsample(sample[key])
|
21 |
+
if key in ["caps", "cmasks"]:
|
22 |
+
size = tensor.size()
|
23 |
+
batch_size = size[0] * size[1]
|
24 |
+
expanded_size = (batch_size,) + size[2:]
|
25 |
+
tensor = tensor.view(expanded_size)
|
26 |
+
sample[key] = tensor
|
27 |
+
return sample
|
fairseq/examples/MMPT/mmpt/tasks/retritask.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
import random
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
|
14 |
+
from ..processors import (
|
15 |
+
ShardedHow2MetaProcessor,
|
16 |
+
ShardedVideoProcessor,
|
17 |
+
ShardedTextProcessor,
|
18 |
+
VariedLenAligner,
|
19 |
+
)
|
20 |
+
|
21 |
+
from ..datasets import MMDataset
|
22 |
+
from .task import Task
|
23 |
+
from ..modules import vectorpool
|
24 |
+
from ..evaluators.predictor import Predictor
|
25 |
+
from ..utils import set_seed, get_local_rank, get_world_size
|
26 |
+
|
27 |
+
|
28 |
+
class RetriTask(Task):
|
29 |
+
"""abstract class for task with retrival."""
|
30 |
+
|
31 |
+
def reshape_subsample(self, sample):
|
32 |
+
for key in sample:
|
33 |
+
if torch.is_tensor(sample[key]):
|
34 |
+
sample[key] = self.flat_subsample(sample[key])
|
35 |
+
return sample
|
36 |
+
|
37 |
+
def flat_subsample(self, tensor):
|
38 |
+
if tensor.size(0) == 1:
|
39 |
+
tensor = tensor.squeeze(0)
|
40 |
+
return tensor
|
41 |
+
|
42 |
+
def build_dataloader(self):
|
43 |
+
"""called by `get_batch_iterator` in fairseqmmtask. """
|
44 |
+
# TODO: hard-code dataloader for retri for now and configurable in .yaml.
|
45 |
+
# reuse the `train.lst`.
|
46 |
+
self.config.dataset.split = "train"
|
47 |
+
meta_processor = ShardedHow2MetaProcessor(self.config.dataset)
|
48 |
+
video_processor = ShardedVideoProcessor(self.config.dataset)
|
49 |
+
text_processor = ShardedTextProcessor(self.config.dataset)
|
50 |
+
|
51 |
+
aligner = VariedLenAligner(self.config.dataset)
|
52 |
+
aligner.subsampling = self.config.dataset.clip_per_video
|
53 |
+
|
54 |
+
self.retri_data = MMDataset(
|
55 |
+
meta_processor, video_processor, text_processor, aligner
|
56 |
+
)
|
57 |
+
|
58 |
+
retri_sampler = DistributedSampler(self.retri_data)
|
59 |
+
infer_scale = 16
|
60 |
+
batch_size = self.config.dataset.num_video_per_batch \
|
61 |
+
* infer_scale
|
62 |
+
|
63 |
+
self.retri_dataloader = DataLoader(
|
64 |
+
self.retri_data,
|
65 |
+
collate_fn=self.retri_data.collater,
|
66 |
+
batch_size=batch_size,
|
67 |
+
shuffle=False,
|
68 |
+
sampler=retri_sampler,
|
69 |
+
num_workers=self.config.fairseq.dataset.num_workers
|
70 |
+
)
|
71 |
+
return self.retri_dataloader
|
72 |
+
|
73 |
+
def retrive_candidates(self, epoch, dataloader=None):
|
74 |
+
if get_local_rank() == 0:
|
75 |
+
print("running retrieval model.")
|
76 |
+
out_dir = os.path.join(
|
77 |
+
self.config.fairseq.checkpoint.save_dir, "retri")
|
78 |
+
os.makedirs(out_dir, exist_ok=True)
|
79 |
+
|
80 |
+
if not os.path.isfile(
|
81 |
+
os.path.join(
|
82 |
+
out_dir, "batched_e" + str(epoch) + "_videos0.pkl")
|
83 |
+
):
|
84 |
+
if dataloader is None:
|
85 |
+
dataloader = self.retri_dataloader
|
86 |
+
|
87 |
+
self.model.eval()
|
88 |
+
self.model.is_train = False
|
89 |
+
|
90 |
+
assert self.retri_data.meta_processor.data == \
|
91 |
+
self.train_data.meta_processor.data # video_ids not mutated.
|
92 |
+
|
93 |
+
self._retri_predict(epoch, dataloader)
|
94 |
+
|
95 |
+
self.model.train()
|
96 |
+
self.model.is_train = True
|
97 |
+
|
98 |
+
torch.distributed.barrier()
|
99 |
+
output = self._retri_sync(epoch, out_dir)
|
100 |
+
torch.distributed.barrier()
|
101 |
+
self.train_data.meta_processor.set_candidates(output)
|
102 |
+
return output
|
103 |
+
|
104 |
+
|
105 |
+
class VideoRetriTask(RetriTask):
|
106 |
+
"""RetriTask on video level."""
|
107 |
+
|
108 |
+
def reshape_subsample(self, sample):
|
109 |
+
if (
|
110 |
+
hasattr(self.config.dataset, "clip_per_video")
|
111 |
+
and self.config.dataset.clip_per_video is not None
|
112 |
+
and self.config.dataset.clip_per_video > 1
|
113 |
+
):
|
114 |
+
for key in sample:
|
115 |
+
if torch.is_tensor(sample[key]):
|
116 |
+
sample[key] = self.flat_subsample(sample[key])
|
117 |
+
return sample
|
118 |
+
|
119 |
+
def flat_subsample(self, tensor):
|
120 |
+
if tensor.size(0) == 1:
|
121 |
+
tensor = tensor.squeeze(0)
|
122 |
+
return Task.flat_subsample(self, tensor)
|
123 |
+
|
124 |
+
def _retri_predict(self, epoch, dataloader):
|
125 |
+
set_seed(epoch)
|
126 |
+
# save for retrival.
|
127 |
+
predictor = VideoPredictor(self.config)
|
128 |
+
predictor.predict_loop(
|
129 |
+
self.model, dataloader)
|
130 |
+
set_seed(epoch) # get the same text clips.
|
131 |
+
# retrival.
|
132 |
+
retri_predictor = VideoRetriPredictor(
|
133 |
+
self.config)
|
134 |
+
retri_predictor.predict_loop(
|
135 |
+
self.model, predictor.vecpool.retriver, epoch)
|
136 |
+
del predictor
|
137 |
+
del retri_predictor
|
138 |
+
|
139 |
+
def _retri_sync(self, epoch, out_dir):
|
140 |
+
# gpu do the same merge.
|
141 |
+
batched_videos = []
|
142 |
+
for local_rank in range(get_world_size()):
|
143 |
+
fn = os.path.join(
|
144 |
+
out_dir,
|
145 |
+
"batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl")
|
146 |
+
with open(fn, "rb") as fr:
|
147 |
+
batched_videos.extend(pickle.load(fr))
|
148 |
+
print(
|
149 |
+
"[INFO] batched_videos",
|
150 |
+
len(batched_videos), len(batched_videos[0]))
|
151 |
+
return batched_videos
|
152 |
+
|
153 |
+
|
154 |
+
class VideoPredictor(Predictor):
|
155 |
+
def __init__(self, config):
|
156 |
+
vectorpool_cls = getattr(vectorpool, config.vectorpool_cls)
|
157 |
+
self.vecpool = vectorpool_cls(config)
|
158 |
+
|
159 |
+
def predict_loop(
|
160 |
+
self,
|
161 |
+
model,
|
162 |
+
dataloader,
|
163 |
+
early_stop=-1,
|
164 |
+
):
|
165 |
+
with torch.no_grad():
|
166 |
+
if get_local_rank() == 0:
|
167 |
+
dataloader = tqdm(dataloader)
|
168 |
+
for batch_idx, batch in enumerate(dataloader):
|
169 |
+
if batch_idx == early_stop:
|
170 |
+
break
|
171 |
+
self(batch, model)
|
172 |
+
return self.finalize()
|
173 |
+
|
174 |
+
def __call__(self, sample, model, **kwargs):
|
175 |
+
param = next(model.parameters())
|
176 |
+
dtype = param.dtype
|
177 |
+
device = param.device
|
178 |
+
subsample = sample["vfeats"].size(1)
|
179 |
+
sample = self.to_ctx(sample, device, dtype)
|
180 |
+
for key in sample:
|
181 |
+
if torch.is_tensor(sample[key]):
|
182 |
+
size = sample[key].size()
|
183 |
+
if len(size) >= 2:
|
184 |
+
batch_size = size[0] * size[1]
|
185 |
+
expanded_size = (
|
186 |
+
(batch_size,) + size[2:] if len(size) > 2
|
187 |
+
else (batch_size,)
|
188 |
+
)
|
189 |
+
sample[key] = sample[key].view(expanded_size)
|
190 |
+
|
191 |
+
outputs = model(**sample)
|
192 |
+
sample.update(outputs)
|
193 |
+
self.vecpool(sample, subsample)
|
194 |
+
|
195 |
+
def finalize(self):
|
196 |
+
print("[INFO]", self.vecpool)
|
197 |
+
if not self.vecpool.retriver.db.is_trained:
|
198 |
+
self.vecpool.retriver.finalize_training()
|
199 |
+
return self.vecpool.retriver
|
200 |
+
|
201 |
+
|
202 |
+
class VideoRetriPredictor(Predictor):
|
203 |
+
"""
|
204 |
+
Online Retrieval Predictor for Clips (used by RetriTask).
|
205 |
+
TODO: merge this with VisPredictor?
|
206 |
+
"""
|
207 |
+
|
208 |
+
def __init__(self, config):
|
209 |
+
self.pred_dir = os.path.join(
|
210 |
+
config.fairseq.checkpoint.save_dir,
|
211 |
+
"retri")
|
212 |
+
self.num_cands = config.num_cands
|
213 |
+
self.num_video_per_batch = config.dataset.num_video_per_batch
|
214 |
+
|
215 |
+
def predict_loop(
|
216 |
+
self,
|
217 |
+
model,
|
218 |
+
retriver,
|
219 |
+
epoch,
|
220 |
+
early_stop=-1
|
221 |
+
):
|
222 |
+
# a fake loop that only try to recover video vector
|
223 |
+
# from video_id.
|
224 |
+
batched_videos = []
|
225 |
+
# obtain available video_ids.
|
226 |
+
video_ids = list(retriver.videoid_to_vectoridx.keys())
|
227 |
+
|
228 |
+
dataloader = random.sample(
|
229 |
+
video_ids,
|
230 |
+
len(video_ids) // self.num_video_per_batch
|
231 |
+
)
|
232 |
+
|
233 |
+
if get_local_rank() == 0:
|
234 |
+
dataloader = tqdm(dataloader)
|
235 |
+
for batch_idx, batch in enumerate(dataloader):
|
236 |
+
# batch is one video id.
|
237 |
+
if batch_idx == early_stop:
|
238 |
+
break
|
239 |
+
video_ids = retriver.search_by_video_ids(
|
240 |
+
[batch], self.num_cands)[0]
|
241 |
+
if len(video_ids) > self.num_video_per_batch:
|
242 |
+
# we moved the center to make cluster robust.
|
243 |
+
video_ids = random.sample(video_ids, self.num_video_per_batch)
|
244 |
+
batched_videos.append(video_ids)
|
245 |
+
return self.finalize(batched_videos, epoch)
|
246 |
+
|
247 |
+
def finalize(self, batched_videos, epoch):
|
248 |
+
fn = os.path.join(
|
249 |
+
self.pred_dir,
|
250 |
+
"batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl")
|
251 |
+
with open(fn, "wb") as fw:
|
252 |
+
pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL)
|
253 |
+
return batched_videos
|
fairseq/examples/MMPT/mmpt/tasks/task.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .. import tasks
|
8 |
+
from .. import models
|
9 |
+
from .. import losses
|
10 |
+
from ..datasets import MMDataset
|
11 |
+
from .. import processors
|
12 |
+
|
13 |
+
|
14 |
+
class Task(object):
|
15 |
+
"""
|
16 |
+
A task refers to one generic training task (e.g., training one model).
|
17 |
+
"""
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def config_task(cls, config):
|
21 |
+
"""
|
22 |
+
determine whether to load a hard-coded task or config from a generic one.
|
23 |
+
via if a task string is available in config.
|
24 |
+
"""
|
25 |
+
if config.task is not None:
|
26 |
+
# TODO (huxu): expand the search scope.
|
27 |
+
task_cls = getattr(tasks, config.task)
|
28 |
+
return task_cls(config)
|
29 |
+
else:
|
30 |
+
return Task(config)
|
31 |
+
|
32 |
+
def __init__(self, config):
|
33 |
+
self.config = config
|
34 |
+
self.train_data = None
|
35 |
+
self.val_data = None
|
36 |
+
self.test_data = None
|
37 |
+
|
38 |
+
self.model = None
|
39 |
+
self.loss_fn = None
|
40 |
+
self.eval_fn = None
|
41 |
+
|
42 |
+
def build_dataset(self):
|
43 |
+
"""TODO (huxu): move processor breakdown to MMDataset."""
|
44 |
+
"""fill-in `self.train_data`, `self.val_data` and `self.test_data`."""
|
45 |
+
|
46 |
+
meta_processor_cls = getattr(
|
47 |
+
processors, self.config.dataset.meta_processor)
|
48 |
+
video_processor_cls = getattr(
|
49 |
+
processors, self.config.dataset.video_processor)
|
50 |
+
text_processor_cls = getattr(
|
51 |
+
processors, self.config.dataset.text_processor)
|
52 |
+
aligner_cls = getattr(
|
53 |
+
processors, self.config.dataset.aligner)
|
54 |
+
|
55 |
+
if self.config.dataset.train_path is not None:
|
56 |
+
self.config.dataset.split = "train"
|
57 |
+
# may be used by meta processor.
|
58 |
+
# meta_processor controls different dataset.
|
59 |
+
meta_processor = meta_processor_cls(self.config.dataset)
|
60 |
+
video_processor = video_processor_cls(self.config.dataset)
|
61 |
+
text_processor = text_processor_cls(self.config.dataset)
|
62 |
+
aligner = aligner_cls(self.config.dataset)
|
63 |
+
self.train_data = MMDataset(
|
64 |
+
meta_processor, video_processor, text_processor, aligner
|
65 |
+
)
|
66 |
+
print("train_len", len(self.train_data))
|
67 |
+
output = self.train_data[0]
|
68 |
+
self.train_data.print_example(output)
|
69 |
+
if self.config.dataset.val_path is not None:
|
70 |
+
self.config.dataset.split = "valid"
|
71 |
+
# may be used by meta processor.
|
72 |
+
meta_processor = meta_processor_cls(self.config.dataset)
|
73 |
+
video_processor = video_processor_cls(self.config.dataset)
|
74 |
+
text_processor = text_processor_cls(self.config.dataset)
|
75 |
+
aligner = aligner_cls(self.config.dataset)
|
76 |
+
self.val_data = MMDataset(
|
77 |
+
meta_processor, video_processor, text_processor, aligner
|
78 |
+
)
|
79 |
+
print("val_len", len(self.val_data))
|
80 |
+
output = self.val_data[0]
|
81 |
+
self.val_data.print_example(output)
|
82 |
+
|
83 |
+
if self.config.dataset.split == "test":
|
84 |
+
# the following is run via lauching fairseq-validate.
|
85 |
+
meta_processor = meta_processor_cls(self.config.dataset)
|
86 |
+
video_processor = video_processor_cls(self.config.dataset)
|
87 |
+
text_processor = text_processor_cls(self.config.dataset)
|
88 |
+
|
89 |
+
self.test_data = MMDataset(
|
90 |
+
meta_processor, video_processor, text_processor, aligner
|
91 |
+
)
|
92 |
+
print("test_len", len(self.test_data))
|
93 |
+
output = self.test_data[0]
|
94 |
+
self.test_data.print_example(output)
|
95 |
+
|
96 |
+
def build_model(self, checkpoint=None):
|
97 |
+
if self.model is None:
|
98 |
+
model_cls = getattr(models, self.config.model.model_cls)
|
99 |
+
self.model = model_cls(self.config)
|
100 |
+
if checkpoint is not None:
|
101 |
+
self.load_checkpoint(checkpoint)
|
102 |
+
return self.model
|
103 |
+
|
104 |
+
def load_checkpoint(self, checkpoint):
|
105 |
+
if self.model is None:
|
106 |
+
raise ValueError("model is not initialized.")
|
107 |
+
state_dict = torch.load(checkpoint)
|
108 |
+
state_dict = self._trim_state_dict(state_dict)
|
109 |
+
self.model.load_state_dict(state_dict, strict=False)
|
110 |
+
# if it's a fp16 model, turn it back.
|
111 |
+
if next(self.model.parameters()).dtype == torch.float16:
|
112 |
+
self.model = self.model.float()
|
113 |
+
return self.model
|
114 |
+
|
115 |
+
def _trim_state_dict(self, state_dict):
|
116 |
+
from collections import OrderedDict
|
117 |
+
|
118 |
+
if "state_dict" in state_dict:
|
119 |
+
state_dict = state_dict["state_dict"]
|
120 |
+
if "model" in state_dict: # fairseq checkpoint format.
|
121 |
+
state_dict = state_dict["model"]
|
122 |
+
ret_state_dict = OrderedDict()
|
123 |
+
for (
|
124 |
+
key,
|
125 |
+
value,
|
126 |
+
) in state_dict.items():
|
127 |
+
# remove fairseq wrapper since this is a task.
|
128 |
+
if key.startswith("mmmodel"):
|
129 |
+
key = key[len("mmmodel."):]
|
130 |
+
ret_state_dict[key] = value
|
131 |
+
return ret_state_dict
|
132 |
+
|
133 |
+
def build_loss(self):
|
134 |
+
if self.loss_fn is None and self.config.loss is not None:
|
135 |
+
loss_cls = getattr(losses, self.config.loss.loss_cls)
|
136 |
+
self.loss_fn = loss_cls()
|
137 |
+
return self.loss_fn
|
138 |
+
|
139 |
+
def flat_subsample(self, tensor):
|
140 |
+
size = tensor.size()
|
141 |
+
if len(size) >= 2:
|
142 |
+
batch_size = size[0] * size[1]
|
143 |
+
expanded_size = (
|
144 |
+
(batch_size,) + size[2:] if len(size) > 2
|
145 |
+
else (batch_size,)
|
146 |
+
)
|
147 |
+
tensor = tensor.view(expanded_size)
|
148 |
+
return tensor
|
149 |
+
|
150 |
+
def reshape_subsample(self, sample):
|
151 |
+
if (
|
152 |
+
hasattr(self.config.dataset, "subsampling")
|
153 |
+
and self.config.dataset.subsampling is not None
|
154 |
+
and self.config.dataset.subsampling > 1
|
155 |
+
):
|
156 |
+
for key in sample:
|
157 |
+
if torch.is_tensor(sample[key]):
|
158 |
+
sample[key] = self.flat_subsample(sample[key])
|
159 |
+
return sample
|
160 |
+
|
161 |
+
def __call__(self, model, sample):
|
162 |
+
loss = None
|
163 |
+
loss_scalar = float("inf")
|
164 |
+
|
165 |
+
sample = self.reshape_subsample(sample)
|
166 |
+
outputs = self.model(**sample)
|
167 |
+
sample.update(outputs)
|
168 |
+
if self.loss_fn is not None:
|
169 |
+
loss = self.loss_fn(**sample)
|
170 |
+
loss_scalar = loss.item()
|
171 |
+
|
172 |
+
batch_size = sample["caps"].size(0)
|
173 |
+
sample_size = 1
|
174 |
+
return {
|
175 |
+
"loss": loss,
|
176 |
+
"loss_scalar": loss_scalar,
|
177 |
+
"max_len": self.config.dataset.max_len,
|
178 |
+
"batch_size": batch_size,
|
179 |
+
"sample_size": sample_size,
|
180 |
+
}
|
181 |
+
|
182 |
+
def build_dataloader(self):
|
183 |
+
"""only used for trainer that lacks building loaders."""
|
184 |
+
raise NotImplementedError
|
fairseq/examples/MMPT/mmpt/tasks/vlmtask.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .task import Task
|
8 |
+
|
9 |
+
|
10 |
+
class VLMTask(Task):
|
11 |
+
"""A VLM task for reproducibility.
|
12 |
+
the collator split subsamples into two sub-batches.
|
13 |
+
This has should have no logic changes.
|
14 |
+
but changed the randomness in frame masking.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def flat_subsample(self, tensor):
|
18 |
+
size = tensor.size()
|
19 |
+
if len(size) >= 2:
|
20 |
+
batch_size = size[0] * (size[1] // 2)
|
21 |
+
expanded_size = (
|
22 |
+
(batch_size, 2) + size[2:] if len(size) > 2
|
23 |
+
else (batch_size, 2)
|
24 |
+
)
|
25 |
+
tensor = tensor.view(expanded_size)
|
26 |
+
tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0)
|
27 |
+
return tensor
|
fairseq/examples/MMPT/mmpt/utils/__init__.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .shardedtensor import *
|
10 |
+
from .load_config import *
|
11 |
+
|
12 |
+
|
13 |
+
def set_seed(seed=43211):
|
14 |
+
random.seed(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
if torch.backends.cudnn.enabled:
|
19 |
+
torch.backends.cudnn.benchmark = False
|
20 |
+
torch.backends.cudnn.deterministic = True
|
21 |
+
|
22 |
+
|
23 |
+
def get_world_size():
|
24 |
+
if torch.distributed.is_initialized():
|
25 |
+
world_size = torch.distributed.get_world_size()
|
26 |
+
else:
|
27 |
+
world_size = 1
|
28 |
+
return world_size
|
29 |
+
|
30 |
+
|
31 |
+
def get_local_rank():
|
32 |
+
return torch.distributed.get_rank() \
|
33 |
+
if torch.distributed.is_initialized() else 0
|
34 |
+
|
35 |
+
|
36 |
+
def print_on_rank0(func):
|
37 |
+
local_rank = get_local_rank()
|
38 |
+
if local_rank == 0:
|
39 |
+
print("[INFO]", func)
|
40 |
+
|
41 |
+
|
42 |
+
class RetriMeter(object):
|
43 |
+
"""
|
44 |
+
Statistics on whether retrieval yields a better pair.
|
45 |
+
"""
|
46 |
+
def __init__(self, freq=1024):
|
47 |
+
self.freq = freq
|
48 |
+
self.total = 0
|
49 |
+
self.replace = 0
|
50 |
+
self.updates = 0
|
51 |
+
|
52 |
+
def __call__(self, data):
|
53 |
+
if isinstance(data, np.ndarray):
|
54 |
+
self.replace += data.shape[0] - int((data[:, 0] == -1).sum())
|
55 |
+
self.total += data.shape[0]
|
56 |
+
elif torch.is_tensor(data):
|
57 |
+
self.replace += int(data.sum())
|
58 |
+
self.total += data.size(0)
|
59 |
+
else:
|
60 |
+
raise ValueError("unsupported RetriMeter data type.", type(data))
|
61 |
+
|
62 |
+
self.updates += 1
|
63 |
+
if get_local_rank() == 0 and self.updates % self.freq == 0:
|
64 |
+
print("[INFO]", self)
|
65 |
+
|
66 |
+
def __repr__(self):
|
67 |
+
return "RetriMeter (" + str(self.replace / self.total) \
|
68 |
+
+ "/" + str(self.replace) + "/" + str(self.total) + ")"
|
fairseq/examples/MMPT/mmpt/utils/load_config.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import omegaconf
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
|
9 |
+
|
10 |
+
def load_config(args=None, config_file=None, overwrite_fairseq=False):
|
11 |
+
"""TODO (huxu): move fairseq overwrite to another function."""
|
12 |
+
if args is not None:
|
13 |
+
config_file = args.taskconfig
|
14 |
+
config = recursive_config(config_file)
|
15 |
+
|
16 |
+
if config.dataset.subsampling is not None:
|
17 |
+
batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
|
18 |
+
print(
|
19 |
+
"adjusting batch_size to {} due to subsampling {}.".format(
|
20 |
+
batch_size, config.dataset.subsampling
|
21 |
+
)
|
22 |
+
)
|
23 |
+
config.fairseq.dataset.batch_size = batch_size
|
24 |
+
|
25 |
+
is_test = config.dataset.split is not None and config.dataset.split == "test"
|
26 |
+
if not is_test:
|
27 |
+
if (
|
28 |
+
config.fairseq.checkpoint is None
|
29 |
+
or config.fairseq.checkpoint.save_dir is None
|
30 |
+
):
|
31 |
+
raise ValueError("fairseq save_dir or save_path must be specified.")
|
32 |
+
|
33 |
+
save_dir = config.fairseq.checkpoint.save_dir
|
34 |
+
os.makedirs(save_dir, exist_ok=True)
|
35 |
+
if config.fairseq.common.tensorboard_logdir is not None:
|
36 |
+
tb_run_dir = suffix_rundir(
|
37 |
+
save_dir, config.fairseq.common.tensorboard_logdir
|
38 |
+
)
|
39 |
+
config.fairseq.common.tensorboard_logdir = tb_run_dir
|
40 |
+
print(
|
41 |
+
"update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
|
42 |
+
)
|
43 |
+
os.makedirs(save_dir, exist_ok=True)
|
44 |
+
OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
|
45 |
+
|
46 |
+
if overwrite_fairseq and config.fairseq is not None and args is not None:
|
47 |
+
# flatten fields.
|
48 |
+
for group in config.fairseq:
|
49 |
+
for field in config.fairseq[group]:
|
50 |
+
print("overwrite args." + field, "as", config.fairseq[group][field])
|
51 |
+
setattr(args, field, config.fairseq[group][field])
|
52 |
+
return config
|
53 |
+
|
54 |
+
|
55 |
+
def recursive_config(config_path):
|
56 |
+
"""allows for stacking of configs in any depth."""
|
57 |
+
config = OmegaConf.load(config_path)
|
58 |
+
if config.includes is not None:
|
59 |
+
includes = config.includes
|
60 |
+
config.pop("includes")
|
61 |
+
base_config = recursive_config(includes)
|
62 |
+
config = OmegaConf.merge(base_config, config)
|
63 |
+
return config
|
64 |
+
|
65 |
+
|
66 |
+
def suffix_rundir(save_dir, run_dir):
|
67 |
+
max_id = -1
|
68 |
+
for search_dir in os.listdir(save_dir):
|
69 |
+
if search_dir.startswith(run_dir):
|
70 |
+
splits = search_dir.split("_")
|
71 |
+
cur_id = int(splits[1]) if len(splits) > 1 else 0
|
72 |
+
max_id = max(max_id, cur_id)
|
73 |
+
return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
|
74 |
+
|
75 |
+
|
76 |
+
def overwrite_dir(config, replace, basedir):
|
77 |
+
for key in config:
|
78 |
+
if isinstance(config[key], str) and config[key].startswith(basedir):
|
79 |
+
config[key] = config[key].replace(basedir, replace)
|
80 |
+
if isinstance(config[key], omegaconf.dictconfig.DictConfig):
|
81 |
+
overwrite_dir(config[key], replace, basedir)
|
fairseq/examples/MMPT/mmpt_cli/localjob.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
|
7 |
+
from mmpt.utils import recursive_config
|
8 |
+
|
9 |
+
|
10 |
+
class BaseJob(object):
|
11 |
+
def __init__(self, yaml_file, dryrun=False):
|
12 |
+
self.yaml_file = yaml_file
|
13 |
+
self.config = recursive_config(yaml_file)
|
14 |
+
self.dryrun = dryrun
|
15 |
+
|
16 |
+
def submit(self, **kwargs):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def _normalize_cmd(self, cmd_list):
|
20 |
+
cmd_list = list(cmd_list)
|
21 |
+
yaml_index = cmd_list.index("[yaml]")
|
22 |
+
cmd_list[yaml_index] = self.yaml_file
|
23 |
+
return cmd_list
|
24 |
+
|
25 |
+
|
26 |
+
class LocalJob(BaseJob):
|
27 |
+
|
28 |
+
CMD_CONFIG = {
|
29 |
+
"local_single": [
|
30 |
+
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
|
31 |
+
"--task", "mmtask", "--arch", "mmarch",
|
32 |
+
"--criterion", "mmloss",
|
33 |
+
],
|
34 |
+
"local_small": [
|
35 |
+
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
|
36 |
+
"--task", "mmtask", "--arch", "mmarch",
|
37 |
+
"--criterion", "mmloss",
|
38 |
+
"--distributed-world-size", "2"
|
39 |
+
],
|
40 |
+
"local_big": [
|
41 |
+
"fairseq-train", "[yaml]", "--user-dir", "mmpt",
|
42 |
+
"--task", "mmtask", "--arch", "mmarch",
|
43 |
+
"--criterion", "mmloss",
|
44 |
+
"--distributed-world-size", "8"
|
45 |
+
],
|
46 |
+
"local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
|
47 |
+
}
|
48 |
+
|
49 |
+
def __init__(self, yaml_file, job_type=None, dryrun=False):
|
50 |
+
super().__init__(yaml_file, dryrun)
|
51 |
+
if job_type is None:
|
52 |
+
self.job_type = "local_single"
|
53 |
+
if self.config.task_type is not None:
|
54 |
+
self.job_type = self.config.task_type
|
55 |
+
else:
|
56 |
+
self.job_type = job_type
|
57 |
+
if self.job_type in ["local_single", "local_small"]:
|
58 |
+
if self.config.fairseq.dataset.batch_size > 32:
|
59 |
+
print("decreasing batch_size to 32 for local testing?")
|
60 |
+
|
61 |
+
def submit(self):
|
62 |
+
cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
|
63 |
+
if "predict" not in self.job_type:
|
64 |
+
# append fairseq args.
|
65 |
+
from mmpt.utils import load_config
|
66 |
+
|
67 |
+
config = load_config(config_file=self.yaml_file)
|
68 |
+
for field in config.fairseq:
|
69 |
+
for key in config.fairseq[field]:
|
70 |
+
if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag.
|
71 |
+
param = ["--" + key.replace("_", "-")]
|
72 |
+
else:
|
73 |
+
if key == "lr":
|
74 |
+
value = str(config.fairseq[field][key][0])
|
75 |
+
elif key == "adam_betas":
|
76 |
+
value = "'"+str(config.fairseq[field][key])+"'"
|
77 |
+
else:
|
78 |
+
value = str(config.fairseq[field][key])
|
79 |
+
param = [
|
80 |
+
"--" + key.replace("_", "-"),
|
81 |
+
value
|
82 |
+
]
|
83 |
+
cmd_list.extend(param)
|
84 |
+
|
85 |
+
print("launching", " ".join(cmd_list))
|
86 |
+
if not self.dryrun:
|
87 |
+
os.system(" ".join(cmd_list))
|
88 |
+
return JobStatus("12345678")
|
89 |
+
|
90 |
+
|
91 |
+
class JobStatus(object):
|
92 |
+
def __init__(self, job_id):
|
93 |
+
self.job_id = job_id
|
94 |
+
|
95 |
+
def __repr__(self):
|
96 |
+
return self.job_id
|
97 |
+
|
98 |
+
def __str__(self):
|
99 |
+
return self.job_id
|
100 |
+
|
101 |
+
def done(self):
|
102 |
+
return False
|
103 |
+
|
104 |
+
def running(self):
|
105 |
+
return False
|
106 |
+
|
107 |
+
def result(self):
|
108 |
+
if self.done():
|
109 |
+
return "{} is done.".format(self.job_id)
|
110 |
+
else:
|
111 |
+
return "{} is running.".format(self.job_id)
|
112 |
+
|
113 |
+
def stderr(self):
|
114 |
+
return self.result()
|
115 |
+
|
116 |
+
def stdout(self):
|
117 |
+
return self.result()
|
fairseq/examples/MMPT/mmpt_cli/predict.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import glob
|
7 |
+
import argparse
|
8 |
+
import pprint
|
9 |
+
import omegaconf
|
10 |
+
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from mmpt.utils import load_config, set_seed
|
15 |
+
from mmpt.evaluators import Evaluator
|
16 |
+
from mmpt.evaluators import predictor as predictor_path
|
17 |
+
from mmpt.tasks import Task
|
18 |
+
from mmpt import processors
|
19 |
+
from mmpt.datasets import MMDataset
|
20 |
+
|
21 |
+
|
22 |
+
def get_dataloader(config):
|
23 |
+
meta_processor_cls = getattr(processors, config.dataset.meta_processor)
|
24 |
+
video_processor_cls = getattr(processors, config.dataset.video_processor)
|
25 |
+
text_processor_cls = getattr(processors, config.dataset.text_processor)
|
26 |
+
aligner_cls = getattr(processors, config.dataset.aligner)
|
27 |
+
|
28 |
+
meta_processor = meta_processor_cls(config.dataset)
|
29 |
+
video_processor = video_processor_cls(config.dataset)
|
30 |
+
text_processor = text_processor_cls(config.dataset)
|
31 |
+
aligner = aligner_cls(config.dataset)
|
32 |
+
|
33 |
+
test_data = MMDataset(
|
34 |
+
meta_processor,
|
35 |
+
video_processor,
|
36 |
+
text_processor,
|
37 |
+
aligner,
|
38 |
+
)
|
39 |
+
print("test_len", len(test_data))
|
40 |
+
output = test_data[0]
|
41 |
+
test_data.print_example(output)
|
42 |
+
|
43 |
+
test_dataloader = DataLoader(
|
44 |
+
test_data,
|
45 |
+
batch_size=config.fairseq.dataset.batch_size,
|
46 |
+
shuffle=False,
|
47 |
+
num_workers=6,
|
48 |
+
collate_fn=test_data.collater,
|
49 |
+
)
|
50 |
+
return test_dataloader
|
51 |
+
|
52 |
+
|
53 |
+
def main(args):
|
54 |
+
config = load_config(args)
|
55 |
+
|
56 |
+
if isinstance(config, omegaconf.dictconfig.DictConfig):
|
57 |
+
print(OmegaConf.to_yaml(config))
|
58 |
+
else:
|
59 |
+
pp = pprint.PrettyPrinter(indent=4)
|
60 |
+
pp.print(config)
|
61 |
+
|
62 |
+
mmtask = Task.config_task(config)
|
63 |
+
mmtask.build_model()
|
64 |
+
|
65 |
+
test_dataloader = get_dataloader(config)
|
66 |
+
checkpoint_search_path = os.path.dirname(config.eval.save_path)
|
67 |
+
results = []
|
68 |
+
|
69 |
+
prefix = os.path.basename(args.taskconfig)
|
70 |
+
if prefix.startswith("test"):
|
71 |
+
# loop all checkpoint for datasets without validation set.
|
72 |
+
if "best" not in config.fairseq.common_eval.path:
|
73 |
+
print("eval each epoch.")
|
74 |
+
for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
|
75 |
+
model = mmtask.load_checkpoint(checkpoint)
|
76 |
+
ckpt = os.path.basename(checkpoint)
|
77 |
+
evaluator = Evaluator(config)
|
78 |
+
output = evaluator.evaluate(
|
79 |
+
model, test_dataloader, ckpt + "_merged")
|
80 |
+
results.append((checkpoint, output))
|
81 |
+
# use the one specified by the config lastly.
|
82 |
+
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
|
83 |
+
evaluator = Evaluator(config)
|
84 |
+
output = evaluator.evaluate(model, test_dataloader)
|
85 |
+
results.append((config.fairseq.common_eval.path, output))
|
86 |
+
|
87 |
+
best_result = None
|
88 |
+
best_metric = 0.
|
89 |
+
for checkpoint, result in results:
|
90 |
+
print(checkpoint)
|
91 |
+
evaluator.metric.print_computed_metrics(result)
|
92 |
+
best_score = evaluator.metric.best_metric(result)
|
93 |
+
if best_score > best_metric:
|
94 |
+
best_result = (checkpoint, result)
|
95 |
+
best_metric = best_score
|
96 |
+
print("best results:")
|
97 |
+
print(best_result[0])
|
98 |
+
evaluator.metric.print_computed_metrics(best_result[1])
|
99 |
+
|
100 |
+
elif prefix.startswith("vis"):
|
101 |
+
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
|
102 |
+
predictor_cls = getattr(predictor_path, config.predictor)
|
103 |
+
predictor = predictor_cls(config)
|
104 |
+
predictor.predict_loop(model, test_dataloader, mmtask, None)
|
105 |
+
else:
|
106 |
+
raise ValueError("unknown prefix of the config file", args.taskconfig)
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser()
|
111 |
+
parser.add_argument("taskconfig", type=str)
|
112 |
+
args = parser.parse_args()
|
113 |
+
main(args)
|
fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/mfmmlm.yaml
|
2 |
+
project_dir: mtm/mmfusionmtm
|
3 |
+
task_group:
|
4 |
+
pretrain:
|
5 |
+
task: VLMTask # reproducible
|
6 |
+
dataset:
|
7 |
+
aligner: MFMMLMAligner
|
8 |
+
model:
|
9 |
+
use_seg_emb: True # reproducible
|
10 |
+
model_cls: MMFusionMTM
|
11 |
+
mm_encoder_cls: MMBertForMFMMLM
|
12 |
+
loss:
|
13 |
+
loss_cls: MTM
|
14 |
+
finetune:
|
15 |
+
model:
|
16 |
+
use_seg_emb: True # reproducible
|
17 |
+
test:
|
18 |
+
model:
|
19 |
+
use_seg_emb: True # reproducible
|
fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: VideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: COINActionSegmentationMetaProcessor
|
5 |
+
train_path: data/coin/COIN.json
|
6 |
+
val_path: data/coin/COIN.json
|
7 |
+
vfeat_dir: data/feat/feat_coin_s3d
|
8 |
+
text_processor: COINActionSegmentationTextProcessor
|
9 |
+
aligner: COINActionSegmentationAligner
|
10 |
+
num_iso_layer: 12
|
11 |
+
sliding_window: 8
|
12 |
+
sliding_window_size: 32
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 1
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 8
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/mtm/vlm/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/mtm/vlm/coin
|
41 |
+
task_type: sweep_big
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionActionSegmentation
|
44 |
+
mm_encoder_cls: MMBertForTokenClassification
|
45 |
+
use_seg_emb: true
|
46 |
+
loss:
|
47 |
+
loss_cls: CrossEntropy
|
fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: ShardedVideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: ShardedHow2MetaProcessor
|
5 |
+
train_path: data/how2/how2_s3d_train.lst
|
6 |
+
val_path: data/how2/how2_s3d_val.lst
|
7 |
+
vfeat_dir: data/feat/feat_how2_s3d_shard_small
|
8 |
+
text_processor: ShardedTextProcessor
|
9 |
+
tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
|
10 |
+
aligner: MFMMLMAligner
|
11 |
+
subsampling: 32
|
12 |
+
sampled_min_len: 8
|
13 |
+
sampled_max_len: 64
|
14 |
+
max_video_len: 32
|
15 |
+
max_len: 96
|
16 |
+
lazy_vfeat_mask: true
|
17 |
+
mfm_probability: 0.15
|
18 |
+
mlm_probability: 0.15
|
19 |
+
mm_prob: 0.5
|
20 |
+
fairseq:
|
21 |
+
common:
|
22 |
+
tensorboard_logdir: run
|
23 |
+
log_interval: 1000
|
24 |
+
fp16: true
|
25 |
+
dataset:
|
26 |
+
num_workers: 4
|
27 |
+
batch_size: 256
|
28 |
+
optimization:
|
29 |
+
lr:
|
30 |
+
- 5.0e-05
|
31 |
+
clip_norm: 2.0
|
32 |
+
optimizer: adam
|
33 |
+
adam_betas: (0.9, 0.98)
|
34 |
+
lr_scheduler: polynomial_decay
|
35 |
+
total_num_update: 1000000
|
36 |
+
warmup_updates: 1000
|
37 |
+
weight_decay: 0.0
|
38 |
+
ddp_backend: no_c10d
|
39 |
+
max_epoch: 15
|
40 |
+
checkpoint:
|
41 |
+
save_dir: runs/mtm/vlm
|
42 |
+
save_interval_updates: 1024
|
43 |
+
keep_interval_updates: 2
|
44 |
+
keep_last_epochs: 30
|
45 |
+
task_type: sweep_big
|
46 |
+
slurm_config: big
|
47 |
+
eval:
|
48 |
+
save_path: runs/mtm/vlm
|
49 |
+
model:
|
50 |
+
model_cls: MMFusionMTM
|
51 |
+
mm_encoder_cls: MMBertForMFMMLM
|
52 |
+
use_seg_emb: true
|
53 |
+
loss:
|
54 |
+
loss_cls: MTM
|
55 |
+
task: VLMTask
|
fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: CrossTaskVideoProcessor
|
6 |
+
aligner: CrossTaskAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: CrossTaskMetaProcessor
|
9 |
+
test_path: data/crosstask/crosstask_release/videos_val.csv
|
10 |
+
train_csv_path: data/crosstask/crosstask_release/videos.csv
|
11 |
+
val_path: data/crosstask/crosstask_release/videos_val.csv
|
12 |
+
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
|
13 |
+
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
|
14 |
+
related_path: data/crosstask/crosstask_release/tasks_related.txt
|
15 |
+
vfeat_dir: data/feat/feat_crosstask_s3d
|
16 |
+
annotation_path: data/crosstask/crosstask_release/annotations
|
17 |
+
n_train: 30
|
18 |
+
text_processor: CrossTaskTextProcessor
|
19 |
+
num_iso_layer: 12
|
20 |
+
sliding_window: 16
|
21 |
+
sliding_window_size: 32
|
22 |
+
max_video_len: 32
|
23 |
+
max_len: 96
|
24 |
+
fairseq:
|
25 |
+
dataset:
|
26 |
+
batch_size: 1
|
27 |
+
valid_subset: test
|
28 |
+
num_workers: 2
|
29 |
+
common_eval:
|
30 |
+
path: runs/mtm/vlm/crosstask/checkpoint_best.pt
|
31 |
+
model:
|
32 |
+
model_cls: MMFusionActionLocalization
|
33 |
+
mm_encoder_cls: MMBertForJoint
|
34 |
+
use_seg_emb: true
|
35 |
+
eval:
|
36 |
+
save_path: runs/mtm/vlm/crosstask/eval
|
37 |
+
metric: CrossTaskMetric
|
38 |
+
predictor: CrossTaskPredictor
|
fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: DSAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTMetaProcessor
|
9 |
+
test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/mtm/vlm/vtt/checkpoint_last.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionJoint
|
24 |
+
mm_encoder_cls: MMBertForJoint
|
25 |
+
use_seg_emb: true
|
26 |
+
eval:
|
27 |
+
save_path: runs/mtm/vlm/vtt/eval
|
28 |
+
metric: RetrievalMetric
|
29 |
+
predictor: RetrievalPredictor
|
fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: MSRVTTQAAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTQAMetaProcessor
|
9 |
+
test_path: data/msrvtt-qa/MSR_MC_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTQATextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/mtm/vlm/vttqa/checkpoint_last.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionJoint
|
24 |
+
mm_encoder_cls: MMBertForJoint
|
25 |
+
use_seg_emb: true
|
26 |
+
eval:
|
27 |
+
save_path: runs/mtm/vlm/vttqa/eval
|
28 |
+
metric: QAMetric
|
29 |
+
predictor: QAPredictor
|
fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: YoucookVideoProcessor
|
6 |
+
aligner: DSNLGAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: YoucookNLGMetaProcessor
|
9 |
+
test_path: data/youcook/val_list.txt
|
10 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
11 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
12 |
+
text_processor: NLGTextProcessor
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/mtm/vlm/youcookcap/checkpoint_best.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionNLG
|
24 |
+
mm_encoder_cls: MMBertForNLG
|
25 |
+
max_decode_length: 24
|
26 |
+
use_seg_emb: true
|
27 |
+
eval:
|
28 |
+
save_path: runs/mtm/vlm/youcookcap/eval
|
29 |
+
metric: NLGMetric
|
30 |
+
predictor: NLGPredictor
|
31 |
+
gen_param:
|
32 |
+
num_beams: 5
|
fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: VideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: MSRVTTMetaProcessor
|
5 |
+
train_path: data/msrvtt/MSRVTT_train.csv
|
6 |
+
jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
7 |
+
full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
|
8 |
+
dup: 20
|
9 |
+
val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTTextProcessor
|
12 |
+
json_path: data/msrvtt/MSRVTT_data.json
|
13 |
+
aligner: DSAligner
|
14 |
+
num_iso_layer: 12
|
15 |
+
max_video_len: 32
|
16 |
+
max_len: 96
|
17 |
+
fairseq:
|
18 |
+
common:
|
19 |
+
tensorboard_logdir: run
|
20 |
+
log_interval: 1000
|
21 |
+
fp16: true
|
22 |
+
dataset:
|
23 |
+
num_workers: 4
|
24 |
+
batch_size: 256
|
25 |
+
optimization:
|
26 |
+
lr:
|
27 |
+
- 5.0e-05
|
28 |
+
clip_norm: 2.0
|
29 |
+
optimizer: adam
|
30 |
+
adam_betas: (0.9, 0.98)
|
31 |
+
lr_scheduler: polynomial_decay
|
32 |
+
total_num_update: 1000000
|
33 |
+
warmup_updates: 122
|
34 |
+
weight_decay: 0.0
|
35 |
+
ddp_backend: no_c10d
|
36 |
+
max_epoch: 10
|
37 |
+
checkpoint:
|
38 |
+
restore_file: runs/mtm/vlm/checkpoint_best.pt
|
39 |
+
reset_optimizer: true
|
40 |
+
reset_dataloader: true
|
41 |
+
reset_meters: true
|
42 |
+
save_dir: runs/mtm/vlm/vtt
|
43 |
+
task_type: sweep_small
|
44 |
+
model:
|
45 |
+
model_cls: MMFusionJoint
|
46 |
+
mm_encoder_cls: MMBertForJoint
|
47 |
+
use_seg_emb: true
|
48 |
+
loss:
|
49 |
+
loss_cls: T2VContraLoss
|
fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: VideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: MSRVTTMetaProcessor
|
5 |
+
train_path: data/msrvtt/MSRVTT_train.csv
|
6 |
+
dup: 20
|
7 |
+
val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
8 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
9 |
+
text_processor: MSRVTTTextProcessor
|
10 |
+
json_path: data/msrvtt/MSRVTT_data.json
|
11 |
+
aligner: DSAligner
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 128
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 5
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/mtm/vlm/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/mtm/vlm/vttqa
|
41 |
+
task_type: sweep_small
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionJoint
|
44 |
+
mm_encoder_cls: MMBertForJoint
|
45 |
+
use_seg_emb: true
|
46 |
+
loss:
|
47 |
+
loss_cls: V2TContraLoss
|
fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: YoucookVideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: YoucookMetaProcessor
|
5 |
+
train_path: data/youcook/youcook_train.pkl
|
6 |
+
val_path: data/youcook/youcook_val.pkl
|
7 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
8 |
+
use_annotation_text: true
|
9 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
10 |
+
text_processor: TextProcessor
|
11 |
+
aligner: DSAligner
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 128
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 10
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/mtm/vlm/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/mtm/vlm/youcook
|
41 |
+
task_type: sweep_small
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionJoint
|
44 |
+
mm_encoder_cls: MMBertForJoint
|
45 |
+
use_seg_emb: true
|
46 |
+
loss:
|
47 |
+
loss_cls: T2VContraLoss
|
fairseq/examples/MMPT/projects/retri/videoclip.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/retri/videoretri.yaml
|
2 |
+
project_dir: retri/videoclip
|
3 |
+
task_group:
|
4 |
+
pretrain:
|
5 |
+
model:
|
6 |
+
model_cls: MMFusionSeparate
|
7 |
+
mm_encoder_cls:
|
8 |
+
video_encoder_cls: MMBertForEncoder
|
9 |
+
text_encoder_cls: BertModel
|
10 |
+
num_hidden_video_layers: 6
|
fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: VideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: COINActionSegmentationMetaProcessor
|
5 |
+
train_path: data/coin/COIN.json
|
6 |
+
val_path: data/coin/COIN.json
|
7 |
+
vfeat_dir: data/feat/feat_coin_s3d
|
8 |
+
text_processor: COINActionSegmentationTextProcessor
|
9 |
+
aligner: COINActionSegmentationAligner
|
10 |
+
num_iso_layer: 12
|
11 |
+
sliding_window: 8
|
12 |
+
sliding_window_size: 32
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 1
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 8
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/retri/videoclip/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/retri/videoclip/coin
|
41 |
+
task_type: sweep_big
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionSeparateActionSegmentation
|
44 |
+
mm_encoder_cls: null
|
45 |
+
video_encoder_cls: MMBertForTokenClassification
|
46 |
+
text_encoder_cls: BertModel
|
47 |
+
num_hidden_video_layers: 6
|
48 |
+
loss:
|
49 |
+
loss_cls: CrossEntropy
|
fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: CrossTaskVideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: CrossTaskMetaProcessor
|
5 |
+
train_path: data/crosstask/crosstask_release/videos.csv
|
6 |
+
train_csv_path: data/crosstask/crosstask_release/videos.csv
|
7 |
+
val_path: data/crosstask/crosstask_release/videos_val.csv
|
8 |
+
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
|
9 |
+
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
|
10 |
+
related_path: data/crosstask/crosstask_release/tasks_related.txt
|
11 |
+
vfeat_dir: data/feat/feat_crosstask_s3d
|
12 |
+
annotation_path: data/crosstask/crosstask_release/annotations
|
13 |
+
n_train: 30
|
14 |
+
text_processor: CrossTaskTextProcessor
|
15 |
+
aligner: CrossTaskAligner
|
16 |
+
num_iso_layer: 12
|
17 |
+
sliding_window: 16
|
18 |
+
sliding_window_size: 32
|
19 |
+
max_video_len: 32
|
20 |
+
max_len: 96
|
21 |
+
fairseq:
|
22 |
+
common:
|
23 |
+
tensorboard_logdir: run
|
24 |
+
log_interval: 1000
|
25 |
+
fp16: true
|
26 |
+
dataset:
|
27 |
+
num_workers: 4
|
28 |
+
batch_size: 1
|
29 |
+
optimization:
|
30 |
+
lr:
|
31 |
+
- 5.0e-05
|
32 |
+
clip_norm: 2.0
|
33 |
+
optimizer: adam
|
34 |
+
adam_betas: (0.9, 0.98)
|
35 |
+
lr_scheduler: polynomial_decay
|
36 |
+
total_num_update: 1000000
|
37 |
+
warmup_updates: 122
|
38 |
+
weight_decay: 0.0
|
39 |
+
ddp_backend: no_c10d
|
40 |
+
max_epoch: 5
|
41 |
+
checkpoint:
|
42 |
+
restore_file: runs/retri/videoclip/checkpoint_best.pt
|
43 |
+
reset_optimizer: true
|
44 |
+
reset_dataloader: true
|
45 |
+
reset_meters: true
|
46 |
+
save_dir: runs/retri/videoclip/crosstask
|
47 |
+
task_type: sweep_small
|
48 |
+
model:
|
49 |
+
model_cls: MMFusionSeparateActionLocalization
|
50 |
+
mm_encoder_cls: null
|
51 |
+
video_encoder_cls: MMBertForEncoder
|
52 |
+
text_encoder_cls: BertModel
|
53 |
+
num_hidden_video_layers: 6
|
54 |
+
loss:
|
55 |
+
loss_cls: BCE
|
fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: ShardedVideoRetriVideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: ShardedHow2VideoRetriMetaProcessor
|
5 |
+
train_path: data/how2/how2_s3d_train.lst
|
6 |
+
val_path: data/how2/how2_s3d_val.lst
|
7 |
+
vfeat_dir: data/feat/feat_how2_s3d_shard_small
|
8 |
+
text_processor: ShardedVideoRetriTextProcessor
|
9 |
+
tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
|
10 |
+
aligner: VideoRetriOverlappedAligner
|
11 |
+
subsampling: 1
|
12 |
+
sampled_min_len: 8
|
13 |
+
sampled_max_len: 64
|
14 |
+
max_video_len: 32
|
15 |
+
max_len: 96
|
16 |
+
lazy_vfeat_mask: true
|
17 |
+
mfm_probability: 0.15
|
18 |
+
mlm_probability: 0.15
|
19 |
+
mm_prob: 0.5
|
20 |
+
sampled_video_min_len: 3
|
21 |
+
sampled_video_max_len: 32
|
22 |
+
num_video_per_batch: 32
|
23 |
+
clip_per_video: 16
|
24 |
+
fairseq:
|
25 |
+
common:
|
26 |
+
tensorboard_logdir: run
|
27 |
+
log_interval: 1000
|
28 |
+
fp16: true
|
29 |
+
dataset:
|
30 |
+
num_workers: 4
|
31 |
+
batch_size: 1
|
32 |
+
optimization:
|
33 |
+
lr:
|
34 |
+
- 5.0e-05
|
35 |
+
clip_norm: 2.0
|
36 |
+
optimizer: adam
|
37 |
+
adam_betas: (0.9, 0.98)
|
38 |
+
lr_scheduler: polynomial_decay
|
39 |
+
total_num_update: 1000000
|
40 |
+
warmup_updates: 1000
|
41 |
+
weight_decay: 0.0
|
42 |
+
ddp_backend: no_c10d
|
43 |
+
max_epoch: 25
|
44 |
+
checkpoint:
|
45 |
+
save_dir: runs/retri/videoclip
|
46 |
+
save_interval_updates: 1024
|
47 |
+
keep_interval_updates: 2
|
48 |
+
keep_last_epochs: 30
|
49 |
+
task_type: sweep_big
|
50 |
+
slurm_config: big
|
51 |
+
eval:
|
52 |
+
save_path: runs/retri/videoclip
|
53 |
+
model:
|
54 |
+
model_cls: MMFusionSeparate
|
55 |
+
mm_encoder_cls: null
|
56 |
+
video_encoder_cls: MMBertForEncoder
|
57 |
+
text_encoder_cls: BertModel
|
58 |
+
num_hidden_video_layers: 6
|
59 |
+
loss:
|
60 |
+
loss_cls: MMContraLoss
|
61 |
+
task: VideoRetriTask
|
62 |
+
retri_epoch: 1
|
63 |
+
vectorpool_cls: VideoVectorPool
|
64 |
+
retriever_cls: VectorRetriever
|
65 |
+
num_cands: 64
|
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: COINActionSegmentationAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
test_path: data/coin/COIN.json
|
9 |
+
meta_processor: COINActionSegmentationMetaProcessor
|
10 |
+
vfeat_dir: data/feat/feat_coin_s3d
|
11 |
+
text_processor: COINActionSegmentationTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
sliding_window: 16
|
14 |
+
sliding_window_size: 32
|
15 |
+
max_video_len: 32
|
16 |
+
max_len: 96
|
17 |
+
fairseq:
|
18 |
+
dataset:
|
19 |
+
batch_size: 1
|
20 |
+
valid_subset: test
|
21 |
+
num_workers: 2
|
22 |
+
common_eval:
|
23 |
+
path: runs/retri/videoclip/coin/checkpoint_best.pt
|
24 |
+
model:
|
25 |
+
model_cls: MMFusionSeparateActionSegmentation
|
26 |
+
mm_encoder_cls: null
|
27 |
+
video_encoder_cls: MMBertForTokenClassification
|
28 |
+
text_encoder_cls: BertModel
|
29 |
+
num_hidden_video_layers: 6
|
30 |
+
eval:
|
31 |
+
save_path: runs/retri/videoclip/coin/eval
|
32 |
+
metric: COINActionSegmentationMetric
|
33 |
+
predictor: COINPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: COINActionSegmentationAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
test_path: data/coin/COIN.json
|
9 |
+
meta_processor: COINActionSegmentationMetaProcessor
|
10 |
+
vfeat_dir: data/feat/feat_coin_s3d
|
11 |
+
text_processor: COINActionSegmentationTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
sliding_window: 16
|
14 |
+
sliding_window_size: 32
|
15 |
+
max_video_len: 32
|
16 |
+
max_len: 96
|
17 |
+
fairseq:
|
18 |
+
dataset:
|
19 |
+
batch_size: 1
|
20 |
+
valid_subset: test
|
21 |
+
num_workers: 2
|
22 |
+
common_eval:
|
23 |
+
path: runs/retri/videoclip/checkpoint_best.pt
|
24 |
+
model:
|
25 |
+
model_cls: MMFusionSeparate
|
26 |
+
mm_encoder_cls: null
|
27 |
+
video_encoder_cls: MMBertForEncoder
|
28 |
+
text_encoder_cls: BertModel
|
29 |
+
num_hidden_video_layers: 6
|
30 |
+
eval:
|
31 |
+
save_path: runs/retri/videoclip/coin_zs/eval
|
32 |
+
metric: COINActionSegmentationMetric
|
33 |
+
predictor: COINZSPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: CrossTaskVideoProcessor
|
6 |
+
aligner: CrossTaskAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: CrossTaskMetaProcessor
|
9 |
+
test_path: data/crosstask/crosstask_release/videos_val.csv
|
10 |
+
train_csv_path: data/crosstask/crosstask_release/videos.csv
|
11 |
+
val_path: data/crosstask/crosstask_release/videos_val.csv
|
12 |
+
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
|
13 |
+
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
|
14 |
+
related_path: data/crosstask/crosstask_release/tasks_related.txt
|
15 |
+
vfeat_dir: data/feat/feat_crosstask_s3d
|
16 |
+
annotation_path: data/crosstask/crosstask_release/annotations
|
17 |
+
n_train: 30
|
18 |
+
text_processor: CrossTaskTextProcessor
|
19 |
+
num_iso_layer: 12
|
20 |
+
sliding_window: 16
|
21 |
+
sliding_window_size: 32
|
22 |
+
max_video_len: 32
|
23 |
+
max_len: 96
|
24 |
+
fairseq:
|
25 |
+
dataset:
|
26 |
+
batch_size: 1
|
27 |
+
valid_subset: test
|
28 |
+
num_workers: 2
|
29 |
+
common_eval:
|
30 |
+
path: runs/retri/videoclip/crosstask/checkpoint_best.pt
|
31 |
+
model:
|
32 |
+
model_cls: MMFusionSeparateActionLocalization
|
33 |
+
mm_encoder_cls: null
|
34 |
+
video_encoder_cls: MMBertForEncoder
|
35 |
+
text_encoder_cls: BertModel
|
36 |
+
num_hidden_video_layers: 6
|
37 |
+
eval:
|
38 |
+
save_path: runs/retri/videoclip/crosstask/eval
|
39 |
+
metric: CrossTaskMetric
|
40 |
+
predictor: CrossTaskPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: CrossTaskVideoProcessor
|
6 |
+
aligner: CrossTaskAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: CrossTaskMetaProcessor
|
9 |
+
test_path: data/crosstask/crosstask_release/videos_val.csv
|
10 |
+
train_csv_path: data/crosstask/crosstask_release/videos.csv
|
11 |
+
val_path: data/crosstask/crosstask_release/videos_val.csv
|
12 |
+
val_csv_path: data/crosstask/crosstask_release/videos_val.csv
|
13 |
+
primary_path: data/crosstask/crosstask_release/tasks_primary.txt
|
14 |
+
related_path: data/crosstask/crosstask_release/tasks_related.txt
|
15 |
+
vfeat_dir: data/feat/feat_crosstask_s3d
|
16 |
+
annotation_path: data/crosstask/crosstask_release/annotations
|
17 |
+
n_train: 30
|
18 |
+
text_processor: CrossTaskTextProcessor
|
19 |
+
num_iso_layer: 12
|
20 |
+
sliding_window: 16
|
21 |
+
sliding_window_size: 32
|
22 |
+
max_video_len: 32
|
23 |
+
max_len: 96
|
24 |
+
fairseq:
|
25 |
+
dataset:
|
26 |
+
batch_size: 1
|
27 |
+
valid_subset: test
|
28 |
+
num_workers: 2
|
29 |
+
common_eval:
|
30 |
+
path: runs/retri/videoclip/checkpoint_best.pt
|
31 |
+
model:
|
32 |
+
model_cls: MMFusionSeparateActionLocalization
|
33 |
+
mm_encoder_cls: null
|
34 |
+
video_encoder_cls: MMBertForEncoder
|
35 |
+
text_encoder_cls: BertModel
|
36 |
+
num_hidden_video_layers: 6
|
37 |
+
eval:
|
38 |
+
save_path: runs/retri/videoclip/crosstask_zs/eval
|
39 |
+
metric: CrossTaskMetric
|
40 |
+
predictor: CrossTaskPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: DSAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTMetaProcessor
|
9 |
+
test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/retri/videoclip/vtt/checkpoint_last.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionSeparate
|
24 |
+
mm_encoder_cls: null
|
25 |
+
video_encoder_cls: MMBertForEncoder
|
26 |
+
text_encoder_cls: BertModel
|
27 |
+
num_hidden_video_layers: 6
|
28 |
+
eval:
|
29 |
+
save_path: runs/retri/videoclip/vtt/eval
|
30 |
+
metric: RetrievalMetric
|
31 |
+
predictor: RetrievalPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: DSAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTMetaProcessor
|
9 |
+
test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/retri/videoclip/checkpoint_best.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionSeparate
|
24 |
+
mm_encoder_cls: null
|
25 |
+
video_encoder_cls: MMBertForEncoder
|
26 |
+
text_encoder_cls: BertModel
|
27 |
+
num_hidden_video_layers: 6
|
28 |
+
eval:
|
29 |
+
save_path: runs/retri/videoclip/vtt_zs/eval
|
30 |
+
metric: RetrievalMetric
|
31 |
+
predictor: RetrievalPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: MSRVTTQAAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTQAMetaProcessor
|
9 |
+
test_path: data/msrvtt-qa/MSR_MC_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTQATextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/retri/videoclip/vttqa/checkpoint_last.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionSeparate
|
24 |
+
mm_encoder_cls: null
|
25 |
+
video_encoder_cls: MMBertForEncoder
|
26 |
+
text_encoder_cls: BertModel
|
27 |
+
num_hidden_video_layers: 6
|
28 |
+
eval:
|
29 |
+
save_path: runs/retri/videoclip/vttqa/eval
|
30 |
+
metric: QAMetric
|
31 |
+
predictor: QAPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: MSRVTTQAAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: MSRVTTQAMetaProcessor
|
9 |
+
test_path: data/msrvtt-qa/MSR_MC_test.csv
|
10 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
11 |
+
text_processor: MSRVTTQATextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/retri/videoclip/checkpoint_best.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionSeparate
|
24 |
+
mm_encoder_cls: null
|
25 |
+
video_encoder_cls: MMBertForEncoder
|
26 |
+
text_encoder_cls: BertModel
|
27 |
+
num_hidden_video_layers: 6
|
28 |
+
eval:
|
29 |
+
save_path: runs/retri/videoclip/vttqa_zs/eval
|
30 |
+
metric: QAMetric
|
31 |
+
predictor: QAPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: YoucookVideoProcessor
|
6 |
+
aligner: DSAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: YoucookMetaProcessor
|
9 |
+
test_path: data/youcook/youcook_val.pkl
|
10 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
11 |
+
use_annotation_text: true
|
12 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
13 |
+
text_processor: TextProcessor
|
14 |
+
num_iso_layer: 12
|
15 |
+
max_video_len: 32
|
16 |
+
max_len: 96
|
17 |
+
fairseq:
|
18 |
+
dataset:
|
19 |
+
batch_size: 256
|
20 |
+
valid_subset: test
|
21 |
+
num_workers: 2
|
22 |
+
common_eval:
|
23 |
+
path: runs/retri/videoclip/youcook/checkpoint_last.pt
|
24 |
+
model:
|
25 |
+
model_cls: MMFusionSeparate
|
26 |
+
mm_encoder_cls: null
|
27 |
+
video_encoder_cls: MMBertForEncoder
|
28 |
+
text_encoder_cls: BertModel
|
29 |
+
num_hidden_video_layers: 6
|
30 |
+
eval:
|
31 |
+
save_path: runs/retri/videoclip/youcook/eval
|
32 |
+
metric: RetrievalMetric
|
33 |
+
predictor: RetrievalPredictor
|