PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
d28af7f
·
verified ·
1 Parent(s): 29c9ba5

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/docs/models.rst +104 -0
  2. fairseq/docs/tutorial_classifying_names.rst +415 -0
  3. fairseq/examples/MMPT/.gitignore +139 -0
  4. fairseq/examples/MMPT/README.md +166 -0
  5. fairseq/examples/MMPT/endtask.md +41 -0
  6. fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py +57 -0
  7. fairseq/examples/MMPT/mmpt/datasets/mmdataset.py +111 -0
  8. fairseq/examples/MMPT/mmpt/evaluators/__init__.py +13 -0
  9. fairseq/examples/MMPT/mmpt/evaluators/evaluator.py +54 -0
  10. fairseq/examples/MMPT/mmpt/models/transformermodel.py +734 -0
  11. fairseq/examples/MMPT/mmpt/modules/__init__.py +10 -0
  12. fairseq/examples/MMPT/mmpt/modules/mm.py +145 -0
  13. fairseq/examples/MMPT/mmpt/modules/retri.py +429 -0
  14. fairseq/examples/MMPT/mmpt/modules/vectorpool.py +246 -0
  15. fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py +242 -0
  16. fairseq/examples/MMPT/mmpt/processors/how2processor.py +887 -0
  17. fairseq/examples/MMPT/mmpt/processors/models/s3dg.py +336 -0
  18. fairseq/examples/MMPT/mmpt/processors/processor.py +274 -0
  19. fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py +104 -0
  20. fairseq/examples/MMPT/mmpt/tasks/milncetask.py +27 -0
  21. fairseq/examples/MMPT/mmpt/tasks/retritask.py +253 -0
  22. fairseq/examples/MMPT/mmpt/tasks/task.py +184 -0
  23. fairseq/examples/MMPT/mmpt/tasks/vlmtask.py +27 -0
  24. fairseq/examples/MMPT/mmpt/utils/__init__.py +68 -0
  25. fairseq/examples/MMPT/mmpt/utils/load_config.py +81 -0
  26. fairseq/examples/MMPT/mmpt_cli/localjob.py +117 -0
  27. fairseq/examples/MMPT/mmpt_cli/predict.py +113 -0
  28. fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml +19 -0
  29. fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml +47 -0
  30. fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml +55 -0
  31. fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml +38 -0
  32. fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml +29 -0
  33. fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml +29 -0
  34. fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml +32 -0
  35. fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml +49 -0
  36. fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml +47 -0
  37. fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml +47 -0
  38. fairseq/examples/MMPT/projects/retri/videoclip.yaml +10 -0
  39. fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml +49 -0
  40. fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml +55 -0
  41. fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml +65 -0
  42. fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml +33 -0
  43. fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml +33 -0
  44. fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml +40 -0
  45. fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml +40 -0
  46. fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml +31 -0
  47. fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml +31 -0
  48. fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml +31 -0
  49. fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml +31 -0
  50. fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml +33 -0
fairseq/docs/models.rst ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .. role:: hidden
2
+ :class: hidden-section
3
+
4
+ .. module:: fairseq.models
5
+
6
+ .. _Models:
7
+
8
+ Models
9
+ ======
10
+
11
+ A Model defines the neural network's ``forward()`` method and encapsulates all
12
+ of the learnable parameters in the network. Each model also provides a set of
13
+ named *architectures* that define the precise network configuration (e.g.,
14
+ embedding dimension, number of layers, etc.).
15
+
16
+ Both the model type and architecture are selected via the ``--arch``
17
+ command-line argument. Once selected, a model may expose additional command-line
18
+ arguments for further configuration.
19
+
20
+ .. note::
21
+
22
+ All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
23
+ :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
24
+ stand-alone Module in other PyTorch code.
25
+
26
+
27
+ Convolutional Neural Networks (CNN)
28
+ -----------------------------------
29
+
30
+ .. module:: fairseq.models.fconv
31
+ .. autoclass:: fairseq.models.fconv.FConvModel
32
+ :members:
33
+ .. autoclass:: fairseq.models.fconv.FConvEncoder
34
+ :members:
35
+ :undoc-members:
36
+ .. autoclass:: fairseq.models.fconv.FConvDecoder
37
+ :members:
38
+
39
+
40
+ Long Short-Term Memory (LSTM) networks
41
+ --------------------------------------
42
+
43
+ .. module:: fairseq.models.lstm
44
+ .. autoclass:: fairseq.models.lstm.LSTMModel
45
+ :members:
46
+ .. autoclass:: fairseq.models.lstm.LSTMEncoder
47
+ :members:
48
+ .. autoclass:: fairseq.models.lstm.LSTMDecoder
49
+ :members:
50
+
51
+
52
+ Transformer (self-attention) networks
53
+ -------------------------------------
54
+
55
+ .. module:: fairseq.models.transformer
56
+ .. autoclass:: fairseq.models.transformer.TransformerModel
57
+ :members:
58
+ .. autoclass:: fairseq.models.transformer.TransformerEncoder
59
+ :members:
60
+ .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
61
+ :members:
62
+ .. autoclass:: fairseq.models.transformer.TransformerDecoder
63
+ :members:
64
+ .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
65
+ :members:
66
+
67
+
68
+ Adding new models
69
+ -----------------
70
+
71
+ .. currentmodule:: fairseq.models
72
+ .. autofunction:: fairseq.models.register_model
73
+ .. autofunction:: fairseq.models.register_model_architecture
74
+ .. autoclass:: fairseq.models.BaseFairseqModel
75
+ :members:
76
+ :undoc-members:
77
+ .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
78
+ :members:
79
+ :undoc-members:
80
+ .. autoclass:: fairseq.models.FairseqEncoderModel
81
+ :members:
82
+ :undoc-members:
83
+ .. autoclass:: fairseq.models.FairseqLanguageModel
84
+ :members:
85
+ :undoc-members:
86
+ .. autoclass:: fairseq.models.FairseqMultiModel
87
+ :members:
88
+ :undoc-members:
89
+ .. autoclass:: fairseq.models.FairseqEncoder
90
+ :members:
91
+ .. autoclass:: fairseq.models.CompositeEncoder
92
+ :members:
93
+ .. autoclass:: fairseq.models.FairseqDecoder
94
+ :members:
95
+
96
+
97
+ .. _Incremental decoding:
98
+
99
+ Incremental decoding
100
+ --------------------
101
+
102
+ .. autoclass:: fairseq.models.FairseqIncrementalDecoder
103
+ :members:
104
+ :undoc-members:
fairseq/docs/tutorial_classifying_names.rst ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tutorial: Classifying Names with a Character-Level RNN
2
+ ======================================================
3
+
4
+ In this tutorial we will extend fairseq to support *classification* tasks. In
5
+ particular we will re-implement the PyTorch tutorial for `Classifying Names with
6
+ a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_
7
+ in fairseq. It is recommended to quickly skim that tutorial before beginning
8
+ this one.
9
+
10
+ This tutorial covers:
11
+
12
+ 1. **Preprocessing the data** to create dictionaries.
13
+ 2. **Registering a new Model** that encodes an input sentence with a simple RNN
14
+ and predicts the output label.
15
+ 3. **Registering a new Task** that loads our dictionaries and dataset.
16
+ 4. **Training the Model** using the existing command-line tools.
17
+ 5. **Writing an evaluation script** that imports fairseq and allows us to
18
+ interactively evaluate our model on new inputs.
19
+
20
+
21
+ 1. Preprocessing the data
22
+ -------------------------
23
+
24
+ The original tutorial provides raw data, but we'll work with a modified version
25
+ of the data that is already tokenized into characters and split into separate
26
+ train, valid and test sets.
27
+
28
+ Download and extract the data from here:
29
+ `tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_
30
+
31
+ Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess`
32
+ command-line tool to create the dictionaries. While this tool is primarily
33
+ intended for sequence-to-sequence problems, we're able to reuse it here by
34
+ treating the label as a "target" sequence of length 1. We'll also output the
35
+ preprocessed files in "raw" format using the ``--dataset-impl`` option to
36
+ enhance readability:
37
+
38
+ .. code-block:: console
39
+
40
+ > fairseq-preprocess \
41
+ --trainpref names/train --validpref names/valid --testpref names/test \
42
+ --source-lang input --target-lang label \
43
+ --destdir names-bin --dataset-impl raw
44
+
45
+ After running the above command you should see a new directory,
46
+ :file:`names-bin/`, containing the dictionaries for *inputs* and *labels*.
47
+
48
+
49
+ 2. Registering a new Model
50
+ --------------------------
51
+
52
+ Next we'll register a new model in fairseq that will encode an input sentence
53
+ with a simple RNN and predict the output label. Compared to the original PyTorch
54
+ tutorial, our version will also work with batches of data and GPU Tensors.
55
+
56
+ First let's copy the simple RNN module implemented in the `PyTorch tutorial
57
+ <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_.
58
+ Create a new file named :file:`fairseq/models/rnn_classifier.py` with the
59
+ following contents::
60
+
61
+ import torch
62
+ import torch.nn as nn
63
+
64
+ class RNN(nn.Module):
65
+
66
+ def __init__(self, input_size, hidden_size, output_size):
67
+ super(RNN, self).__init__()
68
+
69
+ self.hidden_size = hidden_size
70
+
71
+ self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
72
+ self.i2o = nn.Linear(input_size + hidden_size, output_size)
73
+ self.softmax = nn.LogSoftmax(dim=1)
74
+
75
+ def forward(self, input, hidden):
76
+ combined = torch.cat((input, hidden), 1)
77
+ hidden = self.i2h(combined)
78
+ output = self.i2o(combined)
79
+ output = self.softmax(output)
80
+ return output, hidden
81
+
82
+ def initHidden(self):
83
+ return torch.zeros(1, self.hidden_size)
84
+
85
+ We must also *register* this model with fairseq using the
86
+ :func:`~fairseq.models.register_model` function decorator. Once the model is
87
+ registered we'll be able to use it with the existing :ref:`Command-line Tools`.
88
+
89
+ All registered models must implement the :class:`~fairseq.models.BaseFairseqModel`
90
+ interface, so we'll create a small wrapper class in the same file and register
91
+ it in fairseq with the name ``'rnn_classifier'``::
92
+
93
+ from fairseq.models import BaseFairseqModel, register_model
94
+
95
+ # Note: the register_model "decorator" should immediately precede the
96
+ # definition of the Model class.
97
+
98
+ @register_model('rnn_classifier')
99
+ class FairseqRNNClassifier(BaseFairseqModel):
100
+
101
+ @staticmethod
102
+ def add_args(parser):
103
+ # Models can override this method to add new command-line arguments.
104
+ # Here we'll add a new command-line argument to configure the
105
+ # dimensionality of the hidden state.
106
+ parser.add_argument(
107
+ '--hidden-dim', type=int, metavar='N',
108
+ help='dimensionality of the hidden state',
109
+ )
110
+
111
+ @classmethod
112
+ def build_model(cls, args, task):
113
+ # Fairseq initializes models by calling the ``build_model()``
114
+ # function. This provides more flexibility, since the returned model
115
+ # instance can be of a different type than the one that was called.
116
+ # In this case we'll just return a FairseqRNNClassifier instance.
117
+
118
+ # Initialize our RNN module
119
+ rnn = RNN(
120
+ # We'll define the Task in the next section, but for now just
121
+ # notice that the task holds the dictionaries for the "source"
122
+ # (i.e., the input sentence) and "target" (i.e., the label).
123
+ input_size=len(task.source_dictionary),
124
+ hidden_size=args.hidden_dim,
125
+ output_size=len(task.target_dictionary),
126
+ )
127
+
128
+ # Return the wrapped version of the module
129
+ return FairseqRNNClassifier(
130
+ rnn=rnn,
131
+ input_vocab=task.source_dictionary,
132
+ )
133
+
134
+ def __init__(self, rnn, input_vocab):
135
+ super(FairseqRNNClassifier, self).__init__()
136
+
137
+ self.rnn = rnn
138
+ self.input_vocab = input_vocab
139
+
140
+ # The RNN module in the tutorial expects one-hot inputs, so we can
141
+ # precompute the identity matrix to help convert from indices to
142
+ # one-hot vectors. We register it as a buffer so that it is moved to
143
+ # the GPU when ``cuda()`` is called.
144
+ self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))
145
+
146
+ def forward(self, src_tokens, src_lengths):
147
+ # The inputs to the ``forward()`` function are determined by the
148
+ # Task, and in particular the ``'net_input'`` key in each
149
+ # mini-batch. We'll define the Task in the next section, but for
150
+ # now just know that *src_tokens* has shape `(batch, src_len)` and
151
+ # *src_lengths* has shape `(batch)`.
152
+ bsz, max_src_len = src_tokens.size()
153
+
154
+ # Initialize the RNN hidden state. Compared to the original PyTorch
155
+ # tutorial we'll also handle batched inputs and work on the GPU.
156
+ hidden = self.rnn.initHidden()
157
+ hidden = hidden.repeat(bsz, 1) # expand for batched inputs
158
+ hidden = hidden.to(src_tokens.device) # move to GPU
159
+
160
+ for i in range(max_src_len):
161
+ # WARNING: The inputs have padding, so we should mask those
162
+ # elements here so that padding doesn't affect the results.
163
+ # This is left as an exercise for the reader. The padding symbol
164
+ # is given by ``self.input_vocab.pad()`` and the unpadded length
165
+ # of each input is given by *src_lengths*.
166
+
167
+ # One-hot encode a batch of input characters.
168
+ input = self.one_hot_inputs[src_tokens[:, i].long()]
169
+
170
+ # Feed the input to our RNN.
171
+ output, hidden = self.rnn(input, hidden)
172
+
173
+ # Return the final output state for making a prediction
174
+ return output
175
+
176
+ Finally let's define a *named architecture* with the configuration for our
177
+ model. This is done with the :func:`~fairseq.models.register_model_architecture`
178
+ function decorator. Thereafter this named architecture can be used with the
179
+ ``--arch`` command-line argument, e.g., ``--arch pytorch_tutorial_rnn``::
180
+
181
+ from fairseq.models import register_model_architecture
182
+
183
+ # The first argument to ``register_model_architecture()`` should be the name
184
+ # of the model we registered above (i.e., 'rnn_classifier'). The function we
185
+ # register here should take a single argument *args* and modify it in-place
186
+ # to match the desired architecture.
187
+
188
+ @register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
189
+ def pytorch_tutorial_rnn(args):
190
+ # We use ``getattr()`` to prioritize arguments that are explicitly given
191
+ # on the command-line, so that the defaults defined below are only used
192
+ # when no other value has been specified.
193
+ args.hidden_dim = getattr(args, 'hidden_dim', 128)
194
+
195
+
196
+ 3. Registering a new Task
197
+ -------------------------
198
+
199
+ Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our
200
+ dictionaries and dataset. Tasks can also control how the data is batched into
201
+ mini-batches, but in this tutorial we'll reuse the batching provided by
202
+ :class:`fairseq.data.LanguagePairDataset`.
203
+
204
+ Create a new file named :file:`fairseq/tasks/simple_classification.py` with the
205
+ following contents::
206
+
207
+ import os
208
+ import torch
209
+
210
+ from fairseq.data import Dictionary, LanguagePairDataset
211
+ from fairseq.tasks import LegacyFairseqTask, register_task
212
+
213
+
214
+ @register_task('simple_classification')
215
+ class SimpleClassificationTask(LegacyFairseqTask):
216
+
217
+ @staticmethod
218
+ def add_args(parser):
219
+ # Add some command-line arguments for specifying where the data is
220
+ # located and the maximum supported input length.
221
+ parser.add_argument('data', metavar='FILE',
222
+ help='file prefix for data')
223
+ parser.add_argument('--max-positions', default=1024, type=int,
224
+ help='max input length')
225
+
226
+ @classmethod
227
+ def setup_task(cls, args, **kwargs):
228
+ # Here we can perform any setup required for the task. This may include
229
+ # loading Dictionaries, initializing shared Embedding layers, etc.
230
+ # In this case we'll just load the Dictionaries.
231
+ input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
232
+ label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
233
+ print('| [input] dictionary: {} types'.format(len(input_vocab)))
234
+ print('| [label] dictionary: {} types'.format(len(label_vocab)))
235
+
236
+ return SimpleClassificationTask(args, input_vocab, label_vocab)
237
+
238
+ def __init__(self, args, input_vocab, label_vocab):
239
+ super().__init__(args)
240
+ self.input_vocab = input_vocab
241
+ self.label_vocab = label_vocab
242
+
243
+ def load_dataset(self, split, **kwargs):
244
+ """Load a given dataset split (e.g., train, valid, test)."""
245
+
246
+ prefix = os.path.join(self.args.data, '{}.input-label'.format(split))
247
+
248
+ # Read input sentences.
249
+ sentences, lengths = [], []
250
+ with open(prefix + '.input', encoding='utf-8') as file:
251
+ for line in file:
252
+ sentence = line.strip()
253
+
254
+ # Tokenize the sentence, splitting on spaces
255
+ tokens = self.input_vocab.encode_line(
256
+ sentence, add_if_not_exist=False,
257
+ )
258
+
259
+ sentences.append(tokens)
260
+ lengths.append(tokens.numel())
261
+
262
+ # Read labels.
263
+ labels = []
264
+ with open(prefix + '.label', encoding='utf-8') as file:
265
+ for line in file:
266
+ label = line.strip()
267
+ labels.append(
268
+ # Convert label to a numeric ID.
269
+ torch.LongTensor([self.label_vocab.add_symbol(label)])
270
+ )
271
+
272
+ assert len(sentences) == len(labels)
273
+ print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))
274
+
275
+ # We reuse LanguagePairDataset since classification can be modeled as a
276
+ # sequence-to-sequence task where the target sequence has length 1.
277
+ self.datasets[split] = LanguagePairDataset(
278
+ src=sentences,
279
+ src_sizes=lengths,
280
+ src_dict=self.input_vocab,
281
+ tgt=labels,
282
+ tgt_sizes=torch.ones(len(labels)), # targets have length 1
283
+ tgt_dict=self.label_vocab,
284
+ left_pad_source=False,
285
+ # Since our target is a single class label, there's no need for
286
+ # teacher forcing. If we set this to ``True`` then our Model's
287
+ # ``forward()`` method would receive an additional argument called
288
+ # *prev_output_tokens* that would contain a shifted version of the
289
+ # target sequence.
290
+ input_feeding=False,
291
+ )
292
+
293
+ def max_positions(self):
294
+ """Return the max input length allowed by the task."""
295
+ # The source should be less than *args.max_positions* and the "target"
296
+ # has max length 1.
297
+ return (self.args.max_positions, 1)
298
+
299
+ @property
300
+ def source_dictionary(self):
301
+ """Return the source :class:`~fairseq.data.Dictionary`."""
302
+ return self.input_vocab
303
+
304
+ @property
305
+ def target_dictionary(self):
306
+ """Return the target :class:`~fairseq.data.Dictionary`."""
307
+ return self.label_vocab
308
+
309
+ # We could override this method if we wanted more control over how batches
310
+ # are constructed, but it's not necessary for this tutorial since we can
311
+ # reuse the batching provided by LanguagePairDataset.
312
+ #
313
+ # def get_batch_iterator(
314
+ # self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
315
+ # ignore_invalid_inputs=False, required_batch_size_multiple=1,
316
+ # seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
317
+ # data_buffer_size=0, disable_iterator_cache=False,
318
+ # ):
319
+ # (...)
320
+
321
+
322
+ 4. Training the Model
323
+ ---------------------
324
+
325
+ Now we're ready to train the model. We can use the existing :ref:`fairseq-train`
326
+ command-line tool for this, making sure to specify our new Task (``--task
327
+ simple_classification``) and Model architecture (``--arch
328
+ pytorch_tutorial_rnn``):
329
+
330
+ .. note::
331
+
332
+ You can also configure the dimensionality of the hidden state by passing the
333
+ ``--hidden-dim`` argument to :ref:`fairseq-train`.
334
+
335
+ .. code-block:: console
336
+
337
+ > fairseq-train names-bin \
338
+ --task simple_classification \
339
+ --arch pytorch_tutorial_rnn \
340
+ --optimizer adam --lr 0.001 --lr-shrink 0.5 \
341
+ --max-tokens 1000
342
+ (...)
343
+ | epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
344
+ | epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
345
+ | done training in 31.6 seconds
346
+
347
+ The model files should appear in the :file:`checkpoints/` directory.
348
+
349
+
350
+ 5. Writing an evaluation script
351
+ -------------------------------
352
+
353
+ Finally we can write a short script to evaluate our model on new inputs. Create
354
+ a new file named :file:`eval_classifier.py` with the following contents::
355
+
356
+ from fairseq import checkpoint_utils, data, options, tasks
357
+
358
+ # Parse command-line arguments for generation
359
+ parser = options.get_generation_parser(default_task='simple_classification')
360
+ args = options.parse_args_and_arch(parser)
361
+
362
+ # Setup task
363
+ task = tasks.setup_task(args)
364
+
365
+ # Load model
366
+ print('| loading model from {}'.format(args.path))
367
+ models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
368
+ model = models[0]
369
+
370
+ while True:
371
+ sentence = input('\nInput: ')
372
+
373
+ # Tokenize into characters
374
+ chars = ' '.join(list(sentence.strip()))
375
+ tokens = task.source_dictionary.encode_line(
376
+ chars, add_if_not_exist=False,
377
+ )
378
+
379
+ # Build mini-batch to feed to the model
380
+ batch = data.language_pair_dataset.collate(
381
+ samples=[{'id': -1, 'source': tokens}], # bsz = 1
382
+ pad_idx=task.source_dictionary.pad(),
383
+ eos_idx=task.source_dictionary.eos(),
384
+ left_pad_source=False,
385
+ input_feeding=False,
386
+ )
387
+
388
+ # Feed batch to the model and get predictions
389
+ preds = model(**batch['net_input'])
390
+
391
+ # Print top 3 predictions and their log-probabilities
392
+ top_scores, top_labels = preds[0].topk(k=3)
393
+ for score, label_idx in zip(top_scores, top_labels):
394
+ label_name = task.target_dictionary.string([label_idx])
395
+ print('({:.2f})\t{}'.format(score, label_name))
396
+
397
+ Now we can evaluate our model interactively. Note that we have included the
398
+ original data path (:file:`names-bin/`) so that the dictionaries can be loaded:
399
+
400
+ .. code-block:: console
401
+
402
+ > python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
403
+ | [input] dictionary: 64 types
404
+ | [label] dictionary: 24 types
405
+ | loading model from checkpoints/checkpoint_best.pt
406
+
407
+ Input: Satoshi
408
+ (-0.61) Japanese
409
+ (-1.20) Arabic
410
+ (-2.86) Italian
411
+
412
+ Input: Sinbad
413
+ (-0.30) Arabic
414
+ (-1.76) English
415
+ (-4.08) Russian
fairseq/examples/MMPT/.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ runs
131
+ data
132
+ pretrained_models
133
+ projects/mmfusion_*
134
+ log_test
135
+ third-party
136
+ python_log
137
+ slurm_snapshot_code
138
+ lightning_logs
139
+ demos
fairseq/examples/MMPT/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VideoCLIP and VLM
2
+
3
+ You just find this toolkit for multimodal video understanding! It contains implementation of two recent multi-modal video understanding papers [VideoCLIP](https://arxiv.org/pdf/2109.14084.pdf) (EMNLP, 2021) and [VLM](https://aclanthology.org/2021.findings-acl.370.pdf) (ACL Findings, 2021), along with high-performance toolkits that are typically lacking in existing codebase. The toolkit is desigend to contain generic performance-tuned components that can be potentially adapted to other frameworks (we initially use fairseq).
4
+
5
+ VideoCLIP is a contrastive learning model for zero-shot transfer to retrieval/classification/sequence labeling style tasks.
6
+
7
+ <img src="videoclip.png" width="350" class="center">
8
+
9
+ VLM is a masked language model style pre-training using only one encoder with masked modality model (MMM) for retrieval/generation/sequence labeling style tasks.
10
+
11
+ <img src="vlm.png" width="350" class="center">
12
+
13
+ ### News
14
+ [Oct. 2021] Initial release of implementation for the following papers:
15
+ [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding](https://arxiv.org/pdf/2109.14084.pdf) (Xu et. al., EMNLP 2021)
16
+ [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding](https://aclanthology.org/2021.findings-acl.370.pdf) (Xu et. al., ACL Findings 2021)
17
+
18
+
19
+ ### Installation
20
+ We aim to minimize the dependency of this repo on other packages.
21
+ We use fairseq as the main trainer (no models/datasets dependency on fairseq. We will support other trainer in future):
22
+ ```
23
+ git clone https://github.com/pytorch/fairseq
24
+ cd fairseq
25
+ pip install -e . # also optionally follow fairseq README for apex installation for fp16 training.
26
+ export MKL_THREADING_LAYER=GNU # fairseq may need this for numpy.
27
+ ```
28
+
29
+ Then install this toolkit:
30
+ ```
31
+ cd examples/MMPT # MMPT can be in any folder, not necessarily under fairseq/examples.
32
+ pip install -e .
33
+ ```
34
+
35
+ The code is developed under Python=3.8.8, Pytorch=1.8, cuda=11.0 with fairseq=1.0.0a0+af0389f and tested under Python=3.8.8 pytorch=1.9 cuda=11.0 fairseq=1.0.0a0+8e7bc73 during code release.
36
+ Most models require `transformers==3.4` for API compatibility `pip install transformers==3.4`.
37
+ In addition, some downstream tasks may need `conda install pandas`.
38
+
39
+
40
+ ### Usage
41
+ #### Download Checkpoints
42
+ We use pre-trained [S3D](https://github.com/antoine77340/S3D_HowTo100M) for video feature extraction. Please place the models as `pretrained_models/s3d_dict.npy` and `pretrained_models/s3d_howto100m.pth`.
43
+
44
+ Download VideoCLIP checkpoint `https://dl.fbaipublicfiles.com/MMPT/retri/videoclip/checkpoint_best.pt` to `runs/retri/videoclip` or VLM checkpoint `https://dl.fbaipublicfiles.com/MMPT/mtm/vlm/checkpoint_best.pt` to `runs/mtm/vlm`.
45
+
46
+ #### Demo of Inference
47
+ run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` to get all `.yaml`s for VideoCLIP.
48
+
49
+ ```python
50
+ import torch
51
+
52
+ from mmpt.models import MMPTModel
53
+
54
+
55
+ model, tokenizer, aligner = MMPTModel.from_pretrained(
56
+ "projects/retri/videoclip/how2.yaml")
57
+
58
+ model.eval()
59
+
60
+
61
+ # B, T, FPS, H, W, C (VideoCLIP is trained on 30 fps of s3d)
62
+ video_frames = torch.randn(1, 2, 30, 224, 224, 3)
63
+ caps, cmasks = aligner._build_text_seq(
64
+ tokenizer("some text", add_special_tokens=False)["input_ids"]
65
+ )
66
+
67
+ caps, cmasks = caps[None, :], cmasks[None, :] # bsz=1
68
+
69
+ with torch.no_grad():
70
+ output = model(video_frames, caps, cmasks, return_score=True)
71
+ print(output["score"]) # dot-product
72
+ ```
73
+
74
+ #### Data Preparation
75
+ See [dataset](DATASET.md) for each dataset.
76
+
77
+ #### Global Config for Training Pipeline
78
+ We organize a global config file for a training/testing pipeline under projects (see a detailed [explanation](CONFIG.md)). For example, VideoCLIP in `projects/retri/videoclip.yaml` and VLM is in `projects/mtm/vlm.yaml`.
79
+
80
+ We wrap all cmds into `locallaunch.py` and `mmpt_cli/localjob.py`. You can check concrete cmds by `--dryrun` and then drop it for actual run.
81
+
82
+ First, run `python locallaunch.py projects/retri/videoclip.yaml --dryrun` will generate configs for all configs of pre-training, zero-shot evaluation, fine-tuning and testing, for VideoCLIP under `projects/retri/videoclip`.
83
+
84
+ Then each (either training or evaluation) process will be configed by a concrete config file (we save all complex arguments into the concrete config file for reproducibility, including fairseq args). For example, run zero-shot evaluation on youcook,
85
+ ```
86
+ python locallaunch.py projects/retri/videoclip/test_youcook_zs.yaml --jobtype local_predict # zero-shot evaluation.
87
+ python locallaunch.py projects/retri/videoclip/youcook_videoclip.yaml --jobtype local_single --dryrun # fine-tuning: use --dryrun to check cmds and drop it to make an actual run; local_small will run on two gpus (as in paper).
88
+ python locallaunch.py projects/retri/videoclip/test_youcook_videoclip.yaml --jobtype local_predict # testing on fine-tuned model.
89
+ ```
90
+
91
+ Pretraining can be run as:
92
+ ```
93
+ python locallaunch.py projects/retri/videoclip/how2.yaml --jobtype local_single --dryrun # check then drop dryrun; paper is ran on local_big as 8 gpus.
94
+ ```
95
+ You may need to change `--jobtype`, check/extend `LocalJob` in `mmpt_cli/localjob.py` for multi-gpu/multi-node pre-training.
96
+
97
+ The detailed instructions of pretraining and fine-tuning can be found at [pretraining instruction](pretraining.md) and [finetuning instruction](endtask.md).
98
+
99
+
100
+ ### Development
101
+ Several components of this toolkit can be re-used for future research (and also our ongoing research).
102
+
103
+ #### Framework Wrapper
104
+ We currently only support fairseq, but most components can be easily fit into other frameworks like huggingface. This repo is a `--user-dir` of fairseq with fairseq wrapper. For example, `mmpt/tasks` includes a `FairseqMMTTask`, which manages `mmpt/datasets` with `FairseqDataset`, `mmpt/models` with `FairseqModel`, `mmpt/losses` with `FairseqCriterion`.
105
+
106
+ #### Processors
107
+ **Multi**modal research introduces the complexity on modality alignment from different input sources to losses. Inspired by [MMF](https://github.com/facebookresearch/mmf), this toolkit leverages `mmpt/processors` to handle various needs of data preprocessing and loading, **alleviating** the needs of multiple `torch.data.utils.Dataset` (that can be tricky for ablation study).
108
+ Processors can also be decoupled from `torch.data.utils.Dataset` for offline preprocessing instead of on-the-fly data preprocessing.
109
+
110
+ We decouple a `mmpt.MMDataset` as 3 types of processors: `MetaProcessor`, `VideoProcessor`, `TextProcessor` and `Aligner`. They can be configed in `dataset` field of a config file (e.g., see `projects/task/how2.yaml`).
111
+ `MetaProcessor` is used to load the meta data about a dataset, aka, all video_ids of how2 dataset.
112
+ `VideoProcessor` is used to load the video features about a dataset. For example, S3D features for each second of a video.
113
+ `TextProcessor` is used to load the text (feature). For example, BERT pre-tokenized text clips for how2 dataset (with `start`s, `end`s of timestamps and `cap` for `token_ids`).
114
+ `Aligner` is the core class for different baselines that prepares the training data. For example, sampling a clip, masking tokens for MLM, etc.
115
+
116
+ #### Performance-tuned Components
117
+ To speed up pre-training, this toolkit uses sharded features stored in mmaped numpy, backed by `ShardedTensor` in `mmpt/utils/shardedtensor.py` (adopted from MARGE paper). This reduces the loads of IO for multi-GPU training without loading all features for a video into the memory each time and `ShardedTensor` ensure features are stored in continuous disk space for near random access. This is used for both How2 video features and texts in `mmpt/processors/how2processor.py`.
118
+
119
+
120
+ ### Citation
121
+ If this codebase is useful for your work, please cite the following papers:
122
+
123
+ ```BibTeX
124
+ @inproceedings{xu-etal-2021-videoclip,
125
+ title = "{VideoCLIP}: Contrastive Pre-training for\\Zero-shot Video-Text Understanding",
126
+ author = "Xu, Hu and
127
+ Ghosh, Gargi and
128
+ Huang, Po-Yao and
129
+ Okhonko, Dmytro and
130
+ Aghajanyan, Armen and
131
+ Metze, Florian and
132
+ Zettlemoyer, Luke and
133
+ Feichtenhofer, Christoph",
134
+ booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)",
135
+ month = nov,
136
+ year = "2021",
137
+ address = "Online",
138
+ publisher = "Association for Computational Linguistics",
139
+ }
140
+
141
+ @inproceedings{xu-etal-2021-vlm,
142
+ title = "{VLM}: Task-agnostic Video-Language Model Pre-training for Video Understanding",
143
+ author = "Xu, Hu and
144
+ Ghosh, Gargi and
145
+ Huang, Po-Yao and
146
+ Arora, Prahal and
147
+ Aminzadeh, Masoumeh and
148
+ Feichtenhofer, Christoph and
149
+ Metze, Florian and
150
+ Zettlemoyer, Luke",
151
+ booktitle = "Findings of the Association for Computational Linguistics: ACL-IJCNLP 2021",
152
+ month = aug,
153
+ year = "2021",
154
+ address = "Online",
155
+ publisher = "Association for Computational Linguistics",
156
+ url = "https://aclanthology.org/2021.findings-acl.370",
157
+ doi = "10.18653/v1/2021.findings-acl.370",
158
+ pages = "4227--4239",
159
+ }
160
+ ```
161
+
162
+ ### Bug Reports
163
+ This repo is in its initial stage, welcome bug reports to [email protected]
164
+
165
+ ### Copyright
166
+ The majority of Multimodal Pre-training (MMPT) is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Evaluation Codes/Models: Howto100M and HuggingFace Transformers are licensed under the Apache2.0 license; COIN and NLG-eval are licensed under the MIT license; CrossTask is licensed under the BSD-3; DiDeMo is licensed under the BSD-2 license.
fairseq/examples/MMPT/endtask.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Zero-shot Transfer and Finetuning
2
+
3
+ (If you are new to the ideas of `mmpt.processors`, see [README](README.md) first.)
4
+ All finetuning datasets (specifically `processors`) are defined in `mmpt.processors.dsprocessor`.
5
+ Given the complexity of different types of finetuning tasks, each task may have their own meta/video/text/aligner processors and `mmpt/evaluators/{Predictor,Metric}`.
6
+
7
+ ### Tasks
8
+
9
+ Currently, we support 5 end datasets: `MSRVTT`, `Youcook`, `COIN`, `Crosstask` and `DiDeMo` with the following tasks:
10
+ text-video retrieval: `MSRVTT`, `Youcook`, `DiDeMo`;
11
+ video captioning: `Youcook`;
12
+ Video Question and Answering: `MSRVTT-QA`.
13
+
14
+ To add your own dataset, you can specify the corresponding processors and config them in the `dataset` field of a config file, such as `projects/task/vtt.yaml`.
15
+
16
+ ### Zero-shot Transfer (no Training)
17
+ Zero-shot transfer will run the pre-trained model (e.g., VideoCLIP) directly on testing data. Configs with pattern: `projects/task/*_zs_*.yaml` are dedicated for zero-shot transfer.
18
+
19
+ ### Fine-tuning
20
+
21
+ The training of a downstream task is similar to pretraining, execept you may need to specify the `restore_file` in `fairseq.checkpoint` and reset optimizers, see `projects/task/ft.yaml` that is included by `projects/task/vtt.yaml`.
22
+
23
+ We typically do finetuning on 2 gpus (`local_small`).
24
+
25
+ ### Testing
26
+ For each finetuning dataset, you may need to specify a testing config, similar to `projects/task/test_vtt.yaml`.
27
+
28
+ We define `mmpt.evaluators.Predictor` for different types of prediction. For example, `MSRVTT` and `Youcook` are video-retrieval tasks and expecting to use `RetrievalPredictor`. You may need to define your new type of predictors and specify that in `predictor` field of a testing config.
29
+
30
+ Each task may also have their own metric for evaluation. This can be created in `mmpt.evaluators.Metric` and specified in the `metric` field of a testing config.
31
+
32
+ Launching a testing is as simple as training by specifying the path of a testing config:
33
+ ```python locallaunch.py projects/mfmmlm/test_vtt.yaml```
34
+ Testing will be launched locally by default since prediction is computationally less expensive.
35
+
36
+ ### Third-party Libraries
37
+ We list the following finetuning tasks that require third-party libraries.
38
+
39
+ Youcook captioning: `https://github.com/Maluuba/nlg-eval`
40
+
41
+ CrossTask: `https://github.com/DmZhukov/CrossTask`'s `dp` under `third-party/CrossTask` (`python setup.py build_ext --inplace`)
fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ TODO (huxu): fairseq wrapper class for all dataset you defined: mostly MMDataset.
7
+ """
8
+
9
+ from collections import OrderedDict
10
+
11
+ from torch.utils.data import Dataset
12
+ from torch.utils.data.dataloader import default_collate
13
+ from fairseq.data import FairseqDataset, data_utils
14
+
15
+
16
+ class FairseqMMDataset(FairseqDataset):
17
+ """
18
+ A wrapper class for MMDataset for fairseq.
19
+ """
20
+
21
+ def __init__(self, mmdataset):
22
+ if not isinstance(mmdataset, Dataset):
23
+ raise TypeError("mmdataset must be of type `torch.utils.data.dataset`.")
24
+ self.mmdataset = mmdataset
25
+
26
+ def set_epoch(self, epoch, **unused):
27
+ super().set_epoch(epoch)
28
+ self.epoch = epoch
29
+
30
+ def __getitem__(self, idx):
31
+ with data_utils.numpy_seed(43211, self.epoch, idx):
32
+ return self.mmdataset[idx]
33
+
34
+ def __len__(self):
35
+ return len(self.mmdataset)
36
+
37
+ def collater(self, samples):
38
+ if hasattr(self.mmdataset, "collator"):
39
+ return self.mmdataset.collator(samples)
40
+ if len(samples) == 0:
41
+ return {}
42
+ if isinstance(samples[0], dict):
43
+ batch = OrderedDict()
44
+ for key in samples[0]:
45
+ if samples[0][key] is not None:
46
+ batch[key] = default_collate([sample[key] for sample in samples])
47
+ return batch
48
+ else:
49
+ return default_collate(samples)
50
+
51
+ def size(self, index):
52
+ """dummy implementation: we don't use --max-tokens"""
53
+ return 1
54
+
55
+ def num_tokens(self, index):
56
+ """dummy implementation: we don't use --max-tokens"""
57
+ return 1
fairseq/examples/MMPT/mmpt/datasets/mmdataset.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from collections import OrderedDict
9
+
10
+ from torch.utils.data import Dataset
11
+ from torch.utils.data.dataloader import default_collate
12
+
13
+ from ..utils import set_seed
14
+
15
+
16
+ class MMDataset(Dataset):
17
+ """
18
+ A generic multi-modal dataset.
19
+ Args:
20
+ `meta_processor`: a meta processor,
21
+ handling loading meta data and return video_id and text_id.
22
+ `video_processor`: a video processor,
23
+ handling e.g., decoding, loading .np files.
24
+ `text_processor`: a text processor,
25
+ handling e.g., tokenization.
26
+ `aligner`: combine the video and text feature
27
+ as one training example.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ meta_processor,
33
+ video_processor,
34
+ text_processor,
35
+ align_processor,
36
+ ):
37
+ self.split = meta_processor.split
38
+ self.meta_processor = meta_processor
39
+ self.video_processor = video_processor
40
+ self.text_processor = text_processor
41
+ self.align_processor = align_processor
42
+
43
+ def __len__(self):
44
+ return len(self.meta_processor)
45
+
46
+ def __getitem__(self, idx):
47
+ if self.split == "test":
48
+ set_seed(idx)
49
+ video_id, text_id = self.meta_processor[idx]
50
+ video_feature = self.video_processor(video_id)
51
+ text_feature = self.text_processor(text_id)
52
+ output = self.align_processor(video_id, video_feature, text_feature)
53
+ # TODO (huxu): the following is for debug purpose.
54
+ output.update({"idx": idx})
55
+ return output
56
+
57
+ def collater(self, samples):
58
+ """This collator is deprecated.
59
+ set self.collator = MMDataset.collater.
60
+ see collator in FairseqMMDataset.
61
+ """
62
+
63
+ if len(samples) == 0:
64
+ return {}
65
+ if isinstance(samples[0], dict):
66
+ batch = OrderedDict()
67
+ for key in samples[0]:
68
+ if samples[0][key] is not None:
69
+ batch[key] = default_collate(
70
+ [sample[key] for sample in samples])
71
+ # if torch.is_tensor(batch[key]):
72
+ # print(key, batch[key].size())
73
+ # else:
74
+ # print(key, len(batch[key]))
75
+ return batch
76
+ else:
77
+ return default_collate(samples)
78
+
79
+ def print_example(self, output):
80
+ print("[one example]", output["video_id"])
81
+ if (
82
+ hasattr(self.align_processor, "subsampling")
83
+ and self.align_processor.subsampling is not None
84
+ and self.align_processor.subsampling > 1
85
+ ):
86
+ for key in output:
87
+ if torch.is_tensor(output[key]):
88
+ output[key] = output[key][0]
89
+
90
+ # search tokenizer to translate ids back.
91
+ tokenizer = None
92
+ if hasattr(self.text_processor, "tokenizer"):
93
+ tokenizer = self.text_processor.tokenizer
94
+ elif hasattr(self.align_processor, "tokenizer"):
95
+ tokenizer = self.align_processor.tokenizer
96
+ if tokenizer is not None:
97
+ caps = output["caps"].tolist()
98
+ if isinstance(caps[0], list):
99
+ caps = caps[0]
100
+ print("caps", tokenizer.decode(caps))
101
+ print("caps", tokenizer.convert_ids_to_tokens(caps))
102
+
103
+ for key, value in output.items():
104
+ if torch.is_tensor(value):
105
+ if len(value.size()) >= 3: # attention_mask.
106
+ print(key, value.size())
107
+ print(key, "first", value[0, :, :])
108
+ print(key, "last", value[-1, :, :])
109
+ else:
110
+ print(key, value)
111
+ print("[end of one example]")
fairseq/examples/MMPT/mmpt/evaluators/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .metric import *
6
+ from .evaluator import *
7
+
8
+
9
+ # experimental.
10
+ try:
11
+ from .expmetric import *
12
+ except ImportError:
13
+ pass
fairseq/examples/MMPT/mmpt/evaluators/evaluator.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import glob
7
+ import numpy as np
8
+
9
+ from . import metric as metric_path
10
+ from . import predictor as predictor_path
11
+
12
+
13
+ class Evaluator(object):
14
+ """
15
+ perform evaluation on a single (downstream) task.
16
+ make this both offline and online.
17
+ TODO(huxu) saving evaluation results.
18
+ """
19
+
20
+ def __init__(self, config, eval_dataloader=None):
21
+ if config.metric is None:
22
+ raise ValueError("config.metric is", config.metric)
23
+ metric_cls = getattr(metric_path, config.metric)
24
+ self.metric = metric_cls(config)
25
+ if config.predictor is None:
26
+ raise ValueError("config.predictor is", config.predictor)
27
+ predictor_cls = getattr(predictor_path, config.predictor)
28
+ self.predictor = predictor_cls(config)
29
+ self.eval_dataloader = eval_dataloader
30
+
31
+ def __call__(self):
32
+ try:
33
+ print(self.predictor.pred_dir)
34
+ for pred_file in glob.glob(
35
+ self.predictor.pred_dir + "/*_merged.npy"):
36
+ outputs = np.load(pred_file)
37
+ results = self.metric.compute_metrics(outputs)
38
+ self.metric.print_computed_metrics(results)
39
+
40
+ outputs = np.load(os.path.join(
41
+ self.predictor.pred_dir, "merged.npy"))
42
+ results = self.metric.compute_metrics(outputs)
43
+ return {"results": results, "metric": self.metric}
44
+ except FileNotFoundError:
45
+ print("\n[missing]", self.predictor.pred_dir)
46
+ return {}
47
+
48
+ def evaluate(self, model, eval_dataloader=None, output_file="merged"):
49
+ if eval_dataloader is None:
50
+ eval_dataloader = self.eval_dataloader
51
+ outputs = self.predictor.predict_loop(
52
+ model, eval_dataloader, output_file)
53
+ results = self.metric.compute_metrics(**outputs)
54
+ return results
fairseq/examples/MMPT/mmpt/models/transformermodel.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+ import torch
19
+
20
+ from torch import nn
21
+
22
+ try:
23
+ from transformers.modeling_bert import (
24
+ BertPreTrainedModel,
25
+ BertModel,
26
+ BertEncoder,
27
+ BertPredictionHeadTransform,
28
+ )
29
+ except ImportError:
30
+ pass
31
+
32
+ from ..modules import VideoTokenMLP, MMBertEmbeddings
33
+
34
+
35
+ # --------------- fine-tuning models ---------------
36
+ class MMBertForJoint(BertPreTrainedModel):
37
+ """A BertModel with isolated attention mask to separate modality."""
38
+
39
+ def __init__(self, config):
40
+ super().__init__(config)
41
+ self.videomlp = VideoTokenMLP(config)
42
+ self.bert = MMBertModel(config)
43
+ self.init_weights()
44
+
45
+ def forward(
46
+ self,
47
+ input_ids=None,
48
+ input_video_embeds=None,
49
+ attention_mask=None,
50
+ token_type_ids=None,
51
+ position_ids=None,
52
+ head_mask=None,
53
+ inputs_embeds=None,
54
+ next_sentence_label=None,
55
+ output_attentions=None,
56
+ output_hidden_states=None,
57
+ return_dict=None,
58
+ separate_forward_split=None,
59
+ ):
60
+ return_dict = (
61
+ return_dict if return_dict is not None
62
+ else self.config.use_return_dict
63
+ )
64
+ video_tokens = self.videomlp(input_video_embeds)
65
+
66
+ outputs = self.bert(
67
+ input_ids,
68
+ video_tokens,
69
+ attention_mask=attention_mask,
70
+ token_type_ids=token_type_ids,
71
+ position_ids=position_ids,
72
+ head_mask=head_mask,
73
+ inputs_embeds=inputs_embeds,
74
+ output_attentions=output_attentions,
75
+ output_hidden_states=output_hidden_states,
76
+ return_dict=return_dict,
77
+ separate_forward_split=separate_forward_split,
78
+ )
79
+
80
+ return outputs
81
+
82
+
83
+ class MMBertForTokenClassification(BertPreTrainedModel):
84
+ """A BertModel similar to MMJointUni, with extra wrapper layer
85
+ to be fine-tuned from other pretrained MMFusion model."""
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+ self.videomlp = VideoTokenMLP(config)
90
+ self.bert = MMBertModel(config)
91
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
92
+ # TODO(huxu): 779 is the number of classes for COIN: move to config?
93
+ self.classifier = nn.Linear(config.hidden_size, 779)
94
+ self.init_weights()
95
+
96
+ def forward(
97
+ self,
98
+ input_ids=None,
99
+ input_video_embeds=None,
100
+ attention_mask=None,
101
+ token_type_ids=None,
102
+ position_ids=None,
103
+ head_mask=None,
104
+ inputs_embeds=None,
105
+ next_sentence_label=None,
106
+ output_attentions=None,
107
+ output_hidden_states=None,
108
+ return_dict=None,
109
+ separate_forward_split=None,
110
+ ):
111
+ return_dict = (
112
+ return_dict if return_dict is not None
113
+ else self.config.use_return_dict
114
+ )
115
+
116
+ video_tokens = self.videomlp(input_video_embeds)
117
+ outputs = self.bert(
118
+ input_ids,
119
+ video_tokens,
120
+ attention_mask=attention_mask,
121
+ token_type_ids=token_type_ids,
122
+ position_ids=position_ids,
123
+ head_mask=head_mask,
124
+ inputs_embeds=inputs_embeds,
125
+ output_attentions=output_attentions,
126
+ output_hidden_states=output_hidden_states,
127
+ return_dict=return_dict,
128
+ separate_forward_split=separate_forward_split,
129
+ )
130
+
131
+ return (self.classifier(outputs[0]),)
132
+
133
+
134
+ # ------------ pre-training models ----------------
135
+
136
+ class MMBertForEncoder(BertPreTrainedModel):
137
+ """A BertModel for Contrastive Learning."""
138
+ def __init__(self, config):
139
+ super().__init__(config)
140
+ self.videomlp = VideoTokenMLP(config)
141
+ self.bert = MMBertModel(config)
142
+ self.init_weights()
143
+
144
+ def forward(
145
+ self,
146
+ input_ids=None,
147
+ input_video_embeds=None,
148
+ attention_mask=None,
149
+ token_type_ids=None,
150
+ position_ids=None,
151
+ head_mask=None,
152
+ inputs_embeds=None,
153
+ output_attentions=None,
154
+ output_hidden_states=None,
155
+ return_dict=None,
156
+ ):
157
+ return_dict = (
158
+ return_dict if return_dict is not None
159
+ else self.config.use_return_dict
160
+ )
161
+ if input_video_embeds is not None:
162
+ video_tokens = self.videomlp(input_video_embeds)
163
+ else:
164
+ video_tokens = None
165
+
166
+ outputs = self.bert(
167
+ input_ids,
168
+ video_tokens,
169
+ attention_mask=attention_mask,
170
+ token_type_ids=token_type_ids,
171
+ position_ids=position_ids,
172
+ head_mask=head_mask,
173
+ inputs_embeds=inputs_embeds,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ )
178
+ return outputs
179
+
180
+
181
+ class MMBertForMFMMLM(BertPreTrainedModel):
182
+ """A BertModel with shared prediction head on MFM-MLM."""
183
+ def __init__(self, config):
184
+ super().__init__(config)
185
+ self.videomlp = VideoTokenMLP(config)
186
+ self.bert = MMBertModel(config)
187
+ self.cls = MFMMLMHead(config)
188
+ self.hidden_size = config.hidden_size
189
+ self.init_weights()
190
+
191
+ def get_output_embeddings(self):
192
+ return self.cls.predictions.decoder
193
+
194
+ def forward(
195
+ self,
196
+ input_ids=None,
197
+ input_video_embeds=None,
198
+ attention_mask=None,
199
+ token_type_ids=None,
200
+ position_ids=None,
201
+ head_mask=None,
202
+ inputs_embeds=None,
203
+ masked_frame_labels=None,
204
+ target_video_hidden_states=None,
205
+ non_masked_frame_mask=None,
206
+ masked_lm_labels=None,
207
+ output_attentions=None,
208
+ output_hidden_states=None,
209
+ return_dict=None,
210
+ ):
211
+ return_dict = (
212
+ return_dict if return_dict is not None
213
+ else self.config.use_return_dict
214
+ )
215
+ if input_video_embeds is not None:
216
+ video_tokens = self.videomlp(input_video_embeds)
217
+ else:
218
+ video_tokens = None
219
+
220
+ if target_video_hidden_states is not None:
221
+ target_video_hidden_states = self.videomlp(
222
+ target_video_hidden_states)
223
+
224
+ non_masked_frame_hidden_states = video_tokens.masked_select(
225
+ non_masked_frame_mask.unsqueeze(-1)
226
+ ).view(-1, self.hidden_size)
227
+
228
+ outputs = self.bert(
229
+ input_ids,
230
+ video_tokens,
231
+ attention_mask=attention_mask,
232
+ token_type_ids=token_type_ids,
233
+ position_ids=position_ids,
234
+ head_mask=head_mask,
235
+ inputs_embeds=inputs_embeds,
236
+ output_attentions=output_attentions,
237
+ output_hidden_states=output_hidden_states,
238
+ return_dict=return_dict,
239
+ )
240
+
241
+ sequence_output = outputs[0]
242
+
243
+ mfm_scores, prediction_scores = None, None
244
+ if masked_frame_labels is not None and masked_lm_labels is not None:
245
+ # split the sequence.
246
+ text_offset = masked_frame_labels.size(1) + 1 # [CLS]
247
+ video_sequence_output = sequence_output[
248
+ :, 1:text_offset
249
+ ] # remove [SEP] as not in video_label.
250
+ text_sequence_output = torch.cat(
251
+ [sequence_output[:, :1], sequence_output[:, text_offset:]],
252
+ dim=1
253
+ )
254
+
255
+ hidden_size = video_sequence_output.size(-1)
256
+ selected_video_output = video_sequence_output.masked_select(
257
+ masked_frame_labels.unsqueeze(-1)
258
+ ).view(-1, hidden_size)
259
+
260
+ # only compute select tokens to training to speed up.
261
+ hidden_size = text_sequence_output.size(-1)
262
+ # masked_lm_labels = masked_lm_labels.reshape(-1)
263
+ labels_mask = masked_lm_labels != -100
264
+
265
+ selected_text_output = text_sequence_output.masked_select(
266
+ labels_mask.unsqueeze(-1)
267
+ ).view(-1, hidden_size)
268
+ mfm_scores, prediction_scores = self.cls(
269
+ selected_video_output,
270
+ target_video_hidden_states,
271
+ non_masked_frame_hidden_states,
272
+ selected_text_output,
273
+ )
274
+
275
+ output = (
276
+ mfm_scores,
277
+ prediction_scores,
278
+ ) + outputs
279
+ return output
280
+
281
+
282
+ class BertMFMMLMPredictionHead(nn.Module):
283
+ def __init__(self, config):
284
+ super().__init__()
285
+ self.transform = BertPredictionHeadTransform(config)
286
+ # The output weights are the same as the input embeddings, but there is
287
+ # an output-only bias for each token.
288
+ self.decoder = nn.Linear(
289
+ config.hidden_size, config.vocab_size, bias=False)
290
+
291
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
292
+
293
+ # Need a link between the two variables so that the bias is correctly
294
+ # resized with `resize_token_embeddings`
295
+ self.decoder.bias = self.bias
296
+
297
+ def forward(
298
+ self,
299
+ video_hidden_states=None,
300
+ target_video_hidden_states=None,
301
+ non_masked_frame_hidden_states=None,
302
+ text_hidden_states=None,
303
+ ):
304
+ video_logits, text_logits = None, None
305
+ if video_hidden_states is not None:
306
+ video_hidden_states = self.transform(video_hidden_states)
307
+ non_masked_frame_logits = torch.mm(
308
+ video_hidden_states,
309
+ non_masked_frame_hidden_states.transpose(1, 0)
310
+ )
311
+ masked_frame_logits = torch.bmm(
312
+ video_hidden_states.unsqueeze(1),
313
+ target_video_hidden_states.unsqueeze(-1),
314
+ ).squeeze(-1)
315
+ video_logits = torch.cat(
316
+ [masked_frame_logits, non_masked_frame_logits], dim=1
317
+ )
318
+
319
+ if text_hidden_states is not None:
320
+ text_hidden_states = self.transform(text_hidden_states)
321
+ text_logits = self.decoder(text_hidden_states)
322
+ return video_logits, text_logits
323
+
324
+
325
+ class MFMMLMHead(nn.Module):
326
+ def __init__(self, config):
327
+ super().__init__()
328
+ self.predictions = BertMFMMLMPredictionHead(config)
329
+
330
+ def forward(
331
+ self,
332
+ video_hidden_states=None,
333
+ target_video_hidden_states=None,
334
+ non_masked_frame_hidden_states=None,
335
+ text_hidden_states=None,
336
+ ):
337
+ video_logits, text_logits = self.predictions(
338
+ video_hidden_states,
339
+ target_video_hidden_states,
340
+ non_masked_frame_hidden_states,
341
+ text_hidden_states,
342
+ )
343
+ return video_logits, text_logits
344
+
345
+
346
+ class MMBertForMTM(MMBertForMFMMLM):
347
+ def __init__(self, config):
348
+ BertPreTrainedModel.__init__(self, config)
349
+ self.videomlp = VideoTokenMLP(config)
350
+ self.bert = MMBertModel(config)
351
+ self.cls = MTMHead(config)
352
+ self.hidden_size = config.hidden_size
353
+ self.init_weights()
354
+
355
+
356
+ class BertMTMPredictionHead(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.transform = BertPredictionHeadTransform(config)
360
+ self.decoder = nn.Linear(
361
+ config.hidden_size, config.vocab_size, bias=False)
362
+
363
+ def forward(
364
+ self,
365
+ video_hidden_states=None,
366
+ target_video_hidden_states=None,
367
+ non_masked_frame_hidden_states=None,
368
+ text_hidden_states=None,
369
+ ):
370
+ non_masked_frame_hidden_states = non_masked_frame_hidden_states.transpose(1, 0)
371
+ video_logits, text_logits = None, None
372
+ if video_hidden_states is not None:
373
+ video_hidden_states = self.transform(video_hidden_states)
374
+
375
+ masked_frame_logits = torch.bmm(
376
+ video_hidden_states.unsqueeze(1),
377
+ target_video_hidden_states.unsqueeze(-1),
378
+ ).squeeze(-1)
379
+
380
+ non_masked_frame_logits = torch.mm(
381
+ video_hidden_states,
382
+ non_masked_frame_hidden_states
383
+ )
384
+ video_on_vocab_logits = self.decoder(video_hidden_states)
385
+ video_logits = torch.cat([
386
+ masked_frame_logits,
387
+ non_masked_frame_logits,
388
+ video_on_vocab_logits], dim=1)
389
+
390
+ if text_hidden_states is not None:
391
+ text_hidden_states = self.transform(text_hidden_states)
392
+ # text first so label does not need to be shifted.
393
+ text_on_vocab_logits = self.decoder(text_hidden_states)
394
+ text_on_video_logits = torch.mm(
395
+ text_hidden_states,
396
+ non_masked_frame_hidden_states
397
+ )
398
+ text_logits = torch.cat([
399
+ text_on_vocab_logits,
400
+ text_on_video_logits
401
+ ], dim=1)
402
+
403
+ return video_logits, text_logits
404
+
405
+
406
+ class MTMHead(nn.Module):
407
+ def __init__(self, config):
408
+ super().__init__()
409
+ self.predictions = BertMTMPredictionHead(config)
410
+
411
+ def forward(
412
+ self,
413
+ video_hidden_states=None,
414
+ target_video_hidden_states=None,
415
+ non_masked_frame_hidden_states=None,
416
+ text_hidden_states=None,
417
+ ):
418
+ video_logits, text_logits = self.predictions(
419
+ video_hidden_states,
420
+ target_video_hidden_states,
421
+ non_masked_frame_hidden_states,
422
+ text_hidden_states,
423
+ )
424
+ return video_logits, text_logits
425
+
426
+
427
+ class MMBertModel(BertModel):
428
+ """MMBertModel has MMBertEmbedding to support video tokens."""
429
+
430
+ def __init__(self, config, add_pooling_layer=True):
431
+ super().__init__(config)
432
+ # overwrite embedding
433
+ self.embeddings = MMBertEmbeddings(config)
434
+ self.encoder = MultiLayerAttentionMaskBertEncoder(config)
435
+ self.init_weights()
436
+
437
+ def forward(
438
+ self,
439
+ input_ids=None,
440
+ input_video_embeds=None,
441
+ attention_mask=None,
442
+ token_type_ids=None,
443
+ position_ids=None,
444
+ head_mask=None,
445
+ inputs_embeds=None,
446
+ encoder_hidden_states=None,
447
+ encoder_attention_mask=None,
448
+ output_attentions=None,
449
+ output_hidden_states=None,
450
+ return_dict=None,
451
+ separate_forward_split=None,
452
+ ):
453
+ output_attentions = (
454
+ output_attentions
455
+ if output_attentions is not None
456
+ else self.config.output_attentions
457
+ )
458
+ output_hidden_states = (
459
+ output_hidden_states
460
+ if output_hidden_states is not None
461
+ else self.config.output_hidden_states
462
+ )
463
+ return_dict = (
464
+ return_dict if return_dict is not None
465
+ else self.config.use_return_dict
466
+ )
467
+
468
+ if input_ids is not None and inputs_embeds is not None:
469
+ raise ValueError(
470
+ "You cannot specify both input_ids "
471
+ "and inputs_embeds at the same time"
472
+ )
473
+ elif input_ids is not None:
474
+ if input_video_embeds is not None:
475
+ input_shape = (
476
+ input_ids.size(0),
477
+ input_ids.size(1) + input_video_embeds.size(1),
478
+ )
479
+ else:
480
+ input_shape = (
481
+ input_ids.size(0),
482
+ input_ids.size(1),
483
+ )
484
+ elif inputs_embeds is not None:
485
+ if input_video_embeds is not None:
486
+ input_shape = (
487
+ inputs_embeds.size(0),
488
+ inputs_embeds.size(1) + input_video_embeds.size(1),
489
+ )
490
+ else:
491
+ input_shape = (
492
+ input_ids.size(0),
493
+ input_ids.size(1),
494
+ )
495
+ else:
496
+ raise ValueError(
497
+ "You have to specify either input_ids or inputs_embeds")
498
+
499
+ device = input_ids.device if input_ids is not None \
500
+ else inputs_embeds.device
501
+
502
+ if attention_mask is None:
503
+ attention_mask = torch.ones(input_shape, device=device)
504
+ if token_type_ids is None:
505
+ token_type_ids = torch.zeros(
506
+ input_shape, dtype=torch.long, device=device)
507
+
508
+ # We can provide a self-attention mask of dimensions
509
+ # [batch_size, from_seq_length, to_seq_length]
510
+ # ourselves in which case
511
+ # we just need to make it broadcastable to all heads.
512
+ extended_attention_mask: torch.Tensor = \
513
+ self.get_extended_attention_mask(
514
+ attention_mask, input_shape, device)
515
+
516
+ # If a 2D or 3D attention mask is provided for the cross-attention
517
+ # we need to make broadcastable to
518
+ # [batch_size, num_heads, seq_length, seq_length]
519
+ if self.config.is_decoder and encoder_hidden_states is not None:
520
+ (
521
+ encoder_batch_size,
522
+ encoder_sequence_length,
523
+ _,
524
+ ) = encoder_hidden_states.size()
525
+ encoder_hidden_shape = (
526
+ encoder_batch_size, encoder_sequence_length)
527
+ if encoder_attention_mask is None:
528
+ encoder_attention_mask = torch.ones(
529
+ encoder_hidden_shape, device=device)
530
+ encoder_extended_attention_mask = self.invert_attention_mask(
531
+ encoder_attention_mask
532
+ )
533
+ else:
534
+ encoder_extended_attention_mask = None
535
+
536
+ # Prepare head mask if needed
537
+ # 1.0 in head_mask indicate we keep the head
538
+ # attention_probs has shape bsz x n_heads x N x N
539
+ # input head_mask has shape [num_heads] or
540
+ # [num_hidden_layers x num_heads]
541
+ # and head_mask is converted to shape
542
+ # [num_hidden_layers x batch x num_heads x seq_length x seq_length]
543
+
544
+ head_mask = self.get_head_mask(
545
+ head_mask, self.config.num_hidden_layers)
546
+
547
+ embedding_output = self.embeddings(
548
+ input_ids,
549
+ input_video_embeds,
550
+ position_ids=position_ids,
551
+ token_type_ids=token_type_ids,
552
+ inputs_embeds=inputs_embeds,
553
+ )
554
+
555
+ if separate_forward_split is not None:
556
+ split_embedding_output = \
557
+ embedding_output[:, :separate_forward_split]
558
+ split_extended_attention_mask = extended_attention_mask[
559
+ :, :, :, :separate_forward_split, :separate_forward_split
560
+ ]
561
+ split_encoder_outputs = self.encoder(
562
+ split_embedding_output,
563
+ attention_mask=split_extended_attention_mask,
564
+ head_mask=head_mask,
565
+ encoder_hidden_states=encoder_hidden_states,
566
+ encoder_attention_mask=encoder_extended_attention_mask,
567
+ output_attentions=output_attentions,
568
+ output_hidden_states=output_hidden_states,
569
+ return_dict=return_dict,
570
+ )
571
+ assert (
572
+ len(split_encoder_outputs) <= 2
573
+ ), "we do not support merge on attention for now."
574
+ encoder_outputs = []
575
+ encoder_outputs.append([split_encoder_outputs[0]])
576
+ if len(split_encoder_outputs) == 2:
577
+ encoder_outputs.append([])
578
+ for _all_hidden_states in split_encoder_outputs[1]:
579
+ encoder_outputs[-1].append([_all_hidden_states])
580
+
581
+ split_embedding_output = \
582
+ embedding_output[:, separate_forward_split:]
583
+ split_extended_attention_mask = extended_attention_mask[
584
+ :, :, :, separate_forward_split:, separate_forward_split:
585
+ ]
586
+
587
+ split_encoder_outputs = self.encoder(
588
+ split_embedding_output,
589
+ attention_mask=split_extended_attention_mask,
590
+ head_mask=head_mask,
591
+ encoder_hidden_states=encoder_hidden_states,
592
+ encoder_attention_mask=encoder_extended_attention_mask,
593
+ output_attentions=output_attentions,
594
+ output_hidden_states=output_hidden_states,
595
+ return_dict=return_dict,
596
+ )
597
+
598
+ assert (
599
+ len(split_encoder_outputs) <= 2
600
+ ), "we do not support merge on attention for now."
601
+ encoder_outputs[0].append(split_encoder_outputs[0])
602
+ encoder_outputs[0] = torch.cat(encoder_outputs[0], dim=1)
603
+ if len(split_encoder_outputs) == 2:
604
+ for layer_idx, _all_hidden_states in enumerate(
605
+ split_encoder_outputs[1]
606
+ ):
607
+ encoder_outputs[1][layer_idx].append(_all_hidden_states)
608
+ encoder_outputs[1][layer_idx] = torch.cat(
609
+ encoder_outputs[1][layer_idx], dim=1
610
+ )
611
+ encoder_outputs = tuple(encoder_outputs)
612
+ else:
613
+ encoder_outputs = self.encoder(
614
+ embedding_output,
615
+ attention_mask=extended_attention_mask,
616
+ head_mask=head_mask,
617
+ encoder_hidden_states=encoder_hidden_states,
618
+ encoder_attention_mask=encoder_extended_attention_mask,
619
+ output_attentions=output_attentions,
620
+ output_hidden_states=output_hidden_states,
621
+ return_dict=return_dict,
622
+ )
623
+
624
+ sequence_output = encoder_outputs[0]
625
+ pooled_output = (
626
+ self.pooler(sequence_output) if self.pooler is not None else None
627
+ )
628
+
629
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
630
+
631
+ def get_extended_attention_mask(self, attention_mask, input_shape, device):
632
+ """This is borrowed from `modeling_utils.py` with the support of
633
+ multi-layer attention masks.
634
+ The second dim is expected to be number of layers.
635
+ See `MMAttentionMaskProcessor`.
636
+ Makes broadcastable attention and causal masks so that future
637
+ and masked tokens are ignored.
638
+
639
+ Arguments:
640
+ attention_mask (:obj:`torch.Tensor`):
641
+ Mask with ones indicating tokens to attend to,
642
+ zeros for tokens to ignore.
643
+ input_shape (:obj:`Tuple[int]`):
644
+ The shape of the input to the model.
645
+ device: (:obj:`torch.device`):
646
+ The device of the input to the model.
647
+
648
+ Returns:
649
+ :obj:`torch.Tensor` The extended attention mask, \
650
+ with a the same dtype as :obj:`attention_mask.dtype`.
651
+ """
652
+ # We can provide a self-attention mask of dimensions
653
+ # [batch_size, from_seq_length, to_seq_length]
654
+ # ourselves in which case we just need to make it broadcastable
655
+ # to all heads.
656
+ if attention_mask.dim() == 4:
657
+ extended_attention_mask = attention_mask[:, :, None, :, :]
658
+ extended_attention_mask = extended_attention_mask.to(
659
+ dtype=self.dtype
660
+ ) # fp16 compatibility
661
+ extended_attention_mask = (1.0 - extended_attention_mask) \
662
+ * -10000.0
663
+ return extended_attention_mask
664
+ else:
665
+ return super().get_extended_attention_mask(
666
+ attention_mask, input_shape, device
667
+ )
668
+
669
+
670
+ class MultiLayerAttentionMaskBertEncoder(BertEncoder):
671
+ """extend BertEncoder with the capability of
672
+ multiple layers of attention mask."""
673
+
674
+ def forward(
675
+ self,
676
+ hidden_states,
677
+ attention_mask=None,
678
+ head_mask=None,
679
+ encoder_hidden_states=None,
680
+ encoder_attention_mask=None,
681
+ output_attentions=False,
682
+ output_hidden_states=False,
683
+ return_dict=False,
684
+ ):
685
+ all_hidden_states = () if output_hidden_states else None
686
+ all_attentions = () if output_attentions else None
687
+ for i, layer_module in enumerate(self.layer):
688
+ if output_hidden_states:
689
+ all_hidden_states = all_hidden_states + (hidden_states,)
690
+ layer_head_mask = head_mask[i] if head_mask is not None else None
691
+
692
+ layer_attention_mask = (
693
+ attention_mask[:, i, :, :, :]
694
+ if attention_mask.dim() == 5
695
+ else attention_mask
696
+ )
697
+
698
+ if getattr(self.config, "gradient_checkpointing", False):
699
+
700
+ def create_custom_forward(module):
701
+ def custom_forward(*inputs):
702
+ return module(*inputs, output_attentions)
703
+
704
+ return custom_forward
705
+
706
+ layer_outputs = torch.utils.checkpoint.checkpoint(
707
+ create_custom_forward(layer_module),
708
+ hidden_states,
709
+ layer_attention_mask,
710
+ layer_head_mask,
711
+ encoder_hidden_states,
712
+ encoder_attention_mask,
713
+ )
714
+ else:
715
+ layer_outputs = layer_module(
716
+ hidden_states,
717
+ layer_attention_mask,
718
+ layer_head_mask,
719
+ encoder_hidden_states,
720
+ encoder_attention_mask,
721
+ output_attentions,
722
+ )
723
+ hidden_states = layer_outputs[0]
724
+ if output_attentions:
725
+ all_attentions = all_attentions + (layer_outputs[1],)
726
+
727
+ if output_hidden_states:
728
+ all_hidden_states = all_hidden_states + (hidden_states,)
729
+
730
+ return tuple(
731
+ v
732
+ for v in [hidden_states, all_hidden_states, all_attentions]
733
+ if v is not None
734
+ )
fairseq/examples/MMPT/mmpt/modules/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from .mm import *
6
+
7
+ try:
8
+ from .expmm import *
9
+ except ImportError:
10
+ pass
fairseq/examples/MMPT/mmpt/modules/mm.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+
19
+ import torch
20
+
21
+ from torch import nn
22
+
23
+ try:
24
+ from transformers.modeling_bert import (
25
+ BertEmbeddings,
26
+ ACT2FN,
27
+ )
28
+ except ImportError:
29
+ pass
30
+
31
+
32
+ class VideoTokenMLP(nn.Module):
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ input_dim = config.input_dim if hasattr(config, "input_dim") else 512
36
+ self.linear1 = nn.Linear(input_dim, config.hidden_size)
37
+ self.LayerNorm = nn.LayerNorm(config.hidden_size)
38
+ self.activation = ACT2FN[config.hidden_act]
39
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
40
+
41
+ def forward(self, hidden_states):
42
+ hidden_states = self.linear1(hidden_states)
43
+ hidden_states = self.activation(hidden_states)
44
+ hidden_states = self.LayerNorm(hidden_states)
45
+ hidden_states = self.linear2(hidden_states)
46
+ return hidden_states
47
+
48
+
49
+ class MMBertEmbeddings(BertEmbeddings):
50
+ def __init__(self, config):
51
+ super().__init__(config)
52
+ self.max_video_len = config.max_video_len
53
+ if hasattr(config, "use_seg_emb") and config.use_seg_emb:
54
+ """the original VLM paper uses seg_embeddings for temporal space.
55
+ although not used it changed the randomness of initialization.
56
+ we keep it for reproducibility.
57
+ """
58
+ self.seg_embeddings = nn.Embedding(256, config.hidden_size)
59
+
60
+ def forward(
61
+ self,
62
+ input_ids,
63
+ input_video_embeds,
64
+ token_type_ids=None,
65
+ position_ids=None,
66
+ inputs_embeds=None,
67
+ ):
68
+ input_tensor = input_ids if input_ids is not None else inputs_embeds
69
+ if input_video_embeds is not None:
70
+ input_shape = (
71
+ input_tensor.size(0),
72
+ input_tensor.size(1) + input_video_embeds.size(1),
73
+ )
74
+ else:
75
+ input_shape = (input_tensor.size(0), input_tensor.size(1))
76
+
77
+ if position_ids is None:
78
+ """
79
+ Auto skip position embeddings for text only case.
80
+ use cases:
81
+ (1) action localization and segmentation:
82
+ feed in len-1 dummy video token needs text part to
83
+ skip input_video_embeds.size(1) for the right
84
+ position_ids for video [SEP] and rest text tokens.
85
+ (2) MMFusionShare for two forward passings:
86
+ in `forward_text`: input_video_embeds is None.
87
+ need to skip video [SEP] token.
88
+
89
+ # video_len + 1: [CLS] + video_embed
90
+ # self.max_video_len + 1: [SEP] for video.
91
+ # self.max_video_len + 2: [SEP] for video.
92
+ # self.max_video_len + input_ids.size(1): rest for text.
93
+ """
94
+ if input_video_embeds is not None:
95
+ video_len = input_video_embeds.size(1)
96
+ starting_offset = self.max_video_len + 1 # video [SEP]
97
+ ending_offset = self.max_video_len + input_ids.size(1)
98
+ else:
99
+ video_len = 0
100
+ starting_offset = self.max_video_len + 2 # first text token.
101
+ ending_offset = self.max_video_len + input_ids.size(1) + 1
102
+ position_ids = torch.cat([
103
+ self.position_ids[:, :video_len + 1],
104
+ self.position_ids[:, starting_offset:ending_offset]
105
+ ], dim=1)
106
+
107
+ if token_type_ids is None:
108
+ token_type_ids = torch.zeros(
109
+ input_shape, dtype=torch.long, device=self.position_ids.device
110
+ )
111
+
112
+ """
113
+ the format of input_ids is [CLS] [SEP] caption [SEP] padding.
114
+ the goal is to build [CLS] video tokens [SEP] caption [SEP] .
115
+ """
116
+ if inputs_embeds is None:
117
+ inputs_embeds = self.word_embeddings(input_ids)
118
+ if input_video_embeds is not None:
119
+ inputs_mm_embeds = torch.cat([
120
+ inputs_embeds[:, :1], input_video_embeds, inputs_embeds[:, 1:]
121
+ ], dim=1)
122
+ else:
123
+ # text only for `MMFusionShare`.
124
+ inputs_mm_embeds = inputs_embeds
125
+
126
+ position_embeddings = self.position_embeddings(position_ids)
127
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
128
+ embeddings = inputs_mm_embeds + position_embeddings
129
+ embeddings += token_type_embeddings
130
+
131
+ embeddings = self.LayerNorm(embeddings)
132
+ embeddings = self.dropout(embeddings)
133
+ return embeddings
134
+
135
+
136
+ class AlignHead(nn.Module):
137
+ """this will load pre-trained weights for NSP, which is desirable."""
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
142
+
143
+ def forward(self, dropout_pooled_output):
144
+ logits = self.seq_relationship(dropout_pooled_output)
145
+ return logits
fairseq/examples/MMPT/mmpt/modules/retri.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import numpy as np
7
+ import pickle
8
+ import time
9
+
10
+ try:
11
+ import faiss
12
+ except ImportError:
13
+ pass
14
+
15
+ from collections import defaultdict
16
+
17
+ from ..utils import get_local_rank, print_on_rank0
18
+
19
+
20
+ class VectorRetriever(object):
21
+ """
22
+ How2 Video Retriver.
23
+ Reference usage of FAISS:
24
+ https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
25
+ """
26
+
27
+ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
28
+ if db_type == "flatl2":
29
+ quantizer = faiss.IndexFlatL2(hidden_size) # the other index
30
+ self.db = faiss.IndexIVFFlat(
31
+ quantizer, hidden_size, cent, faiss.METRIC_L2)
32
+ elif db_type == "pq":
33
+ self.db = faiss.index_factory(
34
+ hidden_size, f"IVF{cent}_HNSW32,PQ32"
35
+ )
36
+ else:
37
+ raise ValueError("unknown type of db", db_type)
38
+ self.train_thres = cent * examples_per_cent_to_train
39
+ self.train_cache = []
40
+ self.train_len = 0
41
+ self.videoid_to_vectoridx = {}
42
+ self.vectoridx_to_videoid = None
43
+ self.make_direct_maps_done = False
44
+
45
+ def make_direct_maps(self):
46
+ faiss.downcast_index(self.db).make_direct_map()
47
+
48
+ def __len__(self):
49
+ return self.db.ntotal
50
+
51
+ def save(self, out_dir):
52
+ faiss.write_index(
53
+ self.db,
54
+ os.path.join(out_dir, "faiss_idx")
55
+ )
56
+ with open(
57
+ os.path.join(
58
+ out_dir, "videoid_to_vectoridx.pkl"),
59
+ "wb") as fw:
60
+ pickle.dump(
61
+ self.videoid_to_vectoridx, fw,
62
+ protocol=pickle.HIGHEST_PROTOCOL
63
+ )
64
+
65
+ def load(self, out_dir):
66
+ fn = os.path.join(out_dir, "faiss_idx")
67
+ self.db = faiss.read_index(fn)
68
+ with open(
69
+ os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
70
+ self.videoid_to_vectoridx = pickle.load(fr)
71
+
72
+ def add(self, hidden_states, video_ids, last=False):
73
+ assert len(hidden_states) == len(video_ids), "{}, {}".format(
74
+ str(len(hidden_states)), str(len(video_ids)))
75
+ assert len(hidden_states.shape) == 2
76
+ assert hidden_states.dtype == np.float32
77
+
78
+ valid_idx = []
79
+ for idx, video_id in enumerate(video_ids):
80
+ if video_id not in self.videoid_to_vectoridx:
81
+ valid_idx.append(idx)
82
+ self.videoid_to_vectoridx[video_id] = \
83
+ len(self.videoid_to_vectoridx)
84
+
85
+ hidden_states = hidden_states[valid_idx]
86
+ if not self.db.is_trained:
87
+ self.train_cache.append(hidden_states)
88
+ self.train_len += hidden_states.shape[0]
89
+ if self.train_len < self.train_thres:
90
+ return
91
+ self.finalize_training()
92
+ else:
93
+ self.db.add(hidden_states)
94
+
95
+ def finalize_training(self):
96
+ hidden_states = np.concatenate(self.train_cache, axis=0)
97
+ del self.train_cache
98
+ local_rank = get_local_rank()
99
+ if local_rank == 0:
100
+ start = time.time()
101
+ print("training db on", self.train_thres, "/", self.train_len)
102
+ self.db.train(hidden_states[:self.train_thres])
103
+ if local_rank == 0:
104
+ print("training db for", time.time() - start)
105
+ self.db.add(hidden_states)
106
+
107
+ def search(
108
+ self,
109
+ query_hidden_states,
110
+ orig_dist,
111
+ ):
112
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
113
+ raise ValueError(
114
+ "cannot search: size mismatch in-between index and db",
115
+ len(self.videoid_to_vectoridx),
116
+ self.db.ntotal
117
+ )
118
+
119
+ if self.vectoridx_to_videoid is None:
120
+ self.vectoridx_to_videoid = {
121
+ self.videoid_to_vectoridx[videoid]: videoid
122
+ for videoid in self.videoid_to_vectoridx
123
+ }
124
+ assert len(self.vectoridx_to_videoid) \
125
+ == len(self.videoid_to_vectoridx)
126
+
127
+ # MultilingualFaissDataset uses the following; not sure the purpose.
128
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
129
+ queried_dist, index = self.db.search(query_hidden_states, 1)
130
+ queried_dist, index = queried_dist[:, 0], index[:, 0]
131
+
132
+ outputs = np.array(
133
+ [self.vectoridx_to_videoid[_index]
134
+ if _index != -1 else (-1, -1, -1) for _index in index],
135
+ dtype=np.int32)
136
+ outputs[queried_dist <= orig_dist] = -1
137
+ return outputs
138
+
139
+ def search_by_video_ids(
140
+ self,
141
+ video_ids,
142
+ retri_factor
143
+ ):
144
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
145
+ raise ValueError(
146
+ len(self.videoid_to_vectoridx),
147
+ self.db.ntotal
148
+ )
149
+
150
+ if not self.make_direct_maps_done:
151
+ self.make_direct_maps()
152
+
153
+ if self.vectoridx_to_videoid is None:
154
+ self.vectoridx_to_videoid = {
155
+ self.videoid_to_vectoridx[videoid]: videoid
156
+ for videoid in self.videoid_to_vectoridx
157
+ }
158
+ assert len(self.vectoridx_to_videoid) \
159
+ == len(self.videoid_to_vectoridx)
160
+
161
+ query_hidden_states = []
162
+ vector_ids = []
163
+ for video_id in video_ids:
164
+ vector_id = self.videoid_to_vectoridx[video_id]
165
+ vector_ids.append(vector_id)
166
+ query_hidden_state = self.db.reconstruct(vector_id)
167
+ query_hidden_states.append(query_hidden_state)
168
+ query_hidden_states = np.stack(query_hidden_states)
169
+
170
+ # MultilingualFaissDataset uses the following; not sure the reason.
171
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
172
+ _, index = self.db.search(query_hidden_states, retri_factor)
173
+ outputs = []
174
+ for sample_idx, sample in enumerate(index):
175
+ # the first video_id is always the video itself.
176
+ cands = [video_ids[sample_idx]]
177
+ for vector_idx in sample:
178
+ if vector_idx >= 0 \
179
+ and vector_ids[sample_idx] != vector_idx:
180
+ cands.append(
181
+ self.vectoridx_to_videoid[vector_idx]
182
+ )
183
+ outputs.append(cands)
184
+ return outputs
185
+
186
+
187
+ class VectorRetrieverDM(VectorRetriever):
188
+ """
189
+ with direct map.
190
+ How2 Video Retriver.
191
+ Reference usage of FAISS:
192
+ https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ hidden_size,
198
+ cent,
199
+ db_type,
200
+ examples_per_cent_to_train
201
+ ):
202
+ super().__init__(
203
+ hidden_size, cent, db_type, examples_per_cent_to_train)
204
+ self.make_direct_maps_done = False
205
+
206
+ def make_direct_maps(self):
207
+ faiss.downcast_index(self.db).make_direct_map()
208
+ self.make_direct_maps_done = True
209
+
210
+ def search(
211
+ self,
212
+ query_hidden_states,
213
+ orig_dist,
214
+ ):
215
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
216
+ raise ValueError(
217
+ len(self.videoid_to_vectoridx),
218
+ self.db.ntotal
219
+ )
220
+
221
+ if not self.make_direct_maps_done:
222
+ self.make_direct_maps()
223
+ if self.vectoridx_to_videoid is None:
224
+ self.vectoridx_to_videoid = {
225
+ self.videoid_to_vectoridx[videoid]: videoid
226
+ for videoid in self.videoid_to_vectoridx
227
+ }
228
+ assert len(self.vectoridx_to_videoid) \
229
+ == len(self.videoid_to_vectoridx)
230
+
231
+ # MultilingualFaissDataset uses the following; not sure the reason.
232
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
233
+ queried_dist, index = self.db.search(query_hidden_states, 1)
234
+ outputs = []
235
+ for sample_idx, sample in enumerate(index):
236
+ # and queried_dist[sample_idx] < thres \
237
+ if sample >= 0 \
238
+ and queried_dist[sample_idx] < orig_dist[sample_idx]:
239
+ outputs.append(self.vectoridx_to_videoid[sample])
240
+ else:
241
+ outputs.append(None)
242
+ return outputs
243
+
244
+ def search_by_video_ids(
245
+ self,
246
+ video_ids,
247
+ retri_factor=8
248
+ ):
249
+ if len(self.videoid_to_vectoridx) != self.db.ntotal:
250
+ raise ValueError(
251
+ len(self.videoid_to_vectoridx),
252
+ self.db.ntotal
253
+ )
254
+
255
+ if not self.make_direct_maps_done:
256
+ self.make_direct_maps()
257
+ if self.vectoridx_to_videoid is None:
258
+ self.vectoridx_to_videoid = {
259
+ self.videoid_to_vectoridx[videoid]: videoid
260
+ for videoid in self.videoid_to_vectoridx
261
+ }
262
+ assert len(self.vectoridx_to_videoid) \
263
+ == len(self.videoid_to_vectoridx)
264
+
265
+ query_hidden_states = []
266
+ vector_ids = []
267
+ for video_id in video_ids:
268
+ vector_id = self.videoid_to_vectoridx[video_id]
269
+ vector_ids.append(vector_id)
270
+ query_hidden_state = self.db.reconstruct(vector_id)
271
+ query_hidden_states.append(query_hidden_state)
272
+ query_hidden_states = np.stack(query_hidden_states)
273
+
274
+ # MultilingualFaissDataset uses the following; not sure the reason.
275
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
276
+ _, index = self.db.search(query_hidden_states, retri_factor)
277
+ outputs = []
278
+ for sample_idx, sample in enumerate(index):
279
+ # the first video_id is always the video itself.
280
+ cands = [video_ids[sample_idx]]
281
+ for vector_idx in sample:
282
+ if vector_idx >= 0 \
283
+ and vector_ids[sample_idx] != vector_idx:
284
+ cands.append(
285
+ self.vectoridx_to_videoid[vector_idx]
286
+ )
287
+ outputs.append(cands)
288
+ return outputs
289
+
290
+
291
+ class MMVectorRetriever(VectorRetrieverDM):
292
+ """
293
+ multimodal vector retriver:
294
+ text retrieve video or video retrieve text.
295
+ """
296
+
297
+ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train):
298
+ super().__init__(
299
+ hidden_size, cent, db_type, examples_per_cent_to_train)
300
+ video_db = self.db
301
+ super().__init__(
302
+ hidden_size, cent, db_type, examples_per_cent_to_train)
303
+ text_db = self.db
304
+ self.db = {"video": video_db, "text": text_db}
305
+ self.video_to_videoid = defaultdict(list)
306
+
307
+ def __len__(self):
308
+ assert self.db["video"].ntotal == self.db["text"].ntotal
309
+ return self.db["video"].ntotal
310
+
311
+ def make_direct_maps(self):
312
+ faiss.downcast_index(self.db["video"]).make_direct_map()
313
+ faiss.downcast_index(self.db["text"]).make_direct_map()
314
+
315
+ def save(self, out_dir):
316
+ faiss.write_index(
317
+ self.db["video"],
318
+ os.path.join(out_dir, "video_faiss_idx")
319
+ )
320
+ faiss.write_index(
321
+ self.db["text"],
322
+ os.path.join(out_dir, "text_faiss_idx")
323
+ )
324
+
325
+ with open(
326
+ os.path.join(
327
+ out_dir, "videoid_to_vectoridx.pkl"),
328
+ "wb") as fw:
329
+ pickle.dump(
330
+ self.videoid_to_vectoridx, fw,
331
+ protocol=pickle.HIGHEST_PROTOCOL
332
+ )
333
+
334
+ def load(self, out_dir):
335
+ fn = os.path.join(out_dir, "video_faiss_idx")
336
+ video_db = faiss.read_index(fn)
337
+ fn = os.path.join(out_dir, "text_faiss_idx")
338
+ text_db = faiss.read_index(fn)
339
+ self.db = {"video": video_db, "text": text_db}
340
+ with open(
341
+ os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr:
342
+ self.videoid_to_vectoridx = pickle.load(fr)
343
+ self.video_to_videoid = defaultdict(list)
344
+
345
+ def add(self, hidden_states, video_ids):
346
+ """hidden_states is a pair `(video, text)`"""
347
+ assert len(hidden_states) == len(video_ids), "{}, {}".format(
348
+ str(len(hidden_states)), str(len(video_ids)))
349
+ assert len(hidden_states.shape) == 3
350
+ assert len(self.video_to_videoid) == 0
351
+
352
+ valid_idx = []
353
+ for idx, video_id in enumerate(video_ids):
354
+ if video_id not in self.videoid_to_vectoridx:
355
+ valid_idx.append(idx)
356
+ self.videoid_to_vectoridx[video_id] = \
357
+ len(self.videoid_to_vectoridx)
358
+
359
+ batch_size = hidden_states.shape[0]
360
+ hidden_states = hidden_states[valid_idx]
361
+
362
+ hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy()
363
+ if not self.db["video"].is_trained:
364
+ self.train_cache.append(hidden_states)
365
+ train_len = batch_size * len(self.train_cache)
366
+ if train_len < self.train_thres:
367
+ return
368
+
369
+ hidden_states = np.concatenate(self.train_cache, axis=1)
370
+ del self.train_cache
371
+ self.db["video"].train(hidden_states[0, :self.train_thres])
372
+ self.db["text"].train(hidden_states[1, :self.train_thres])
373
+ self.db["video"].add(hidden_states[0])
374
+ self.db["text"].add(hidden_states[1])
375
+
376
+ def get_clips_by_video_id(self, video_id):
377
+ if not self.video_to_videoid:
378
+ for video_id, video_clip, text_clip in self.videoid_to_vectoridx:
379
+ self.video_to_videoid[video_id].append(
380
+ (video_id, video_clip, text_clip))
381
+ return self.video_to_videoid[video_id]
382
+
383
+ def search(
384
+ self,
385
+ video_ids,
386
+ target_modality,
387
+ retri_factor=8
388
+ ):
389
+ if len(self.videoid_to_vectoridx) != len(self):
390
+ raise ValueError(
391
+ len(self.videoid_to_vectoridx),
392
+ len(self)
393
+ )
394
+
395
+ if not self.make_direct_maps_done:
396
+ self.make_direct_maps()
397
+ if self.vectoridx_to_videoid is None:
398
+ self.vectoridx_to_videoid = {
399
+ self.videoid_to_vectoridx[videoid]: videoid
400
+ for videoid in self.videoid_to_vectoridx
401
+ }
402
+ assert len(self.vectoridx_to_videoid) \
403
+ == len(self.videoid_to_vectoridx)
404
+
405
+ src_modality = "text" if target_modality == "video" else "video"
406
+
407
+ query_hidden_states = []
408
+ vector_ids = []
409
+ for video_id in video_ids:
410
+ vector_id = self.videoid_to_vectoridx[video_id]
411
+ vector_ids.append(vector_id)
412
+ query_hidden_state = self.db[src_modality].reconstruct(vector_id)
413
+ query_hidden_states.append(query_hidden_state)
414
+ query_hidden_states = np.stack(query_hidden_states)
415
+
416
+ # MultilingualFaissDataset uses the following; not sure the reason.
417
+ # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10)
418
+ _, index = self.db[target_modality].search(
419
+ query_hidden_states, retri_factor)
420
+ outputs = []
421
+ for sample_idx, sample in enumerate(index):
422
+ cands = []
423
+ for vector_idx in sample:
424
+ if vector_idx >= 0:
425
+ cands.append(
426
+ self.vectoridx_to_videoid[vector_idx]
427
+ )
428
+ outputs.append(cands)
429
+ return outputs
fairseq/examples/MMPT/mmpt/modules/vectorpool.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. All Rights Reserved
2
+
3
+ import torch
4
+ import os
5
+ import numpy as np
6
+ import pickle
7
+
8
+ from . import retri
9
+ from ..utils import get_local_rank
10
+
11
+
12
+ class VectorPool(object):
13
+ """
14
+ Base class of retrieval space.
15
+ """
16
+
17
+ def __init__(self, config):
18
+ from transformers import AutoConfig
19
+ self.hidden_size = AutoConfig.from_pretrained(
20
+ config.dataset.bert_name).hidden_size
21
+ self.retriever_cls = getattr(retri, config.retriever_cls)
22
+
23
+ def __call__(self, sample, **kwargs):
24
+ raise NotImplementedError
25
+
26
+ def build_retriver(
27
+ self,
28
+ retriever_cls=None,
29
+ hidden_size=None,
30
+ centroids=512,
31
+ db_type="flatl2",
32
+ examples_per_cent_to_train=48
33
+ ):
34
+
35
+ """merge results from multiple gpus and return a retriver.."""
36
+ self.retriver = retriever_cls(
37
+ hidden_size, centroids, db_type, examples_per_cent_to_train)
38
+ return self.retriver
39
+
40
+ def __repr__(self):
41
+ if hasattr(self, "retriver"):
42
+ retriver_name = str(len(self.retriver))
43
+ else:
44
+ retriver_name = "no retriver field yet"
45
+ return self.__class__.__name__ \
46
+ + "(" + retriver_name + ")"
47
+
48
+
49
+ class VideoVectorPool(VectorPool):
50
+ """
51
+ average clips of a video as video representation.
52
+ """
53
+ def __init__(self, config):
54
+ super().__init__(config)
55
+ self.build_retriver(self.retriever_cls, self.hidden_size)
56
+
57
+ def __call__(self, sample, subsampling, **kwargs):
58
+ hidden_states = (
59
+ sample["pooled_video"] + sample["pooled_text"]) / 2.
60
+ hidden_states = hidden_states.view(
61
+ -1, subsampling,
62
+ hidden_states.size(-1))
63
+ hidden_states = torch.mean(hidden_states, dim=1)
64
+ hidden_states = hidden_states.cpu().detach().numpy()
65
+ video_ids = []
66
+ for offset_idx, video_id in enumerate(sample["video_id"]):
67
+ if isinstance(video_id, tuple) and len(video_id) == 3:
68
+ # a sharded video_id.
69
+ video_id = video_id[0]
70
+ video_ids.append(video_id)
71
+ assert len(video_ids) == len(hidden_states)
72
+ self.retriver.add(
73
+ hidden_states.astype("float32"),
74
+ video_ids
75
+ )
76
+
77
+
78
+ class DistributedVectorPool(VectorPool):
79
+ """
80
+ support sync of multiple gpus/nodes.
81
+ """
82
+ def __init__(self, config):
83
+ super().__init__(config)
84
+ self.out_dir = os.path.join(
85
+ config.fairseq.checkpoint.save_dir,
86
+ "retri")
87
+ os.makedirs(self.out_dir, exist_ok=True)
88
+ self.hidden_states = []
89
+ self.video_ids = []
90
+
91
+ def build_retriver(
92
+ self,
93
+ retriever_cls=None,
94
+ hidden_size=None,
95
+ centroids=4096,
96
+ db_type="flatl2",
97
+ examples_per_cent_to_train=48
98
+ ):
99
+ if retriever_cls is None:
100
+ retriever_cls = self.retriever_cls
101
+ if hidden_size is None:
102
+ hidden_size = self.hidden_size
103
+ """merge results from multiple gpus and return a retriver.."""
104
+ if torch.distributed.is_initialized():
105
+ self.save()
106
+ # sync saving.
107
+ torch.distributed.barrier()
108
+ world_size = torch.distributed.get_world_size()
109
+ else:
110
+ world_size = 1
111
+ self.retriver = retriever_cls(
112
+ hidden_size, centroids, db_type, examples_per_cent_to_train)
113
+ # each gpu process has its own retriever.
114
+ for local_rank in range(world_size):
115
+ if get_local_rank() == 0:
116
+ print("load local_rank", local_rank)
117
+ hidden_states, video_ids = self.load(local_rank)
118
+ hidden_states = hidden_states.astype("float32")
119
+ self.retriver.add(hidden_states, video_ids)
120
+ return self.retriver
121
+
122
+ def load(self, local_rank):
123
+ hidden_states = np.load(
124
+ os.path.join(
125
+ self.out_dir,
126
+ "hidden_state" + str(local_rank) + ".npy"
127
+ )
128
+ )
129
+
130
+ with open(
131
+ os.path.join(
132
+ self.out_dir, "video_id" + str(local_rank) + ".pkl"),
133
+ "rb") as fr:
134
+ video_ids = pickle.load(fr)
135
+ return hidden_states, video_ids
136
+
137
+ def save(self):
138
+ hidden_states = np.vstack(self.hidden_states)
139
+ assert len(hidden_states) == len(self.video_ids), "{}, {}".format(
140
+ len(hidden_states),
141
+ len(self.video_ids)
142
+ )
143
+ local_rank = torch.distributed.get_rank() \
144
+ if torch.distributed.is_initialized() else 0
145
+
146
+ np.save(
147
+ os.path.join(
148
+ self.out_dir,
149
+ "hidden_state" + str(local_rank) + ".npy"),
150
+ hidden_states)
151
+
152
+ with open(
153
+ os.path.join(
154
+ self.out_dir,
155
+ "video_id" + str(local_rank) + ".pkl"),
156
+ "wb") as fw:
157
+ pickle.dump(
158
+ self.video_ids,
159
+ fw,
160
+ protocol=pickle.HIGHEST_PROTOCOL
161
+ )
162
+
163
+
164
+ class DistributedVideoVectorPool(DistributedVectorPool):
165
+ """
166
+ average clips of a video as video representation.
167
+ """
168
+ def __call__(self, sample, subsampling, **kwargs):
169
+ hidden_states = (
170
+ sample["pooled_video"] + sample["pooled_text"]) / 2.
171
+ hidden_states = hidden_states.view(
172
+ -1, subsampling,
173
+ hidden_states.size(-1))
174
+ hidden_states = torch.mean(hidden_states, dim=1)
175
+ hidden_states = hidden_states.cpu().detach().numpy()
176
+ video_ids = []
177
+ for offset_idx, video_id in enumerate(sample["video_id"]):
178
+ if isinstance(video_id, tuple) and len(video_id) == 3:
179
+ # a sharded video_id.
180
+ video_id = video_id[0]
181
+ video_ids.append(video_id)
182
+ assert len(video_ids) == len(hidden_states)
183
+ self.hidden_states.append(hidden_states)
184
+ self.video_ids.extend(video_ids)
185
+
186
+
187
+ # ------------ the following are deprecated --------------
188
+
189
+ class TextClipVectorPool(VectorPool):
190
+ def __init__(self, config):
191
+ from transformers import AutoConfig
192
+ hidden_size = AutoConfig.from_pretrained(
193
+ config.dataset.bert_name).hidden_size
194
+ retriever_cls = getattr(retri, config.retriever_cls)
195
+ self.build_retriver(retriever_cls, hidden_size)
196
+
197
+ def __call__(self, sample, **kwargs):
198
+ clip_meta = sample["clip_meta"].cpu()
199
+ assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5]))
200
+ text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]]
201
+
202
+ if hasattr(self, "retriver"):
203
+ # build_retriver is called.
204
+ self.retriver.add(
205
+ sample["pooled_text"].cpu().numpy().astype("float32"),
206
+ text_meta
207
+ )
208
+ else:
209
+ raise NotImplementedError
210
+
211
+
212
+ class MMClipVectorPool(VectorPool):
213
+ """
214
+ Multimodal Clip-level vector pool.
215
+ """
216
+ def __init__(self, out_dir):
217
+ """use hidden_states to store `(video, text)`."""
218
+ """use video_ids to store `(video_id, start, end)`."""
219
+ super().__init__(out_dir)
220
+
221
+ def __call__(self, sample, **kwargs):
222
+ pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy()
223
+ pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy()
224
+
225
+ self.hidden_states.append(
226
+ np.concatenate([pooled_video, pooled_text], axis=1)
227
+ )
228
+
229
+ video_starts = sample["video_start"].cpu()
230
+ video_ends = sample["video_end"].cpu()
231
+ assert torch.all(torch.le(video_starts, video_ends))
232
+
233
+ text_starts = sample["text_start"].cpu()
234
+ text_ends = sample["text_end"].cpu()
235
+ assert torch.all(torch.le(text_starts, text_ends))
236
+ subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"])
237
+ video_ids = [video_id for video_id in sample["video_id"]
238
+ for _ in range(subsample_size)
239
+ ]
240
+ for video_id, video_start, video_end, text_start, text_end in zip(
241
+ video_ids, video_starts, video_ends, text_starts, text_ends):
242
+ self.video_ids.append((
243
+ video_id,
244
+ (int(video_start), int(video_end)),
245
+ (int(text_start), int(text_end))
246
+ ))
fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import json
8
+ import pickle
9
+ from tqdm import tqdm
10
+ import os
11
+ import numpy as np
12
+
13
+
14
+ class CaptionDedupProcessor(object):
15
+ """remove overlapping of caption sentences(clip).
16
+ Some statistics:
17
+ caption:
18
+ {'t_clip_len': 246.6448431320854,
19
+ 'video_len': 281.09174795676245,
20
+ 'clip_tps': 0.8841283727427481,
21
+ 'video_tps': 0.7821156477732097,
22
+ 'min_clip_len': 0.0,
23
+ 'max_clip_len': 398.3,
24
+ 'mean_clip_len': 3.196580003006861,
25
+ 'num_clip': 77.15897706301081}
26
+
27
+ raw_caption:
28
+ {'t_clip_len': 238.95908778424115,
29
+ 'video_len': 267.5914859862507,
30
+ 'clip_tps': 2.4941363624267963,
31
+ 'video_tps': 2.258989769647173,
32
+ 'min_clip_len': 0.0,
33
+ 'max_clip_len': 398.3,
34
+ 'mean_clip_len': 3.0537954186814265,
35
+ 'num_clip': 78.24986779481756}
36
+ """
37
+
38
+ def __init__(self, pkl_file):
39
+ with open(pkl_file, "rb") as fd:
40
+ self.data = pickle.load(fd)
41
+ self.stat = {
42
+ "t_clip_len": [],
43
+ "video_len": [],
44
+ "clip_tps": [],
45
+ "video_tps": [],
46
+ "clip_len": [],
47
+ }
48
+
49
+ def __call__(self):
50
+ for idx, video_id in enumerate(tqdm(self.data)):
51
+ caption = json.loads(self.data[video_id])
52
+ caption = self._dedup(caption)
53
+ if idx < 4096: # for the first 4096 examples, compute the statistics.
54
+ self.save_stat(video_id, caption)
55
+ self.data[video_id] = json.dumps(caption)
56
+ self.print_stat()
57
+
58
+ def single(self, video_id):
59
+ caption = json.loads(self.data[video_id])
60
+ for clip_idx, (start, end, text) in enumerate(
61
+ zip(caption["start"], caption["end"], caption["text"])
62
+ ):
63
+ print(start, end, text)
64
+ print("@" * 100)
65
+ caption = self._dedup(caption)
66
+ for clip_idx, (start, end, text) in enumerate(
67
+ zip(caption["start"], caption["end"], caption["text"])
68
+ ):
69
+ print(start, end, text)
70
+ print("#" * 100)
71
+ self.save_stat(video_id, caption)
72
+ self.print_stat()
73
+
74
+ def finalize(self, tgt_fn):
75
+ with open(tgt_fn, "wb") as fw:
76
+ pickle.dump(self.data, fw, pickle.HIGHEST_PROTOCOL)
77
+
78
+ def save_stat(self, video_id, caption):
79
+ video_fn = os.path.join(
80
+ "data/feat/feat_how2_s3d", video_id + ".npy"
81
+ )
82
+ if os.path.isfile(video_fn):
83
+ with open(video_fn, "rb", 1) as fr: # 24 is the buffer size. buffered
84
+ version = np.lib.format.read_magic(fr)
85
+ shape, fortran, dtype = np.lib.format._read_array_header(fr, version)
86
+ video_len = shape[0]
87
+
88
+ t_clip_len = 0.0
89
+ t_tokens = 0
90
+ for idx, (start, end, text) in enumerate(
91
+ zip(caption["start"], caption["end"], caption["text"])
92
+ ):
93
+ clip_len = (
94
+ (end - max(caption["end"][idx - 1], start))
95
+ if idx > 0
96
+ else end - start
97
+ )
98
+ t_clip_len += clip_len
99
+ t_tokens += len(text.split(" "))
100
+ self.stat["clip_len"].append(clip_len)
101
+ self.stat["t_clip_len"].append(t_clip_len)
102
+ self.stat["video_len"].append(video_len)
103
+ self.stat["clip_tps"].append(t_tokens / t_clip_len)
104
+ self.stat["video_tps"].append(t_tokens / video_len)
105
+
106
+ def print_stat(self):
107
+ result = {
108
+ "t_clip_len": np.mean(self.stat["t_clip_len"]),
109
+ "video_len": np.mean(self.stat["video_len"]),
110
+ "clip_tps": np.mean(self.stat["clip_tps"]),
111
+ "video_tps": np.mean(self.stat["video_tps"]),
112
+ "min_clip_len": min(self.stat["clip_len"]),
113
+ "max_clip_len": max(self.stat["clip_len"]),
114
+ "mean_clip_len": np.mean(self.stat["clip_len"]),
115
+ "num_clip": len(self.stat["clip_len"]) / len(self.stat["video_tps"]),
116
+ }
117
+ print(result)
118
+
119
+ def _dedup(self, caption):
120
+ def random_merge(end_idx, start, end, text, starts, ends, texts):
121
+ if random.random() > 0.5:
122
+ # print(clip_idx, "[PARTIAL INTO PREV]", end_idx)
123
+ # overlapped part goes to the end of previous.
124
+ ends[-1] = max(ends[-1], start) # ?
125
+ rest_text = text[end_idx:].strip()
126
+ if rest_text:
127
+ starts.append(max(ends[-1], start))
128
+ ends.append(max(end, starts[-1]))
129
+ texts.append(rest_text)
130
+ else: # goes to the beginning of the current.
131
+ # strip the previous.
132
+ left_text = texts[-1][:-end_idx].strip()
133
+ if left_text:
134
+ # print(clip_idx, "[PREV PARTIAL INTO CUR]", end_idx)
135
+ ends[-1] = min(ends[-1], start)
136
+ texts[-1] = left_text
137
+ else:
138
+ # print(clip_idx, "[PREV LEFT NOTHING ALL INTO CUR]", end_idx)
139
+ starts.pop(-1)
140
+ ends.pop(-1)
141
+ texts.pop(-1)
142
+ starts.append(start)
143
+ ends.append(end)
144
+ texts.append(text)
145
+
146
+ starts, ends, texts = [], [], []
147
+ for clip_idx, (start, end, text) in enumerate(
148
+ zip(caption["start"], caption["end"], caption["text"])
149
+ ):
150
+ if not isinstance(text, str):
151
+ continue
152
+ text = text.replace("\n", " ").strip()
153
+ if len(text) == 0:
154
+ continue
155
+ starts.append(start)
156
+ ends.append(end)
157
+ texts.append(text)
158
+ break
159
+
160
+ for clip_idx, (start, end, text) in enumerate(
161
+ zip(
162
+ caption["start"][clip_idx + 1:],
163
+ caption["end"][clip_idx + 1:],
164
+ caption["text"][clip_idx + 1:],
165
+ )
166
+ ):
167
+ if not isinstance(text, str):
168
+ continue
169
+ text = text.replace("\n", " ").strip()
170
+ if len(text) == 0:
171
+ continue
172
+
173
+ # print(clip_idx, texts[-5:])
174
+ # print(clip_idx, start, end, text)
175
+ if texts[-1].endswith(text): # subset of prev caption -> merge
176
+ # print(clip_idx, "[MERGE INTO PREV]")
177
+ ends[-1] = max(ends[-1], end)
178
+ elif text.startswith(texts[-1]): # superset of prev caption -> merge
179
+ # print(clip_idx, "[PREV MERGE INTO CUR]")
180
+ texts[-1] = text
181
+ starts[-1] = min(starts[-1], start)
182
+ ends[-1] = max(ends[-1], end)
183
+ else: # overlapping or non-overlapping.
184
+ for end_idx in range(1, len(text) + 1):
185
+ if texts[-1].endswith(text[:end_idx]):
186
+ random_merge(end_idx, start, end, text, starts, ends, texts)
187
+ break
188
+ else:
189
+ starts.append(start)
190
+ ends.append(end)
191
+ texts.append(text)
192
+
193
+ assert (ends[-1] + 0.001) >= starts[-1] and len(
194
+ texts[-1]
195
+ ) > 0, "{} {} {} <- {} {} {}, {} {} {}".format(
196
+ str(starts[-1]),
197
+ str(ends[-1]),
198
+ texts[-1],
199
+ caption["start"][clip_idx - 1],
200
+ caption["end"][clip_idx - 1],
201
+ caption["text"][clip_idx - 1],
202
+ str(start),
203
+ str(end),
204
+ text,
205
+ )
206
+
207
+ return {"start": starts, "end": ends, "text": texts}
208
+
209
+
210
+ if __name__ == "__main__":
211
+ import argparse
212
+
213
+ parser = argparse.ArgumentParser(description="dedup how2 caption")
214
+ parser.add_argument('--how2dir', default="data/how2")
215
+ args = parser.parse_args()
216
+
217
+ raw_caption_json = os.path.join(args.how2dir, "raw_caption.json")
218
+ raw_caption_pickle = os.path.join(args.how2dir, "raw_caption.pkl")
219
+ raw_caption_dedup_pickle = os.path.join(args.how2dir, "raw_caption_dedup.pkl")
220
+
221
+ def convert_to_pickle(src_fn, tgt_fn):
222
+ with open(src_fn) as fd:
223
+ captions = json.load(fd)
224
+
225
+ for video_id in captions:
226
+ captions[video_id] = json.dumps(captions[video_id])
227
+
228
+ with open(tgt_fn, "wb") as fw:
229
+ pickle.dump(captions, fw, pickle.HIGHEST_PROTOCOL)
230
+
231
+ if not os.path.isfile(raw_caption_pickle):
232
+ convert_to_pickle(raw_caption_json, raw_caption_pickle)
233
+
234
+ deduper = CaptionDedupProcessor(raw_caption_pickle)
235
+ deduper()
236
+ deduper.finalize(raw_caption_dedup_pickle)
237
+
238
+ """
239
+ # demo
240
+ deduper = CaptionDedupProcessor("data/how2/raw_caption.pkl")
241
+ deduper.single("HfIeQ9pzL5U")
242
+ """
fairseq/examples/MMPT/mmpt/processors/how2processor.py ADDED
@@ -0,0 +1,887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Copyright (c) Facebook, Inc. All Rights Reserved
17
+
18
+
19
+ import torch
20
+ import math
21
+ import pickle
22
+ import random
23
+ import os
24
+ import numpy as np
25
+
26
+ from collections import deque
27
+ from typing import Optional, Tuple, List
28
+ from .processor import (
29
+ Processor,
30
+ MetaProcessor,
31
+ TextProcessor,
32
+ Aligner,
33
+ MMAttentionMask2DProcessor
34
+ )
35
+
36
+ from ..utils import ShardedTensor
37
+
38
+
39
+ class How2MetaProcessor(MetaProcessor):
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+ path = self._get_split_path(config)
43
+ with open(path) as fd:
44
+ self.data = [line.strip() for line in fd]
45
+
46
+ def __getitem__(self, idx):
47
+ video_id = self.data[idx]
48
+ return video_id, video_id
49
+
50
+
51
+ class ShardedHow2MetaProcessor(How2MetaProcessor):
52
+ def __init__(self, config):
53
+ super().__init__(config)
54
+ self.split = str(config.split)
55
+ self.vfeat_dir = config.vfeat_dir
56
+ self._init_shard()
57
+
58
+ def _init_shard(self):
59
+ if self.split == "train":
60
+ meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
61
+ with open(meta_fn, "rb") as fr:
62
+ meta = pickle.load(fr)
63
+ elif self.split == "valid":
64
+ meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
65
+ with open(meta_fn, "rb") as fr:
66
+ meta = pickle.load(fr)
67
+ elif self.split == "test":
68
+ print("use how2 val as test.")
69
+ meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
70
+ with open(meta_fn, "rb") as fr:
71
+ meta = pickle.load(fr)
72
+ else:
73
+ raise ValueError("unsupported for MetaProcessor:", self.split)
74
+ video_id_to_shard = {}
75
+ for shard_id in meta:
76
+ for video_idx, video_id in enumerate(meta[shard_id]):
77
+ video_id_to_shard[video_id] = (shard_id, video_idx)
78
+ self.video_id_to_shard = video_id_to_shard
79
+
80
+ def __getitem__(self, idx):
81
+ video_id, video_id = super().__getitem__(idx)
82
+ shard_id, shard_idx = self.video_id_to_shard[video_id]
83
+ meta = (video_id, idx, shard_id, shard_idx)
84
+ return meta, meta
85
+
86
+
87
+ class ShardedVideoProcessor(Processor):
88
+ """
89
+ mmaped shards of numpy video features.
90
+ """
91
+
92
+ def __init__(self, config):
93
+ self.split = str(config.split)
94
+ self.vfeat_dir = config.vfeat_dir
95
+
96
+ def __call__(self, video_id):
97
+ _, _, shard_id, video_idx = video_id
98
+ if self.split == "train":
99
+ shard = ShardedTensor.load(
100
+ os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
101
+ "r"
102
+ )
103
+ elif self.split == "valid":
104
+ shard = ShardedTensor.load(
105
+ os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
106
+ "r"
107
+ )
108
+ elif self.split == "test":
109
+ shard = ShardedTensor.load(
110
+ os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
111
+ "r"
112
+ )
113
+ else:
114
+ raise ValueError("unknown split", self.split)
115
+ feat = shard[video_idx]
116
+ return feat
117
+
118
+
119
+ class ShardedTextProcessor(Processor):
120
+ def __init__(self, config):
121
+ self.tfeat_dir = str(config.tfeat_dir)
122
+ self.split = str(config.split)
123
+
124
+ def __call__(self, video_id):
125
+ _, _, shard_id, shard_idx = video_id
126
+ if self.split == "train":
127
+ target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
128
+ elif self.split == "valid":
129
+ target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
130
+ elif self.split == "test":
131
+ target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
132
+ else:
133
+ raise ValueError("unknown split", self.split)
134
+
135
+ startend = ShardedTensor.load(
136
+ target_path + ".startends", "r")[shard_idx]
137
+ cap_ids = ShardedTensor.load(
138
+ target_path + ".caps_ids", "r")[shard_idx]
139
+ cap = []
140
+ for clip_idx in range(len(cap_ids)):
141
+ clip = cap_ids[clip_idx]
142
+ cap.append(clip[clip != -1].tolist())
143
+ start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
144
+ return {"start": start, "end": end, "cap": cap}
145
+
146
+
147
+ class FixedLenAligner(Aligner):
148
+ """
149
+ In the model we assume text is on the left (closer to BERT formulation)
150
+ and video is on the right.
151
+ We fix the total length of text + video.
152
+ max_video_len is in number of secs.
153
+ max_text_len is in number of tokens.
154
+
155
+ special tokens formats:
156
+ we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
157
+ [CLS] will be splitted out into:
158
+ [CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
159
+ token_type_ids will be generated by the model (for now).
160
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
161
+ | first sequence | second sequence |
162
+ so each sequence owns a [SEP] token for no-ops.
163
+ """
164
+
165
+ def __init__(self, config):
166
+ super().__init__(config)
167
+ self.text_clip_sampler = TextClipSamplingProcessor(
168
+ self.max_len - self.max_video_len - 3
169
+ )
170
+ """
171
+ decide subsampling:
172
+ `config.subsampling` will change batch_size in trainer.
173
+ `config.clip_per_video` (used by RetriTask) doesn't
174
+ change batch_size in trainer.
175
+ """
176
+ subsampling = config.subsampling \
177
+ if config.subsampling is not None else None
178
+ if config.clip_per_video is not None:
179
+ subsampling = config.clip_per_video
180
+ self.subsampling = subsampling
181
+
182
+ def _get_text_maxlen(self):
183
+ # use max text len
184
+ return self.text_clip_sampler.max_text_len
185
+
186
+ def __call__(self, video_id, video_feature, text_feature):
187
+ from transformers import default_data_collator
188
+ video_idx = video_id[1]
189
+ if self.subsampling is not None and self.subsampling >= 1:
190
+ batch = []
191
+ for _ in range(self.subsampling):
192
+ centerclip_idx = random.randint(
193
+ 0, len(text_feature["start"]) - 1)
194
+ batch.append(
195
+ self.sampling(
196
+ video_idx,
197
+ video_feature,
198
+ text_feature,
199
+ centerclip_idx,
200
+ self._get_text_maxlen()
201
+ ))
202
+ batch = self.batch_post_processing(batch, video_feature)
203
+ batch = default_data_collator(batch)
204
+ else:
205
+ raise ValueError(
206
+ "dataset.subsampling must be >= 1 for efficient video loading.")
207
+ batch = self.sampling(video_idx, video_feature, text_feature)
208
+ batch = self.batch_post_processing(batch, video_feature)
209
+
210
+ batch["video_id"] = video_id if isinstance(video_id, str) \
211
+ else video_id[0]
212
+ # e2e: make sure frame ids is into tensor.
213
+ assert torch.is_tensor(batch["vfeats"])
214
+ return batch
215
+
216
+ def sampling(
217
+ self,
218
+ video_idx,
219
+ video_feature,
220
+ text_feature,
221
+ centerclip_idx=None,
222
+ sampled_max_text_len=None,
223
+ ):
224
+ text_clip_indexs = self.text_clip_sampler(
225
+ text_feature, centerclip_idx,
226
+ sampled_max_text_len
227
+ )
228
+ if isinstance(video_feature, np.ndarray):
229
+ video_len = len(video_feature)
230
+ else:
231
+ video_len = math.ceil(text_feature["end"][-1])
232
+
233
+ video_end = min(
234
+ math.ceil(text_feature["end"][text_clip_indexs[-1]]),
235
+ video_len
236
+ )
237
+ video_start = max(
238
+ min(
239
+ math.floor(text_feature["start"][text_clip_indexs[0]]),
240
+ video_end),
241
+ 0
242
+ )
243
+
244
+ video_clips = {"start": [video_start], "end": [video_end]}
245
+
246
+ # tensorize.
247
+ vfeats, vmasks = self._build_video_seq(
248
+ video_feature, video_clips
249
+ )
250
+ caps, cmasks = self._build_text_seq(
251
+ text_feature, text_clip_indexs
252
+ )
253
+
254
+ text_start = text_clip_indexs[0]
255
+ text_end = text_clip_indexs[-1] + 1
256
+
257
+ return {
258
+ "caps": caps,
259
+ "cmasks": cmasks,
260
+ "vfeats": vfeats,
261
+ "vmasks": vmasks,
262
+ "video_start": video_start,
263
+ "video_end": video_end,
264
+ "text_start": text_start,
265
+ "text_end": text_end,
266
+ }
267
+
268
+
269
+ class VariedLenAligner(FixedLenAligner):
270
+ def __init__(self, config):
271
+ super().__init__(config)
272
+ self.sampled_min_len = config.sampled_min_len
273
+ self.sampled_max_len = config.sampled_max_len
274
+
275
+ def _get_text_maxlen(self):
276
+ return random.randint(self.sampled_min_len, self.sampled_max_len)
277
+
278
+
279
+ class StartClipAligner(VariedLenAligner):
280
+ def sampling(
281
+ self,
282
+ video_idx,
283
+ video_feature,
284
+ text_feature,
285
+ centerclip_idx=None,
286
+ sampled_max_text_len=None,
287
+ ):
288
+ return super().sampling(
289
+ video_idx, video_feature, text_feature, 0)
290
+
291
+
292
+ class OverlappedAligner(VariedLenAligner):
293
+ """video clip and text clip has overlappings
294
+ but may not be the same start/end."""
295
+ def __init__(self, config):
296
+ super().__init__(config)
297
+ self.sampled_video_min_len = config.sampled_video_min_len
298
+ self.sampled_video_max_len = config.sampled_video_max_len
299
+
300
+ self.video_clip_sampler = VideoClipSamplingProcessor()
301
+
302
+ def _get_video_maxlen(self):
303
+ return random.randint(
304
+ self.sampled_video_min_len, self.sampled_video_max_len)
305
+
306
+ def sampling(
307
+ self,
308
+ video_idx,
309
+ video_feature,
310
+ text_feature,
311
+ centerclip_idx=None,
312
+ sampled_max_text_len=None,
313
+ ):
314
+ text_clip_indexs = self.text_clip_sampler(
315
+ text_feature, centerclip_idx,
316
+ sampled_max_text_len
317
+ )
318
+ if isinstance(video_feature, np.ndarray):
319
+ video_len = len(video_feature)
320
+ else:
321
+ video_len = math.ceil(text_feature["end"][-1])
322
+ low = math.floor(text_feature["start"][text_clip_indexs[0]])
323
+ high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
324
+ if low < high:
325
+ center = random.randint(low, high)
326
+ else:
327
+ center = int((low + high) // 2)
328
+ center = max(0, min(video_feature.shape[0] - 1, center))
329
+
330
+ assert 0 <= center < video_feature.shape[0]
331
+
332
+ video_clips = self.video_clip_sampler(
333
+ video_len, self._get_video_maxlen(), center
334
+ )
335
+ video_start = video_clips["start"][0]
336
+ video_end = video_clips["end"][0]
337
+
338
+ # tensorize.
339
+ vfeats, vmasks = self._build_video_seq(
340
+ video_feature, video_clips
341
+ )
342
+ caps, cmasks = self._build_text_seq(
343
+ text_feature, text_clip_indexs
344
+ )
345
+
346
+ text_start = text_clip_indexs[0]
347
+ text_end = text_clip_indexs[-1] + 1
348
+
349
+ return {
350
+ "caps": caps,
351
+ "cmasks": cmasks,
352
+ "vfeats": vfeats,
353
+ "vmasks": vmasks,
354
+ "video_start": video_start,
355
+ "video_end": video_end,
356
+ "text_start": text_start,
357
+ "text_end": text_end,
358
+ }
359
+
360
+
361
+ class MFMMLMAligner(FixedLenAligner):
362
+ """
363
+ `FixedLenAligner` with Masked Language Model and Masked Frame Model.
364
+ """
365
+
366
+ def __init__(self, config):
367
+ super().__init__(config)
368
+ keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
369
+ self.text_clip_sampler = TextClipSamplingProcessor(
370
+ self.max_len - self.max_video_len - 3, keep_prob
371
+ )
372
+ self.sampled_min_len = config.sampled_min_len
373
+ self.sampled_max_len = config.sampled_max_len
374
+ self.masked_token_sampler = TextMaskingProcessor(config)
375
+ self.mm_type = config.mm_type \
376
+ if config.mm_type is not None else "full"
377
+ self.attnmasker = MMAttentionMask2DProcessor() \
378
+ if self.mm_type == "textgen" else None
379
+ self.masked_frame_sampler = FrameMaskingProcessor(config)
380
+ self.lazy_vfeat_mask = (
381
+ False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
382
+ )
383
+ self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
384
+
385
+ def __call__(self, video_id, video_feature, text_feature):
386
+ from transformers import default_data_collator
387
+ if self.subsampling is not None and self.subsampling > 1:
388
+ batch = []
389
+ for _ in range(self.subsampling):
390
+ centerclip_idx = random.randint(
391
+ 0, len(text_feature["start"]) - 1)
392
+ sampled_max_text_len = random.randint(
393
+ self.sampled_min_len, self.sampled_max_len
394
+ )
395
+ batch.append(
396
+ self.sampling(
397
+ video_id,
398
+ video_feature,
399
+ text_feature,
400
+ centerclip_idx,
401
+ sampled_max_text_len,
402
+ )
403
+ )
404
+ batch = self.batch_post_processing(batch, video_feature)
405
+ batch = default_data_collator(batch)
406
+ else:
407
+ batch = self.sampling(video_id, video_feature, text_feature)
408
+ batch = self.batch_post_processing(batch, video_feature)
409
+ batch["video_id"] = video_id if isinstance(video_id, str) \
410
+ else video_id[0]
411
+ return batch
412
+
413
+ def sampling(
414
+ self,
415
+ video_id,
416
+ video_feature,
417
+ text_feature,
418
+ centerclip_idx=None,
419
+ sampled_max_text_len=None,
420
+ ):
421
+ output = FixedLenAligner.sampling(self,
422
+ video_id, video_feature, text_feature,
423
+ centerclip_idx, sampled_max_text_len)
424
+
425
+ masking_text, masking_video = None, None
426
+ if random.random() < self.mm_prob:
427
+ if random.random() > 0.5:
428
+ masking_text, masking_video = self.mm_type, "no"
429
+ else:
430
+ masking_text, masking_video = "no", "full"
431
+ video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
432
+ video_label = self.masked_frame_sampler(
433
+ output["vmasks"], masking_video, vfeats=video_feats)
434
+ caps, text_label = self.masked_token_sampler(
435
+ output["caps"], masking_text)
436
+
437
+ output.update({
438
+ "caps": caps,
439
+ "video_label": video_label,
440
+ "text_label": text_label,
441
+ })
442
+
443
+ if self.attnmasker is not None:
444
+ attention_mask = self.attnmasker(
445
+ output["vmasks"], output["cmasks"], masking_text)
446
+ output.update({
447
+ "attention_mask": attention_mask
448
+ })
449
+ return output
450
+
451
+
452
+ class FrameMaskingProcessor(Processor):
453
+ def __init__(self, config):
454
+ self.mfm_probability = 0.15
455
+ if config.mfm_probability is not None:
456
+ self.mfm_probability = config.mfm_probability
457
+
458
+ def __call__(self, vmasks, modality_masking=None, vfeats=None):
459
+ """
460
+ We perform lazy masking to save data transfer time.
461
+ It only generates video_labels by default and MFM model
462
+ will do actualy masking.
463
+ Return: `video_label` is a binary mask.
464
+ """
465
+ video_label = vmasks.clone()
466
+ if modality_masking is not None:
467
+ if modality_masking == "full":
468
+ probability_matrix = torch.full(video_label.shape, 1.)
469
+ elif modality_masking == "no":
470
+ probability_matrix = torch.full(video_label.shape, 0.)
471
+ elif modality_masking == "inverse":
472
+ probability_matrix = torch.full(
473
+ video_label.shape, 1. - self.mfm_probability)
474
+ else:
475
+ raise ValueError("unknown modality masking.", modality_masking)
476
+ else:
477
+ probability_matrix = torch.full(
478
+ video_label.shape, self.mfm_probability)
479
+ masked_indices = torch.bernoulli(probability_matrix).bool()
480
+ # We only compute loss on masked tokens
481
+ video_label[~masked_indices] = 0
482
+ if vfeats is not None:
483
+ vfeats[video_label, :] = 0.0
484
+ return video_label
485
+
486
+
487
+ class TextGenerationProcessor(Processor):
488
+ def __init__(self, tokenizer):
489
+ self.bos_token_id = tokenizer.bos_token_id
490
+ self.pad_token_id = tokenizer.pad_token_id
491
+
492
+ def __call__(self, inputs):
493
+ labels = inputs.clone()
494
+ # [CLS] [SEP] for video
495
+ labels[:2] = -100
496
+ # keep [SEP] for text.
497
+ pad_mask = labels == self.pad_token_id
498
+ labels[pad_mask] = -100
499
+ inputs[2:] = torch.cat([
500
+ torch.LongTensor([self.bos_token_id]),
501
+ inputs[2:-1]])
502
+ inputs[pad_mask] = self.pad_token_id
503
+ assert len(inputs) == len(labels)
504
+ return inputs, labels
505
+
506
+
507
+ class TextMaskingProcessor(Processor):
508
+ def __init__(self, config):
509
+ """this function is borrowed from
510
+ `transformers/data/data_collator.DataCollatorForLanguageModeling`"""
511
+ self.mlm_probability = 0.15
512
+ if config.mlm_probability is not None:
513
+ self.mlm_probability = config.mlm_probability
514
+ self.bert_name = config.bert_name
515
+ # [CLS] is used as bos_token and [SEP] is used as eos_token.
516
+ # https://huggingface.co/transformers/master/model_doc/bertgeneration.html
517
+ from transformers import AutoTokenizer
518
+ self.tokenizer = AutoTokenizer.from_pretrained(
519
+ self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
520
+ self.textgen = TextGenerationProcessor(self.tokenizer)
521
+
522
+ def __call__(
523
+ self, inputs: torch.Tensor,
524
+ modality_masking=None,
525
+ special_tokens_mask: Optional[torch.Tensor] = None
526
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
527
+ """
528
+ expand modality_masking into
529
+ None: traditional bert masking.
530
+ "no": no masking.
531
+ "full": all [MASK] token for generation.
532
+ "gen": autoregressive generation.
533
+ """
534
+ """
535
+ Prepare masked tokens inputs/labels for masked language modeling:
536
+ 80% MASK, 10% random, 10% original.
537
+ """
538
+ labels = inputs.clone()
539
+ # We sample a few tokens in each sequence for MLM training
540
+ # (with probability `self.mlm_probability`)
541
+ if modality_masking is not None:
542
+ if modality_masking == "full":
543
+ probability_matrix = torch.full(labels.shape, 1.)
544
+ elif modality_masking == "no":
545
+ probability_matrix = torch.full(labels.shape, 0.)
546
+ elif modality_masking.startswith("textgen"):
547
+ # [CLS] [SEP] <s> ...
548
+ inputs, labels = self.textgen(inputs)
549
+ if "mask" not in modality_masking:
550
+ return inputs, labels
551
+ inputs = self.mask_input(inputs, special_tokens_mask)
552
+ return inputs, labels
553
+ elif modality_masking == "mask":
554
+ inputs = self.mask_input(inputs, special_tokens_mask)
555
+ labels = torch.full(inputs.shape, -100)
556
+ return inputs, labels
557
+ elif modality_masking == "inverse":
558
+ probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
559
+ else:
560
+ raise ValueError("unknown modality masking.", modality_masking)
561
+ else:
562
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
563
+
564
+ if special_tokens_mask is None:
565
+ special_tokens_mask = self.get_special_tokens_mask(
566
+ labels.tolist(), already_has_special_tokens=True
567
+ )
568
+ special_tokens_mask = torch.tensor(
569
+ special_tokens_mask, dtype=torch.bool)
570
+ else:
571
+ special_tokens_mask = special_tokens_mask.bool()
572
+
573
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
574
+ masked_indices = torch.bernoulli(probability_matrix).bool()
575
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
576
+
577
+ # 80% of the time,
578
+ # we replace masked input tokens with tokenizer.mask_token ([MASK])
579
+ indices_replaced = (
580
+ torch.bernoulli(
581
+ torch.full(labels.shape, 0.8)).bool() & masked_indices
582
+ )
583
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
584
+ self.tokenizer.mask_token
585
+ )
586
+
587
+ # 10% of the time, we replace masked input tokens with random word
588
+ indices_random = (
589
+ torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
590
+ & masked_indices
591
+ & ~indices_replaced
592
+ )
593
+ random_words = torch.randint(
594
+ len(self.tokenizer), labels.shape, dtype=torch.long
595
+ )
596
+ inputs[indices_random] = random_words[indices_random]
597
+
598
+ # The rest of the time (10% of the time) we keep the masked input
599
+ # tokens unchanged
600
+ return inputs, labels
601
+
602
+ def mask_input(self, inputs, special_tokens_mask=None):
603
+ # the following is new with masked autoregressive.
604
+ probability_matrix = torch.full(
605
+ inputs.shape, self.mlm_probability)
606
+ if special_tokens_mask is None:
607
+ special_tokens_mask = self.get_special_tokens_mask(
608
+ inputs.tolist(), already_has_special_tokens=True
609
+ )
610
+ special_tokens_mask = torch.tensor(
611
+ special_tokens_mask, dtype=torch.bool)
612
+ else:
613
+ special_tokens_mask = special_tokens_mask.bool()
614
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
615
+ masked_indices = torch.bernoulli(probability_matrix).bool()
616
+ indices_replaced = (
617
+ torch.bernoulli(
618
+ torch.full(inputs.shape, 0.8)).bool() & masked_indices
619
+ )
620
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
621
+ self.tokenizer.mask_token
622
+ )
623
+
624
+ # 10% of the time, we replace masked input tokens with random word
625
+ indices_random = (
626
+ torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
627
+ & masked_indices
628
+ & ~indices_replaced
629
+ )
630
+ random_words = torch.randint(
631
+ len(self.tokenizer), inputs.shape, dtype=torch.long
632
+ )
633
+ inputs[indices_random] = random_words[indices_random]
634
+ return inputs
635
+
636
+ def get_special_tokens_mask(
637
+ self, token_ids_0: List[int],
638
+ token_ids_1: Optional[List[int]] = None,
639
+ already_has_special_tokens: bool = False
640
+ ) -> List[int]:
641
+ """
642
+ Note: the version from transformers do not consider pad
643
+ as special tokens.
644
+ """
645
+
646
+ if already_has_special_tokens:
647
+ if token_ids_1 is not None:
648
+ raise ValueError(
649
+ "You should not supply a second sequence if"
650
+ "the provided sequence of "
651
+ "ids is already formated with special tokens "
652
+ "for the model."
653
+ )
654
+ return list(map(lambda x: 1 if x in [
655
+ self.tokenizer.sep_token_id,
656
+ self.tokenizer.cls_token_id,
657
+ self.tokenizer.pad_token_id] else 0, token_ids_0))
658
+
659
+ if token_ids_1 is not None:
660
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
661
+ return [1] + ([0] * len(token_ids_0)) + [1]
662
+
663
+
664
+ class TextClipSamplingProcessor(Processor):
665
+ def __init__(self, max_text_len, keep_prob=1.0):
666
+ self.max_text_len = max_text_len
667
+ self.max_video_len = 256 # always hold.
668
+ self.keep_prob = keep_prob
669
+
670
+ def __call__(
671
+ self,
672
+ text_feature,
673
+ centerclip_idx=None,
674
+ sampled_max_text_len=None,
675
+ sampled_max_video_len=None,
676
+ ):
677
+ # Let's use all caps for now and see if 256 can cover all of them.
678
+ if sampled_max_text_len is not None:
679
+ max_text_len = sampled_max_text_len
680
+ else:
681
+ max_text_len = self.max_text_len
682
+ if sampled_max_video_len is not None:
683
+ max_video_len = sampled_max_video_len
684
+ else:
685
+ max_video_len = self.max_video_len
686
+
687
+ t_num_clips = len(text_feature["start"])
688
+
689
+ if centerclip_idx is None:
690
+ centerclip_idx = random.randint(0, t_num_clips - 1)
691
+
692
+ start_idx, end_idx = centerclip_idx, centerclip_idx + 1
693
+ text_clip_indexs = deque()
694
+ text_clip_indexs.append(start_idx)
695
+ text_len = len(text_feature["cap"][start_idx])
696
+
697
+ video_len = max(
698
+ 0,
699
+ text_feature["end"][start_idx]
700
+ - text_feature["start"][start_idx],
701
+ )
702
+
703
+ while (
704
+ (start_idx > 0 or end_idx < t_num_clips)
705
+ and text_len < max_text_len
706
+ and video_len < max_video_len
707
+ ):
708
+ if random.random() > 0.5 and end_idx < t_num_clips:
709
+ # skip the next one?
710
+ if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
711
+ end_idx = end_idx + 1
712
+ text_clip_indexs.append(end_idx)
713
+ text_len += len(text_feature["cap"][end_idx])
714
+ end_idx += 1
715
+ elif start_idx > 0:
716
+ if random.random() > self.keep_prob and (start_idx - 1) > 0:
717
+ start_idx = start_idx - 1
718
+ start_idx -= 1
719
+ text_clip_indexs.insert(0, start_idx)
720
+ text_len += len(text_feature["cap"][start_idx])
721
+ else:
722
+ if end_idx < t_num_clips:
723
+ if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
724
+ end_idx = end_idx + 1
725
+ text_clip_indexs.append(end_idx)
726
+ text_len += len(text_feature["cap"][end_idx])
727
+ end_idx += 1
728
+ else:
729
+ return text_clip_indexs
730
+ video_len = max(
731
+ 0,
732
+ text_feature["end"][text_clip_indexs[-1]]
733
+ - text_feature["start"][text_clip_indexs[0]],
734
+ )
735
+ return text_clip_indexs
736
+
737
+
738
+ class VideoClipSamplingProcessor(Processor):
739
+ def __call__(self, video_len, max_video_len, center):
740
+ """
741
+ `video_len`: length of the video.
742
+ `max_video_len`: maximum video tokens allowd in a sequence.
743
+ `center`: initial starting index.
744
+ """
745
+ assert center >= 0 and center < video_len
746
+ t_clip_len = 0
747
+ start, end = center, center
748
+ while (start > 0 or end < video_len) and t_clip_len < max_video_len:
749
+ # decide the direction to grow.
750
+ if start <= 0:
751
+ end += 1
752
+ elif end >= video_len:
753
+ start -= 1
754
+ elif random.random() > 0.5:
755
+ end += 1
756
+ else:
757
+ start -= 1
758
+ t_clip_len += 1
759
+ return {"start": [start], "end": [end]}
760
+
761
+
762
+ class How2MILNCEAligner(FixedLenAligner):
763
+ """reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
764
+
765
+ def __init__(self, config):
766
+ super().__init__(config)
767
+ self.num_candidates = 4
768
+ self.min_time = 5.0
769
+ self.num_sec = 3.2
770
+ # self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5
771
+ # self.num_frames = 16
772
+
773
+ def sampling(
774
+ self,
775
+ video_id,
776
+ video_feature,
777
+ text_feature,
778
+ centerclip_idx=None, # will be ignored.
779
+ sampled_max_text_len=None # will be ignored.
780
+ ):
781
+ text, start, end = self._get_text(text_feature)
782
+ video = self._get_video(video_feature, start, end)
783
+
784
+ vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
785
+ vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
786
+ vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
787
+ vmasks[: video.shape[0]] = 1
788
+
789
+ caps, cmasks = [], []
790
+ for words in text:
791
+ cap, cmask = self._build_text_seq(text_feature, words)
792
+ caps.append(cap)
793
+ cmasks.append(cmask)
794
+ caps = torch.stack(caps)
795
+ cmasks = torch.stack(cmasks)
796
+ # video of shape: (video_len)
797
+ # text of shape (num_candidates, max_text_len)
798
+
799
+ return {
800
+ "caps": caps,
801
+ "cmasks": cmasks,
802
+ "vfeats": vfeats,
803
+ "vmasks": vmasks,
804
+ # "video_id": video_id,
805
+ }
806
+
807
+ def _get_video(self, video_feature, start, end):
808
+ start_seek = random.randint(start, int(max(start, end - self.num_sec)))
809
+ # duration = self.num_sec + 0.1
810
+ return video_feature[start_seek : int(start_seek + self.num_sec)]
811
+
812
+ def _get_text(self, cap):
813
+ ind = random.randint(0, len(cap["start"]) - 1)
814
+ if self.num_candidates == 1:
815
+ words = [ind]
816
+ else:
817
+ words = []
818
+ cap_start = self._find_nearest_candidates(cap, ind)
819
+ for i in range(self.num_candidates):
820
+ words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
821
+
822
+ start, end = cap["start"][ind], cap["end"][ind]
823
+ # TODO: May need to be improved for edge cases.
824
+ # expand the min time.
825
+ if end - start < self.min_time:
826
+ diff = self.min_time - end + start
827
+ start = max(0, start - diff / 2)
828
+ end = start + self.min_time
829
+ return words, int(start), int(end)
830
+
831
+ def _find_nearest_candidates(self, caption, ind):
832
+ """find the range of the clips."""
833
+ start, end = ind, ind
834
+ #diff = caption["end"][end] - caption["start"][start]
835
+ n_candidate = 1
836
+ while n_candidate < self.num_candidates:
837
+ # the first clip
838
+ if start == 0:
839
+ return 0
840
+ # we add () in the following condition to fix the bug.
841
+ elif end == (len(caption["start"]) - 1):
842
+ return start - (self.num_candidates - n_candidate)
843
+ elif (caption["end"][end] - caption["start"][start - 1]) < (
844
+ caption["end"][end + 1] - caption["start"][start]
845
+ ):
846
+ start -= 1
847
+ else:
848
+ end += 1
849
+ n_candidate += 1
850
+ return start
851
+
852
+
853
+ class PKLJSONStrTextProcessor(TextProcessor):
854
+ """`caption.json` from howto100m are preprocessed as a
855
+ dict `[video_id, json_str]`.
856
+ Json parsing tokenization are conducted on-the-fly and cached into dict.
857
+ """
858
+
859
+ def __init__(self, config, max_clip_text_len=96):
860
+ print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
861
+ self.caption_pkl_path = str(config.caption_pkl_path)
862
+ with open(self.caption_pkl_path, "rb") as fd:
863
+ self.data = pickle.load(fd)
864
+ self.max_clip_text_len = max_clip_text_len
865
+ from transformers import AutoTokenizer
866
+ self.tokenizer = AutoTokenizer.from_pretrained(
867
+ str(config.bert_name), use_fast=config.use_fast
868
+ )
869
+
870
+ def __call__(self, video_id):
871
+ caption = self.data[video_id]
872
+ if isinstance(caption, str):
873
+ import json
874
+ caption = json.loads(caption)
875
+ cap = []
876
+ for clip_idx, text_clip in enumerate(caption["text"]):
877
+ clip_ids = []
878
+ if isinstance(text_clip, str):
879
+ clip_ids = self.tokenizer(
880
+ text_clip[: self.max_clip_text_len],
881
+ add_special_tokens=False
882
+ )["input_ids"]
883
+ cap.append(clip_ids)
884
+ caption["cap"] = cap
885
+ caption.pop("text") # save space.
886
+ self.data[video_id] = caption
887
+ return caption
fairseq/examples/MMPT/mmpt/processors/models/s3dg.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the MIT license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ """Contains a PyTorch definition for Gated Separable 3D network (S3D-G)
5
+ with a text module for computing joint text-video embedding from raw text
6
+ and video input. The following code will enable you to load the HowTo100M
7
+ pretrained S3D Text-Video model from:
8
+ A. Miech, J.-B. Alayrac, L. Smaira, I. Laptev, J. Sivic and A. Zisserman,
9
+ End-to-End Learning of Visual Representations from Uncurated Instructional Videos.
10
+ https://arxiv.org/abs/1912.06430.
11
+
12
+ S3D-G was proposed by:
13
+ S. Xie, C. Sun, J. Huang, Z. Tu and K. Murphy,
14
+ Rethinking Spatiotemporal Feature Learning For Video Understanding.
15
+ https://arxiv.org/abs/1712.04851.
16
+ Tensorflow code: https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
17
+
18
+ The S3D architecture was slightly modified with a space to depth trick for TPU
19
+ optimization.
20
+ """
21
+
22
+ import torch as th
23
+ import torch.nn.functional as F
24
+ import torch.nn as nn
25
+ import os
26
+ import numpy as np
27
+ import re
28
+
29
+
30
+ class InceptionBlock(nn.Module):
31
+ def __init__(
32
+ self,
33
+ input_dim,
34
+ num_outputs_0_0a,
35
+ num_outputs_1_0a,
36
+ num_outputs_1_0b,
37
+ num_outputs_2_0a,
38
+ num_outputs_2_0b,
39
+ num_outputs_3_0b,
40
+ gating=True,
41
+ ):
42
+ super(InceptionBlock, self).__init__()
43
+ self.conv_b0 = STConv3D(input_dim, num_outputs_0_0a, [1, 1, 1])
44
+ self.conv_b1_a = STConv3D(input_dim, num_outputs_1_0a, [1, 1, 1])
45
+ self.conv_b1_b = STConv3D(
46
+ num_outputs_1_0a, num_outputs_1_0b, [3, 3, 3], padding=1, separable=True
47
+ )
48
+ self.conv_b2_a = STConv3D(input_dim, num_outputs_2_0a, [1, 1, 1])
49
+ self.conv_b2_b = STConv3D(
50
+ num_outputs_2_0a, num_outputs_2_0b, [3, 3, 3], padding=1, separable=True
51
+ )
52
+ self.maxpool_b3 = th.nn.MaxPool3d((3, 3, 3), stride=1, padding=1)
53
+ self.conv_b3_b = STConv3D(input_dim, num_outputs_3_0b, [1, 1, 1])
54
+ self.gating = gating
55
+ self.output_dim = (
56
+ num_outputs_0_0a + num_outputs_1_0b + num_outputs_2_0b + num_outputs_3_0b
57
+ )
58
+ if gating:
59
+ self.gating_b0 = SelfGating(num_outputs_0_0a)
60
+ self.gating_b1 = SelfGating(num_outputs_1_0b)
61
+ self.gating_b2 = SelfGating(num_outputs_2_0b)
62
+ self.gating_b3 = SelfGating(num_outputs_3_0b)
63
+
64
+ def forward(self, input):
65
+ """Inception block
66
+ """
67
+ b0 = self.conv_b0(input)
68
+ b1 = self.conv_b1_a(input)
69
+ b1 = self.conv_b1_b(b1)
70
+ b2 = self.conv_b2_a(input)
71
+ b2 = self.conv_b2_b(b2)
72
+ b3 = self.maxpool_b3(input)
73
+ b3 = self.conv_b3_b(b3)
74
+ if self.gating:
75
+ b0 = self.gating_b0(b0)
76
+ b1 = self.gating_b1(b1)
77
+ b2 = self.gating_b2(b2)
78
+ b3 = self.gating_b3(b3)
79
+ return th.cat((b0, b1, b2, b3), dim=1)
80
+
81
+
82
+ class SelfGating(nn.Module):
83
+ def __init__(self, input_dim):
84
+ super(SelfGating, self).__init__()
85
+ self.fc = nn.Linear(input_dim, input_dim)
86
+
87
+ def forward(self, input_tensor):
88
+ """Feature gating as used in S3D-G.
89
+ """
90
+ spatiotemporal_average = th.mean(input_tensor, dim=[2, 3, 4])
91
+ weights = self.fc(spatiotemporal_average)
92
+ weights = th.sigmoid(weights)
93
+ return weights[:, :, None, None, None] * input_tensor
94
+
95
+
96
+ class STConv3D(nn.Module):
97
+ def __init__(
98
+ self, input_dim, output_dim, kernel_size, stride=1, padding=0, separable=False
99
+ ):
100
+ super(STConv3D, self).__init__()
101
+ self.separable = separable
102
+ self.relu = nn.ReLU(inplace=True)
103
+ assert len(kernel_size) == 3
104
+ if separable and kernel_size[0] != 1:
105
+ spatial_kernel_size = [1, kernel_size[1], kernel_size[2]]
106
+ temporal_kernel_size = [kernel_size[0], 1, 1]
107
+ if isinstance(stride, list) and len(stride) == 3:
108
+ spatial_stride = [1, stride[1], stride[2]]
109
+ temporal_stride = [stride[0], 1, 1]
110
+ else:
111
+ spatial_stride = [1, stride, stride]
112
+ temporal_stride = [stride, 1, 1]
113
+ if isinstance(padding, list) and len(padding) == 3:
114
+ spatial_padding = [0, padding[1], padding[2]]
115
+ temporal_padding = [padding[0], 0, 0]
116
+ else:
117
+ spatial_padding = [0, padding, padding]
118
+ temporal_padding = [padding, 0, 0]
119
+ if separable:
120
+ self.conv1 = nn.Conv3d(
121
+ input_dim,
122
+ output_dim,
123
+ kernel_size=spatial_kernel_size,
124
+ stride=spatial_stride,
125
+ padding=spatial_padding,
126
+ bias=False,
127
+ )
128
+ self.bn1 = nn.BatchNorm3d(output_dim)
129
+ self.conv2 = nn.Conv3d(
130
+ output_dim,
131
+ output_dim,
132
+ kernel_size=temporal_kernel_size,
133
+ stride=temporal_stride,
134
+ padding=temporal_padding,
135
+ bias=False,
136
+ )
137
+ self.bn2 = nn.BatchNorm3d(output_dim)
138
+ else:
139
+ self.conv1 = nn.Conv3d(
140
+ input_dim,
141
+ output_dim,
142
+ kernel_size=kernel_size,
143
+ stride=stride,
144
+ padding=padding,
145
+ bias=False,
146
+ )
147
+ self.bn1 = nn.BatchNorm3d(output_dim)
148
+
149
+ def forward(self, input):
150
+ out = self.relu(self.bn1(self.conv1(input)))
151
+ if self.separable:
152
+ out = self.relu(self.bn2(self.conv2(out)))
153
+ return out
154
+
155
+
156
+ class MaxPool3dTFPadding(th.nn.Module):
157
+ def __init__(self, kernel_size, stride=None, padding="SAME"):
158
+ super(MaxPool3dTFPadding, self).__init__()
159
+ if padding == "SAME":
160
+ padding_shape = self._get_padding_shape(kernel_size, stride)
161
+ self.padding_shape = padding_shape
162
+ self.pad = th.nn.ConstantPad3d(padding_shape, 0)
163
+ self.pool = th.nn.MaxPool3d(kernel_size, stride, ceil_mode=True)
164
+
165
+ def _get_padding_shape(self, filter_shape, stride):
166
+ def _pad_top_bottom(filter_dim, stride_val):
167
+ pad_along = max(filter_dim - stride_val, 0)
168
+ pad_top = pad_along // 2
169
+ pad_bottom = pad_along - pad_top
170
+ return pad_top, pad_bottom
171
+
172
+ padding_shape = []
173
+ for filter_dim, stride_val in zip(filter_shape, stride):
174
+ pad_top, pad_bottom = _pad_top_bottom(filter_dim, stride_val)
175
+ padding_shape.append(pad_top)
176
+ padding_shape.append(pad_bottom)
177
+ depth_top = padding_shape.pop(0)
178
+ depth_bottom = padding_shape.pop(0)
179
+ padding_shape.append(depth_top)
180
+ padding_shape.append(depth_bottom)
181
+ return tuple(padding_shape)
182
+
183
+ def forward(self, inp):
184
+ inp = self.pad(inp)
185
+ out = self.pool(inp)
186
+ return out
187
+
188
+
189
+ class Sentence_Embedding(nn.Module):
190
+ def __init__(
191
+ self,
192
+ embd_dim,
193
+ num_embeddings=66250,
194
+ word_embedding_dim=300,
195
+ token_to_word_path="dict.npy",
196
+ max_words=16,
197
+ output_dim=2048,
198
+ ):
199
+ super(Sentence_Embedding, self).__init__()
200
+ self.word_embd = nn.Embedding(num_embeddings, word_embedding_dim)
201
+ self.fc1 = nn.Linear(word_embedding_dim, output_dim)
202
+ self.fc2 = nn.Linear(output_dim, embd_dim)
203
+ self.word_to_token = {}
204
+ self.max_words = max_words
205
+ token_to_word = np.load(token_to_word_path)
206
+ for i, t in enumerate(token_to_word):
207
+ self.word_to_token[t] = i + 1
208
+
209
+ def _zero_pad_tensor_token(self, tensor, size):
210
+ if len(tensor) >= size:
211
+ return tensor[:size]
212
+ else:
213
+ zero = th.zeros(size - len(tensor)).long()
214
+ return th.cat((tensor, zero), dim=0)
215
+
216
+ def _split_text(self, sentence):
217
+ w = re.findall(r"[\w']+", str(sentence))
218
+ return w
219
+
220
+ def _words_to_token(self, words):
221
+ words = [
222
+ self.word_to_token[word] for word in words if word in self.word_to_token
223
+ ]
224
+ if words:
225
+ we = self._zero_pad_tensor_token(th.LongTensor(words), self.max_words)
226
+ return we
227
+ else:
228
+ return th.zeros(self.max_words).long()
229
+
230
+ def _words_to_ids(self, x):
231
+ split_x = [self._words_to_token(self._split_text(sent.lower())) for sent in x]
232
+ return th.stack(split_x, dim=0)
233
+
234
+ def forward(self, x):
235
+ x = self._words_to_ids(x)
236
+ x = self.word_embd(x)
237
+ x = F.relu(self.fc1(x))
238
+ x = th.max(x, dim=1)[0]
239
+ x = self.fc2(x)
240
+ return {'text_embedding': x}
241
+
242
+
243
+ class S3D(nn.Module):
244
+ def __init__(self, dict_path, num_classes=512, gating=True, space_to_depth=True):
245
+ super(S3D, self).__init__()
246
+ self.num_classes = num_classes
247
+ self.gating = gating
248
+ self.space_to_depth = space_to_depth
249
+ if space_to_depth:
250
+ self.conv1 = STConv3D(
251
+ 24, 64, [2, 4, 4], stride=1, padding=(1, 2, 2), separable=False
252
+ )
253
+ else:
254
+ self.conv1 = STConv3D(
255
+ 3, 64, [3, 7, 7], stride=2, padding=(1, 3, 3), separable=False
256
+ )
257
+ self.conv_2b = STConv3D(64, 64, [1, 1, 1], separable=False)
258
+ self.conv_2c = STConv3D(64, 192, [3, 3, 3], padding=1, separable=True)
259
+ self.gating = SelfGating(192)
260
+ self.maxpool_2a = MaxPool3dTFPadding(
261
+ kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
262
+ )
263
+ self.maxpool_3a = MaxPool3dTFPadding(
264
+ kernel_size=(1, 3, 3), stride=(1, 2, 2), padding="SAME"
265
+ )
266
+ self.mixed_3b = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
267
+ self.mixed_3c = InceptionBlock(
268
+ self.mixed_3b.output_dim, 128, 128, 192, 32, 96, 64
269
+ )
270
+ self.maxpool_4a = MaxPool3dTFPadding(
271
+ kernel_size=(3, 3, 3), stride=(2, 2, 2), padding="SAME"
272
+ )
273
+ self.mixed_4b = InceptionBlock(
274
+ self.mixed_3c.output_dim, 192, 96, 208, 16, 48, 64
275
+ )
276
+ self.mixed_4c = InceptionBlock(
277
+ self.mixed_4b.output_dim, 160, 112, 224, 24, 64, 64
278
+ )
279
+ self.mixed_4d = InceptionBlock(
280
+ self.mixed_4c.output_dim, 128, 128, 256, 24, 64, 64
281
+ )
282
+ self.mixed_4e = InceptionBlock(
283
+ self.mixed_4d.output_dim, 112, 144, 288, 32, 64, 64
284
+ )
285
+ self.mixed_4f = InceptionBlock(
286
+ self.mixed_4e.output_dim, 256, 160, 320, 32, 128, 128
287
+ )
288
+ self.maxpool_5a = self.maxPool3d_5a_2x2 = MaxPool3dTFPadding(
289
+ kernel_size=(2, 2, 2), stride=(2, 2, 2), padding="SAME"
290
+ )
291
+ self.mixed_5b = InceptionBlock(
292
+ self.mixed_4f.output_dim, 256, 160, 320, 32, 128, 128
293
+ )
294
+ self.mixed_5c = InceptionBlock(
295
+ self.mixed_5b.output_dim, 384, 192, 384, 48, 128, 128
296
+ )
297
+ self.fc = nn.Linear(self.mixed_5c.output_dim, num_classes)
298
+ self.text_module = Sentence_Embedding(num_classes,
299
+ token_to_word_path=dict_path)
300
+
301
+ def _space_to_depth(self, input):
302
+ """3D space to depth trick for TPU optimization.
303
+ """
304
+ B, C, T, H, W = input.shape
305
+ input = input.view(B, C, T // 2, 2, H // 2, 2, W // 2, 2)
306
+ input = input.permute(0, 3, 5, 7, 1, 2, 4, 6)
307
+ input = input.contiguous().view(B, 8 * C, T // 2, H // 2, W // 2)
308
+ return input
309
+
310
+ def forward(self, inputs):
311
+ """Defines the S3DG base architecture."""
312
+ if self.space_to_depth:
313
+ inputs = self._space_to_depth(inputs)
314
+ net = self.conv1(inputs)
315
+ if self.space_to_depth:
316
+ # we need to replicate 'SAME' tensorflow padding
317
+ net = net[:, :, 1:, 1:, 1:]
318
+ net = self.maxpool_2a(net)
319
+ net = self.conv_2b(net)
320
+ net = self.conv_2c(net)
321
+ if self.gating:
322
+ net = self.gating(net)
323
+ net = self.maxpool_3a(net)
324
+ net = self.mixed_3b(net)
325
+ net = self.mixed_3c(net)
326
+ net = self.maxpool_4a(net)
327
+ net = self.mixed_4b(net)
328
+ net = self.mixed_4c(net)
329
+ net = self.mixed_4d(net)
330
+ net = self.mixed_4e(net)
331
+ net = self.mixed_4f(net)
332
+ net = self.maxpool_5a(net)
333
+ net = self.mixed_5b(net)
334
+ net = self.mixed_5c(net)
335
+ net = th.mean(net, dim=[2, 3, 4])
336
+ return {'video_embedding': self.fc(net), 'mixed_5c': net}
fairseq/examples/MMPT/mmpt/processors/processor.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. All Rights Reserved
2
+
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+
7
+
8
+ class Processor(object):
9
+ """
10
+ A generic processor for video (codec, feature etc.) and text.
11
+ """
12
+
13
+ def __call__(self, **kwargs):
14
+ raise NotImplementedError
15
+
16
+
17
+ class MetaProcessor(Processor):
18
+ """
19
+ A meta processor is expected to load the metadata of a dataset:
20
+ (e.g., video_ids, or captions).
21
+ You must implement the `__getitem__` (meta datasets are rather diverse.).
22
+ """
23
+
24
+ def __init__(self, config):
25
+ self.split = config.split
26
+
27
+ def __len__(self):
28
+ return len(self.data)
29
+
30
+ def __getitem__(self, idx):
31
+ raise NotImplementedError
32
+
33
+ def _get_split_path(self, config):
34
+ splits = {
35
+ "train": config.train_path,
36
+ "valid": config.val_path,
37
+ "test": config.test_path,
38
+ }
39
+ if config.split is not None:
40
+ return splits[config.split]
41
+ return config.train_path
42
+
43
+
44
+ class TextProcessor(Processor):
45
+ """
46
+ A generic Text processor: rename this as `withTokenizer`.
47
+ tokenize a string of text on-the-fly.
48
+ Warning: mostly used for end tasks.
49
+ (on-the-fly tokenization is slow for how2.)
50
+ TODO(huxu): move this class as a subclass.
51
+ """
52
+
53
+ def __init__(self, config):
54
+ self.bert_name = str(config.bert_name)
55
+ self.use_fast = config.use_fast
56
+ from transformers import AutoTokenizer
57
+ self.tokenizer = AutoTokenizer.from_pretrained(
58
+ self.bert_name, use_fast=self.use_fast
59
+ )
60
+
61
+ def __call__(self, text_id):
62
+ caption = self.tokenizer(text_id, add_special_tokens=False)
63
+ return caption["input_ids"]
64
+
65
+
66
+ class VideoProcessor(Processor):
67
+ """
68
+ A generic video processor: load a numpy video tokens by default.
69
+ """
70
+
71
+ def __init__(self, config):
72
+ self.vfeat_dir = config.vfeat_dir
73
+
74
+ def __call__(self, video_fn):
75
+ if isinstance(video_fn, tuple):
76
+ video_fn = video_fn[0]
77
+ assert isinstance(video_fn, str)
78
+ video_fn = os.path.join(self.vfeat_dir, video_fn + ".npy")
79
+ feat = np.load(video_fn)
80
+ return feat
81
+
82
+
83
+ class Aligner(object):
84
+ """
85
+ An alignprocessor align video and text and output a dict of tensors (for a model).
86
+ """
87
+ def __init__(self, config):
88
+ """__init__ needs to be light weight for more workers/threads."""
89
+ self.split = config.split
90
+ self.max_video_len = config.max_video_len
91
+ self.max_len = config.max_len
92
+ from transformers import AutoTokenizer
93
+ tokenizer = AutoTokenizer.from_pretrained(
94
+ str(config.bert_name), use_fast=config.use_fast
95
+ )
96
+ self.cls_token_id = tokenizer.cls_token_id
97
+ self.sep_token_id = tokenizer.sep_token_id
98
+ self.pad_token_id = tokenizer.pad_token_id
99
+ self.mask_token_id = tokenizer.mask_token_id
100
+
101
+ def __call__(self, video_id, video_feature, text_feature):
102
+ raise NotImplementedError
103
+
104
+ def _build_video_seq(self, video_feature, video_clips=None):
105
+ """
106
+ `video_feature`: available video tokens.
107
+ `video_clips`: video clip sequence to build.
108
+ """
109
+ if not isinstance(video_feature, np.ndarray):
110
+ raise ValueError(
111
+ "unsupported type of video_feature", type(video_feature)
112
+ )
113
+
114
+ if video_clips is None:
115
+ # this is borrowed from DSAligner
116
+ video_start = 0
117
+ video_end = min(len(video_feature), self.max_video_len)
118
+ # the whole sequence is a single clip.
119
+ video_clips = {"start": [video_start], "end": [video_end]}
120
+
121
+ vfeats = np.zeros(
122
+ (self.max_video_len, video_feature.shape[1]), dtype=np.float32
123
+ )
124
+ vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
125
+ video_len = 0
126
+ for start, end in zip(video_clips["start"], video_clips["end"]):
127
+ clip_len = min(self.max_video_len - video_len, (end - start))
128
+ if clip_len > 0:
129
+ vfeats[video_len: video_len + clip_len] = video_feature[
130
+ start: start + clip_len
131
+ ]
132
+ vmasks[video_len: video_len + clip_len] = 1
133
+ video_len += clip_len
134
+ vfeats = torch.from_numpy(vfeats)
135
+
136
+ return vfeats, vmasks
137
+
138
+ def _build_text_seq(self, text_feature, text_clip_indexs=None):
139
+ """
140
+ `text_feature`: all available clips.
141
+ `text_clip_indexes`: clip sequence to build.
142
+ """
143
+ if text_clip_indexs is None:
144
+ text_clip_indexs = [0]
145
+
146
+ full_caps = []
147
+ if isinstance(text_feature, dict):
148
+ for clip_idx in text_clip_indexs:
149
+ full_caps.extend(text_feature["cap"][clip_idx])
150
+ else:
151
+ full_caps = text_feature
152
+ max_text_len = self.max_len - self.max_video_len - 3
153
+ full_caps = full_caps[:max_text_len]
154
+ full_caps = (
155
+ [self.cls_token_id, self.sep_token_id] + full_caps + [self.sep_token_id]
156
+ )
157
+ text_pad_len = self.max_len - len(full_caps) - self.max_video_len
158
+ padded_full_caps = full_caps + [self.pad_token_id] * text_pad_len
159
+ caps = torch.LongTensor(padded_full_caps)
160
+ cmasks = torch.zeros((len(padded_full_caps),), dtype=torch.bool)
161
+ cmasks[: len(full_caps)] = 1
162
+
163
+ return caps, cmasks
164
+
165
+ def batch_post_processing(self, batch, video_feature):
166
+ return batch
167
+
168
+
169
+ class MMAttentionMask2DProcessor(Processor):
170
+ """text generation requires 2d mask
171
+ that is harder to generate by GPU at this stage."""
172
+
173
+ def __call__(self, vmask, cmask, mtype):
174
+ if mtype == "textgen":
175
+ return self._build_textgeneration_mask(vmask, cmask)
176
+ elif mtype == "videogen":
177
+ return self._build_videogeneration_mask(vmask, cmask)
178
+ else:
179
+ return self._build_mm_mask(vmask, cmask)
180
+
181
+ def _build_mm_mask(self, vmask, cmask):
182
+ mask_1d = torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)
183
+ return mask_1d[None, :].repeat(mask_1d.size(0), 1)
184
+
185
+ def _build_videogeneration_mask(self, vmask, cmask):
186
+ # cls_mask is only about text otherwise it will leak generation.
187
+ cls_text_mask = torch.cat([
188
+ # [CLS]
189
+ torch.ones(
190
+ (1,), dtype=torch.bool, device=cmask.device),
191
+ # video tokens and [SEP] for video.
192
+ torch.zeros(
193
+ (vmask.size(0) + 1,), dtype=torch.bool, device=cmask.device),
194
+ cmask[2:]
195
+ ], dim=0)
196
+
197
+ # concat horizontially.
198
+ video_len = int(vmask.sum())
199
+ video_masks = torch.cat([
200
+ # [CLS]
201
+ torch.ones(
202
+ (video_len, 1), dtype=torch.bool, device=cmask.device
203
+ ),
204
+ torch.tril(
205
+ torch.ones(
206
+ (video_len, video_len),
207
+ dtype=torch.bool, device=cmask.device)),
208
+ # video_padding
209
+ torch.zeros(
210
+ (video_len, vmask.size(0) - video_len),
211
+ dtype=torch.bool, device=cmask.device
212
+ ),
213
+ # [SEP] for video (unused).
214
+ torch.zeros(
215
+ (video_len, 1), dtype=torch.bool, device=cmask.device
216
+ ),
217
+ cmask[2:].unsqueeze(0).repeat(video_len, 1)
218
+ ], dim=1)
219
+
220
+ text_masks = cls_text_mask[None, :].repeat(
221
+ cmask.size(0) - 2, 1)
222
+ video_padding_masks = cls_text_mask[None, :].repeat(
223
+ vmask.size(0) - video_len, 1)
224
+
225
+ return torch.cat([
226
+ cls_text_mask[None, :],
227
+ video_masks,
228
+ video_padding_masks,
229
+ torch.cat([cmask[:1], vmask, cmask[1:]], dim=0)[None,:],
230
+ text_masks
231
+ ], dim=0)
232
+
233
+ def _build_textgeneration_mask(self, vmask, cmask):
234
+ # cls_mask is only about video otherwise it will leak generation.
235
+ cls_video_mask = torch.cat([
236
+ # [CLS]
237
+ torch.ones(
238
+ (1,), dtype=torch.bool, device=cmask.device),
239
+ vmask,
240
+ # [SEP]
241
+ torch.ones((1,), dtype=torch.bool, device=cmask.device),
242
+ torch.zeros(
243
+ (cmask.size(0)-2,), dtype=torch.bool, device=cmask.device)
244
+ ], dim=0)
245
+
246
+ # concat horizontially.
247
+ text_len = int(cmask[2:].sum())
248
+ text_masks = torch.cat([
249
+ # [CLS]
250
+ torch.ones(
251
+ (text_len, 1), dtype=torch.bool, device=cmask.device
252
+ ),
253
+ vmask.unsqueeze(0).repeat(text_len, 1),
254
+ # [SEP] for video.
255
+ torch.ones(
256
+ (text_len, 1), dtype=torch.bool, device=cmask.device
257
+ ),
258
+ torch.tril(
259
+ torch.ones(
260
+ (text_len, text_len),
261
+ dtype=torch.bool, device=cmask.device)),
262
+ # padding.
263
+ torch.zeros(
264
+ (text_len, cmask.size(0) - text_len - 2),
265
+ dtype=torch.bool, device=cmask.device
266
+ )
267
+ ], dim=1)
268
+
269
+ cls_video_masks = cls_video_mask[None, :].repeat(
270
+ vmask.size(0) + 2, 1)
271
+ text_padding_masks = cls_video_mask[None, :].repeat(
272
+ cmask.size(0) - text_len - 2, 1)
273
+ return torch.cat([
274
+ cls_video_masks, text_masks, text_padding_masks], dim=0)
fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ make a general fairseq task for MM pretraining.
7
+ """
8
+
9
+ import random
10
+
11
+ from fairseq.tasks import LegacyFairseqTask, register_task
12
+
13
+ from .task import Task
14
+ from .retritask import RetriTask
15
+ from ..datasets import FairseqMMDataset
16
+ from .. import utils
17
+
18
+
19
+ @register_task("mmtask")
20
+ class FairseqMMTask(LegacyFairseqTask):
21
+ @staticmethod
22
+ def add_args(parser):
23
+ # Add some command-line arguments for specifying where the data is
24
+ # located and the maximum supported input length.
25
+ parser.add_argument(
26
+ "taskconfig",
27
+ metavar="FILE",
28
+ help=("taskconfig to load all configurations" "outside fairseq parser."),
29
+ )
30
+
31
+ @classmethod
32
+ def setup_task(cls, args, **kwargs):
33
+ return FairseqMMTask(args)
34
+
35
+ def __init__(self, args):
36
+ super().__init__(args)
37
+ config = utils.load_config(args)
38
+ self.mmtask = Task.config_task(config)
39
+ self.mmtask.build_dataset()
40
+ self.mmtask.build_model()
41
+ self.mmtask.build_loss()
42
+
43
+ def load_dataset(self, split, **kwargs):
44
+ split_map = {
45
+ "train": self.mmtask.train_data,
46
+ "valid": self.mmtask.val_data,
47
+ "test": self.mmtask.test_data,
48
+ }
49
+ if split not in split_map:
50
+ raise ValueError("unknown split type.")
51
+ if split_map[split] is not None:
52
+ self.datasets[split] = FairseqMMDataset(split_map[split])
53
+
54
+ def get_batch_iterator(
55
+ self,
56
+ dataset,
57
+ max_tokens=None,
58
+ max_sentences=None,
59
+ max_positions=None,
60
+ ignore_invalid_inputs=False,
61
+ required_batch_size_multiple=1,
62
+ seed=1,
63
+ num_shards=1,
64
+ shard_id=0,
65
+ num_workers=0,
66
+ epoch=1,
67
+ data_buffer_size=0,
68
+ disable_iterator_cache=False,
69
+ skip_remainder_batch=False,
70
+ grouped_shuffling=False,
71
+ update_epoch_batch_itr=False,
72
+ ):
73
+ random.seed(epoch)
74
+ if dataset.mmdataset.split == "train" and isinstance(self.mmtask, RetriTask):
75
+ if epoch >= self.mmtask.config.retri_epoch:
76
+ if not hasattr(self.mmtask, "retri_dataloader"):
77
+ self.mmtask.build_dataloader()
78
+ self.mmtask.retrive_candidates(epoch)
79
+
80
+ return super().get_batch_iterator(
81
+ dataset,
82
+ max_tokens,
83
+ max_sentences,
84
+ max_positions,
85
+ ignore_invalid_inputs,
86
+ required_batch_size_multiple,
87
+ seed,
88
+ num_shards,
89
+ shard_id,
90
+ num_workers,
91
+ epoch,
92
+ data_buffer_size,
93
+ disable_iterator_cache,
94
+ grouped_shuffling,
95
+ update_epoch_batch_itr,
96
+ )
97
+
98
+ @property
99
+ def source_dictionary(self):
100
+ return None
101
+
102
+ @property
103
+ def target_dictionary(self):
104
+ return None
fairseq/examples/MMPT/mmpt/tasks/milncetask.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from .task import Task
9
+
10
+
11
+ class MILNCETask(Task):
12
+ def reshape_subsample(self, sample):
13
+ if (
14
+ hasattr(self.config.dataset, "subsampling")
15
+ and self.config.dataset.subsampling is not None
16
+ and self.config.dataset.subsampling > 1
17
+ ):
18
+ for key in sample:
19
+ if torch.is_tensor(sample[key]):
20
+ tensor = self.flat_subsample(sample[key])
21
+ if key in ["caps", "cmasks"]:
22
+ size = tensor.size()
23
+ batch_size = size[0] * size[1]
24
+ expanded_size = (batch_size,) + size[2:]
25
+ tensor = tensor.view(expanded_size)
26
+ sample[key] = tensor
27
+ return sample
fairseq/examples/MMPT/mmpt/tasks/retritask.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import torch
7
+ import pickle
8
+ import random
9
+
10
+ from tqdm import tqdm
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.distributed import DistributedSampler
13
+
14
+ from ..processors import (
15
+ ShardedHow2MetaProcessor,
16
+ ShardedVideoProcessor,
17
+ ShardedTextProcessor,
18
+ VariedLenAligner,
19
+ )
20
+
21
+ from ..datasets import MMDataset
22
+ from .task import Task
23
+ from ..modules import vectorpool
24
+ from ..evaluators.predictor import Predictor
25
+ from ..utils import set_seed, get_local_rank, get_world_size
26
+
27
+
28
+ class RetriTask(Task):
29
+ """abstract class for task with retrival."""
30
+
31
+ def reshape_subsample(self, sample):
32
+ for key in sample:
33
+ if torch.is_tensor(sample[key]):
34
+ sample[key] = self.flat_subsample(sample[key])
35
+ return sample
36
+
37
+ def flat_subsample(self, tensor):
38
+ if tensor.size(0) == 1:
39
+ tensor = tensor.squeeze(0)
40
+ return tensor
41
+
42
+ def build_dataloader(self):
43
+ """called by `get_batch_iterator` in fairseqmmtask. """
44
+ # TODO: hard-code dataloader for retri for now and configurable in .yaml.
45
+ # reuse the `train.lst`.
46
+ self.config.dataset.split = "train"
47
+ meta_processor = ShardedHow2MetaProcessor(self.config.dataset)
48
+ video_processor = ShardedVideoProcessor(self.config.dataset)
49
+ text_processor = ShardedTextProcessor(self.config.dataset)
50
+
51
+ aligner = VariedLenAligner(self.config.dataset)
52
+ aligner.subsampling = self.config.dataset.clip_per_video
53
+
54
+ self.retri_data = MMDataset(
55
+ meta_processor, video_processor, text_processor, aligner
56
+ )
57
+
58
+ retri_sampler = DistributedSampler(self.retri_data)
59
+ infer_scale = 16
60
+ batch_size = self.config.dataset.num_video_per_batch \
61
+ * infer_scale
62
+
63
+ self.retri_dataloader = DataLoader(
64
+ self.retri_data,
65
+ collate_fn=self.retri_data.collater,
66
+ batch_size=batch_size,
67
+ shuffle=False,
68
+ sampler=retri_sampler,
69
+ num_workers=self.config.fairseq.dataset.num_workers
70
+ )
71
+ return self.retri_dataloader
72
+
73
+ def retrive_candidates(self, epoch, dataloader=None):
74
+ if get_local_rank() == 0:
75
+ print("running retrieval model.")
76
+ out_dir = os.path.join(
77
+ self.config.fairseq.checkpoint.save_dir, "retri")
78
+ os.makedirs(out_dir, exist_ok=True)
79
+
80
+ if not os.path.isfile(
81
+ os.path.join(
82
+ out_dir, "batched_e" + str(epoch) + "_videos0.pkl")
83
+ ):
84
+ if dataloader is None:
85
+ dataloader = self.retri_dataloader
86
+
87
+ self.model.eval()
88
+ self.model.is_train = False
89
+
90
+ assert self.retri_data.meta_processor.data == \
91
+ self.train_data.meta_processor.data # video_ids not mutated.
92
+
93
+ self._retri_predict(epoch, dataloader)
94
+
95
+ self.model.train()
96
+ self.model.is_train = True
97
+
98
+ torch.distributed.barrier()
99
+ output = self._retri_sync(epoch, out_dir)
100
+ torch.distributed.barrier()
101
+ self.train_data.meta_processor.set_candidates(output)
102
+ return output
103
+
104
+
105
+ class VideoRetriTask(RetriTask):
106
+ """RetriTask on video level."""
107
+
108
+ def reshape_subsample(self, sample):
109
+ if (
110
+ hasattr(self.config.dataset, "clip_per_video")
111
+ and self.config.dataset.clip_per_video is not None
112
+ and self.config.dataset.clip_per_video > 1
113
+ ):
114
+ for key in sample:
115
+ if torch.is_tensor(sample[key]):
116
+ sample[key] = self.flat_subsample(sample[key])
117
+ return sample
118
+
119
+ def flat_subsample(self, tensor):
120
+ if tensor.size(0) == 1:
121
+ tensor = tensor.squeeze(0)
122
+ return Task.flat_subsample(self, tensor)
123
+
124
+ def _retri_predict(self, epoch, dataloader):
125
+ set_seed(epoch)
126
+ # save for retrival.
127
+ predictor = VideoPredictor(self.config)
128
+ predictor.predict_loop(
129
+ self.model, dataloader)
130
+ set_seed(epoch) # get the same text clips.
131
+ # retrival.
132
+ retri_predictor = VideoRetriPredictor(
133
+ self.config)
134
+ retri_predictor.predict_loop(
135
+ self.model, predictor.vecpool.retriver, epoch)
136
+ del predictor
137
+ del retri_predictor
138
+
139
+ def _retri_sync(self, epoch, out_dir):
140
+ # gpu do the same merge.
141
+ batched_videos = []
142
+ for local_rank in range(get_world_size()):
143
+ fn = os.path.join(
144
+ out_dir,
145
+ "batched_e" + str(epoch) + "_videos" + str(local_rank) + ".pkl")
146
+ with open(fn, "rb") as fr:
147
+ batched_videos.extend(pickle.load(fr))
148
+ print(
149
+ "[INFO] batched_videos",
150
+ len(batched_videos), len(batched_videos[0]))
151
+ return batched_videos
152
+
153
+
154
+ class VideoPredictor(Predictor):
155
+ def __init__(self, config):
156
+ vectorpool_cls = getattr(vectorpool, config.vectorpool_cls)
157
+ self.vecpool = vectorpool_cls(config)
158
+
159
+ def predict_loop(
160
+ self,
161
+ model,
162
+ dataloader,
163
+ early_stop=-1,
164
+ ):
165
+ with torch.no_grad():
166
+ if get_local_rank() == 0:
167
+ dataloader = tqdm(dataloader)
168
+ for batch_idx, batch in enumerate(dataloader):
169
+ if batch_idx == early_stop:
170
+ break
171
+ self(batch, model)
172
+ return self.finalize()
173
+
174
+ def __call__(self, sample, model, **kwargs):
175
+ param = next(model.parameters())
176
+ dtype = param.dtype
177
+ device = param.device
178
+ subsample = sample["vfeats"].size(1)
179
+ sample = self.to_ctx(sample, device, dtype)
180
+ for key in sample:
181
+ if torch.is_tensor(sample[key]):
182
+ size = sample[key].size()
183
+ if len(size) >= 2:
184
+ batch_size = size[0] * size[1]
185
+ expanded_size = (
186
+ (batch_size,) + size[2:] if len(size) > 2
187
+ else (batch_size,)
188
+ )
189
+ sample[key] = sample[key].view(expanded_size)
190
+
191
+ outputs = model(**sample)
192
+ sample.update(outputs)
193
+ self.vecpool(sample, subsample)
194
+
195
+ def finalize(self):
196
+ print("[INFO]", self.vecpool)
197
+ if not self.vecpool.retriver.db.is_trained:
198
+ self.vecpool.retriver.finalize_training()
199
+ return self.vecpool.retriver
200
+
201
+
202
+ class VideoRetriPredictor(Predictor):
203
+ """
204
+ Online Retrieval Predictor for Clips (used by RetriTask).
205
+ TODO: merge this with VisPredictor?
206
+ """
207
+
208
+ def __init__(self, config):
209
+ self.pred_dir = os.path.join(
210
+ config.fairseq.checkpoint.save_dir,
211
+ "retri")
212
+ self.num_cands = config.num_cands
213
+ self.num_video_per_batch = config.dataset.num_video_per_batch
214
+
215
+ def predict_loop(
216
+ self,
217
+ model,
218
+ retriver,
219
+ epoch,
220
+ early_stop=-1
221
+ ):
222
+ # a fake loop that only try to recover video vector
223
+ # from video_id.
224
+ batched_videos = []
225
+ # obtain available video_ids.
226
+ video_ids = list(retriver.videoid_to_vectoridx.keys())
227
+
228
+ dataloader = random.sample(
229
+ video_ids,
230
+ len(video_ids) // self.num_video_per_batch
231
+ )
232
+
233
+ if get_local_rank() == 0:
234
+ dataloader = tqdm(dataloader)
235
+ for batch_idx, batch in enumerate(dataloader):
236
+ # batch is one video id.
237
+ if batch_idx == early_stop:
238
+ break
239
+ video_ids = retriver.search_by_video_ids(
240
+ [batch], self.num_cands)[0]
241
+ if len(video_ids) > self.num_video_per_batch:
242
+ # we moved the center to make cluster robust.
243
+ video_ids = random.sample(video_ids, self.num_video_per_batch)
244
+ batched_videos.append(video_ids)
245
+ return self.finalize(batched_videos, epoch)
246
+
247
+ def finalize(self, batched_videos, epoch):
248
+ fn = os.path.join(
249
+ self.pred_dir,
250
+ "batched_e" + str(epoch) + "_videos" + str(get_local_rank()) + ".pkl")
251
+ with open(fn, "wb") as fw:
252
+ pickle.dump(batched_videos, fw, pickle.HIGHEST_PROTOCOL)
253
+ return batched_videos
fairseq/examples/MMPT/mmpt/tasks/task.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+
7
+ from .. import tasks
8
+ from .. import models
9
+ from .. import losses
10
+ from ..datasets import MMDataset
11
+ from .. import processors
12
+
13
+
14
+ class Task(object):
15
+ """
16
+ A task refers to one generic training task (e.g., training one model).
17
+ """
18
+
19
+ @classmethod
20
+ def config_task(cls, config):
21
+ """
22
+ determine whether to load a hard-coded task or config from a generic one.
23
+ via if a task string is available in config.
24
+ """
25
+ if config.task is not None:
26
+ # TODO (huxu): expand the search scope.
27
+ task_cls = getattr(tasks, config.task)
28
+ return task_cls(config)
29
+ else:
30
+ return Task(config)
31
+
32
+ def __init__(self, config):
33
+ self.config = config
34
+ self.train_data = None
35
+ self.val_data = None
36
+ self.test_data = None
37
+
38
+ self.model = None
39
+ self.loss_fn = None
40
+ self.eval_fn = None
41
+
42
+ def build_dataset(self):
43
+ """TODO (huxu): move processor breakdown to MMDataset."""
44
+ """fill-in `self.train_data`, `self.val_data` and `self.test_data`."""
45
+
46
+ meta_processor_cls = getattr(
47
+ processors, self.config.dataset.meta_processor)
48
+ video_processor_cls = getattr(
49
+ processors, self.config.dataset.video_processor)
50
+ text_processor_cls = getattr(
51
+ processors, self.config.dataset.text_processor)
52
+ aligner_cls = getattr(
53
+ processors, self.config.dataset.aligner)
54
+
55
+ if self.config.dataset.train_path is not None:
56
+ self.config.dataset.split = "train"
57
+ # may be used by meta processor.
58
+ # meta_processor controls different dataset.
59
+ meta_processor = meta_processor_cls(self.config.dataset)
60
+ video_processor = video_processor_cls(self.config.dataset)
61
+ text_processor = text_processor_cls(self.config.dataset)
62
+ aligner = aligner_cls(self.config.dataset)
63
+ self.train_data = MMDataset(
64
+ meta_processor, video_processor, text_processor, aligner
65
+ )
66
+ print("train_len", len(self.train_data))
67
+ output = self.train_data[0]
68
+ self.train_data.print_example(output)
69
+ if self.config.dataset.val_path is not None:
70
+ self.config.dataset.split = "valid"
71
+ # may be used by meta processor.
72
+ meta_processor = meta_processor_cls(self.config.dataset)
73
+ video_processor = video_processor_cls(self.config.dataset)
74
+ text_processor = text_processor_cls(self.config.dataset)
75
+ aligner = aligner_cls(self.config.dataset)
76
+ self.val_data = MMDataset(
77
+ meta_processor, video_processor, text_processor, aligner
78
+ )
79
+ print("val_len", len(self.val_data))
80
+ output = self.val_data[0]
81
+ self.val_data.print_example(output)
82
+
83
+ if self.config.dataset.split == "test":
84
+ # the following is run via lauching fairseq-validate.
85
+ meta_processor = meta_processor_cls(self.config.dataset)
86
+ video_processor = video_processor_cls(self.config.dataset)
87
+ text_processor = text_processor_cls(self.config.dataset)
88
+
89
+ self.test_data = MMDataset(
90
+ meta_processor, video_processor, text_processor, aligner
91
+ )
92
+ print("test_len", len(self.test_data))
93
+ output = self.test_data[0]
94
+ self.test_data.print_example(output)
95
+
96
+ def build_model(self, checkpoint=None):
97
+ if self.model is None:
98
+ model_cls = getattr(models, self.config.model.model_cls)
99
+ self.model = model_cls(self.config)
100
+ if checkpoint is not None:
101
+ self.load_checkpoint(checkpoint)
102
+ return self.model
103
+
104
+ def load_checkpoint(self, checkpoint):
105
+ if self.model is None:
106
+ raise ValueError("model is not initialized.")
107
+ state_dict = torch.load(checkpoint)
108
+ state_dict = self._trim_state_dict(state_dict)
109
+ self.model.load_state_dict(state_dict, strict=False)
110
+ # if it's a fp16 model, turn it back.
111
+ if next(self.model.parameters()).dtype == torch.float16:
112
+ self.model = self.model.float()
113
+ return self.model
114
+
115
+ def _trim_state_dict(self, state_dict):
116
+ from collections import OrderedDict
117
+
118
+ if "state_dict" in state_dict:
119
+ state_dict = state_dict["state_dict"]
120
+ if "model" in state_dict: # fairseq checkpoint format.
121
+ state_dict = state_dict["model"]
122
+ ret_state_dict = OrderedDict()
123
+ for (
124
+ key,
125
+ value,
126
+ ) in state_dict.items():
127
+ # remove fairseq wrapper since this is a task.
128
+ if key.startswith("mmmodel"):
129
+ key = key[len("mmmodel."):]
130
+ ret_state_dict[key] = value
131
+ return ret_state_dict
132
+
133
+ def build_loss(self):
134
+ if self.loss_fn is None and self.config.loss is not None:
135
+ loss_cls = getattr(losses, self.config.loss.loss_cls)
136
+ self.loss_fn = loss_cls()
137
+ return self.loss_fn
138
+
139
+ def flat_subsample(self, tensor):
140
+ size = tensor.size()
141
+ if len(size) >= 2:
142
+ batch_size = size[0] * size[1]
143
+ expanded_size = (
144
+ (batch_size,) + size[2:] if len(size) > 2
145
+ else (batch_size,)
146
+ )
147
+ tensor = tensor.view(expanded_size)
148
+ return tensor
149
+
150
+ def reshape_subsample(self, sample):
151
+ if (
152
+ hasattr(self.config.dataset, "subsampling")
153
+ and self.config.dataset.subsampling is not None
154
+ and self.config.dataset.subsampling > 1
155
+ ):
156
+ for key in sample:
157
+ if torch.is_tensor(sample[key]):
158
+ sample[key] = self.flat_subsample(sample[key])
159
+ return sample
160
+
161
+ def __call__(self, model, sample):
162
+ loss = None
163
+ loss_scalar = float("inf")
164
+
165
+ sample = self.reshape_subsample(sample)
166
+ outputs = self.model(**sample)
167
+ sample.update(outputs)
168
+ if self.loss_fn is not None:
169
+ loss = self.loss_fn(**sample)
170
+ loss_scalar = loss.item()
171
+
172
+ batch_size = sample["caps"].size(0)
173
+ sample_size = 1
174
+ return {
175
+ "loss": loss,
176
+ "loss_scalar": loss_scalar,
177
+ "max_len": self.config.dataset.max_len,
178
+ "batch_size": batch_size,
179
+ "sample_size": sample_size,
180
+ }
181
+
182
+ def build_dataloader(self):
183
+ """only used for trainer that lacks building loaders."""
184
+ raise NotImplementedError
fairseq/examples/MMPT/mmpt/tasks/vlmtask.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+
7
+ from .task import Task
8
+
9
+
10
+ class VLMTask(Task):
11
+ """A VLM task for reproducibility.
12
+ the collator split subsamples into two sub-batches.
13
+ This has should have no logic changes.
14
+ but changed the randomness in frame masking.
15
+ """
16
+
17
+ def flat_subsample(self, tensor):
18
+ size = tensor.size()
19
+ if len(size) >= 2:
20
+ batch_size = size[0] * (size[1] // 2)
21
+ expanded_size = (
22
+ (batch_size, 2) + size[2:] if len(size) > 2
23
+ else (batch_size, 2)
24
+ )
25
+ tensor = tensor.view(expanded_size)
26
+ tensor = torch.cat([tensor[:, 0], tensor[:, 1]], dim=0)
27
+ return tensor
fairseq/examples/MMPT/mmpt/utils/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import random
6
+ import numpy as np
7
+ import torch
8
+
9
+ from .shardedtensor import *
10
+ from .load_config import *
11
+
12
+
13
+ def set_seed(seed=43211):
14
+ random.seed(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ if torch.backends.cudnn.enabled:
19
+ torch.backends.cudnn.benchmark = False
20
+ torch.backends.cudnn.deterministic = True
21
+
22
+
23
+ def get_world_size():
24
+ if torch.distributed.is_initialized():
25
+ world_size = torch.distributed.get_world_size()
26
+ else:
27
+ world_size = 1
28
+ return world_size
29
+
30
+
31
+ def get_local_rank():
32
+ return torch.distributed.get_rank() \
33
+ if torch.distributed.is_initialized() else 0
34
+
35
+
36
+ def print_on_rank0(func):
37
+ local_rank = get_local_rank()
38
+ if local_rank == 0:
39
+ print("[INFO]", func)
40
+
41
+
42
+ class RetriMeter(object):
43
+ """
44
+ Statistics on whether retrieval yields a better pair.
45
+ """
46
+ def __init__(self, freq=1024):
47
+ self.freq = freq
48
+ self.total = 0
49
+ self.replace = 0
50
+ self.updates = 0
51
+
52
+ def __call__(self, data):
53
+ if isinstance(data, np.ndarray):
54
+ self.replace += data.shape[0] - int((data[:, 0] == -1).sum())
55
+ self.total += data.shape[0]
56
+ elif torch.is_tensor(data):
57
+ self.replace += int(data.sum())
58
+ self.total += data.size(0)
59
+ else:
60
+ raise ValueError("unsupported RetriMeter data type.", type(data))
61
+
62
+ self.updates += 1
63
+ if get_local_rank() == 0 and self.updates % self.freq == 0:
64
+ print("[INFO]", self)
65
+
66
+ def __repr__(self):
67
+ return "RetriMeter (" + str(self.replace / self.total) \
68
+ + "/" + str(self.replace) + "/" + str(self.total) + ")"
fairseq/examples/MMPT/mmpt/utils/load_config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import omegaconf
7
+ from omegaconf import OmegaConf
8
+
9
+
10
+ def load_config(args=None, config_file=None, overwrite_fairseq=False):
11
+ """TODO (huxu): move fairseq overwrite to another function."""
12
+ if args is not None:
13
+ config_file = args.taskconfig
14
+ config = recursive_config(config_file)
15
+
16
+ if config.dataset.subsampling is not None:
17
+ batch_size = config.fairseq.dataset.batch_size // config.dataset.subsampling
18
+ print(
19
+ "adjusting batch_size to {} due to subsampling {}.".format(
20
+ batch_size, config.dataset.subsampling
21
+ )
22
+ )
23
+ config.fairseq.dataset.batch_size = batch_size
24
+
25
+ is_test = config.dataset.split is not None and config.dataset.split == "test"
26
+ if not is_test:
27
+ if (
28
+ config.fairseq.checkpoint is None
29
+ or config.fairseq.checkpoint.save_dir is None
30
+ ):
31
+ raise ValueError("fairseq save_dir or save_path must be specified.")
32
+
33
+ save_dir = config.fairseq.checkpoint.save_dir
34
+ os.makedirs(save_dir, exist_ok=True)
35
+ if config.fairseq.common.tensorboard_logdir is not None:
36
+ tb_run_dir = suffix_rundir(
37
+ save_dir, config.fairseq.common.tensorboard_logdir
38
+ )
39
+ config.fairseq.common.tensorboard_logdir = tb_run_dir
40
+ print(
41
+ "update tensorboard_logdir as", config.fairseq.common.tensorboard_logdir
42
+ )
43
+ os.makedirs(save_dir, exist_ok=True)
44
+ OmegaConf.save(config=config, f=os.path.join(save_dir, "config.yaml"))
45
+
46
+ if overwrite_fairseq and config.fairseq is not None and args is not None:
47
+ # flatten fields.
48
+ for group in config.fairseq:
49
+ for field in config.fairseq[group]:
50
+ print("overwrite args." + field, "as", config.fairseq[group][field])
51
+ setattr(args, field, config.fairseq[group][field])
52
+ return config
53
+
54
+
55
+ def recursive_config(config_path):
56
+ """allows for stacking of configs in any depth."""
57
+ config = OmegaConf.load(config_path)
58
+ if config.includes is not None:
59
+ includes = config.includes
60
+ config.pop("includes")
61
+ base_config = recursive_config(includes)
62
+ config = OmegaConf.merge(base_config, config)
63
+ return config
64
+
65
+
66
+ def suffix_rundir(save_dir, run_dir):
67
+ max_id = -1
68
+ for search_dir in os.listdir(save_dir):
69
+ if search_dir.startswith(run_dir):
70
+ splits = search_dir.split("_")
71
+ cur_id = int(splits[1]) if len(splits) > 1 else 0
72
+ max_id = max(max_id, cur_id)
73
+ return os.path.join(save_dir, run_dir + "_" + str(max_id + 1))
74
+
75
+
76
+ def overwrite_dir(config, replace, basedir):
77
+ for key in config:
78
+ if isinstance(config[key], str) and config[key].startswith(basedir):
79
+ config[key] = config[key].replace(basedir, replace)
80
+ if isinstance(config[key], omegaconf.dictconfig.DictConfig):
81
+ overwrite_dir(config[key], replace, basedir)
fairseq/examples/MMPT/mmpt_cli/localjob.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+
7
+ from mmpt.utils import recursive_config
8
+
9
+
10
+ class BaseJob(object):
11
+ def __init__(self, yaml_file, dryrun=False):
12
+ self.yaml_file = yaml_file
13
+ self.config = recursive_config(yaml_file)
14
+ self.dryrun = dryrun
15
+
16
+ def submit(self, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def _normalize_cmd(self, cmd_list):
20
+ cmd_list = list(cmd_list)
21
+ yaml_index = cmd_list.index("[yaml]")
22
+ cmd_list[yaml_index] = self.yaml_file
23
+ return cmd_list
24
+
25
+
26
+ class LocalJob(BaseJob):
27
+
28
+ CMD_CONFIG = {
29
+ "local_single": [
30
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
31
+ "--task", "mmtask", "--arch", "mmarch",
32
+ "--criterion", "mmloss",
33
+ ],
34
+ "local_small": [
35
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
36
+ "--task", "mmtask", "--arch", "mmarch",
37
+ "--criterion", "mmloss",
38
+ "--distributed-world-size", "2"
39
+ ],
40
+ "local_big": [
41
+ "fairseq-train", "[yaml]", "--user-dir", "mmpt",
42
+ "--task", "mmtask", "--arch", "mmarch",
43
+ "--criterion", "mmloss",
44
+ "--distributed-world-size", "8"
45
+ ],
46
+ "local_predict": ["python", "mmpt_cli/predict.py", "[yaml]"],
47
+ }
48
+
49
+ def __init__(self, yaml_file, job_type=None, dryrun=False):
50
+ super().__init__(yaml_file, dryrun)
51
+ if job_type is None:
52
+ self.job_type = "local_single"
53
+ if self.config.task_type is not None:
54
+ self.job_type = self.config.task_type
55
+ else:
56
+ self.job_type = job_type
57
+ if self.job_type in ["local_single", "local_small"]:
58
+ if self.config.fairseq.dataset.batch_size > 32:
59
+ print("decreasing batch_size to 32 for local testing?")
60
+
61
+ def submit(self):
62
+ cmd_list = self._normalize_cmd(LocalJob.CMD_CONFIG[self.job_type])
63
+ if "predict" not in self.job_type:
64
+ # append fairseq args.
65
+ from mmpt.utils import load_config
66
+
67
+ config = load_config(config_file=self.yaml_file)
68
+ for field in config.fairseq:
69
+ for key in config.fairseq[field]:
70
+ if key in ["fp16", "reset_optimizer", "reset_dataloader", "reset_meters"]: # a list of binary flag.
71
+ param = ["--" + key.replace("_", "-")]
72
+ else:
73
+ if key == "lr":
74
+ value = str(config.fairseq[field][key][0])
75
+ elif key == "adam_betas":
76
+ value = "'"+str(config.fairseq[field][key])+"'"
77
+ else:
78
+ value = str(config.fairseq[field][key])
79
+ param = [
80
+ "--" + key.replace("_", "-"),
81
+ value
82
+ ]
83
+ cmd_list.extend(param)
84
+
85
+ print("launching", " ".join(cmd_list))
86
+ if not self.dryrun:
87
+ os.system(" ".join(cmd_list))
88
+ return JobStatus("12345678")
89
+
90
+
91
+ class JobStatus(object):
92
+ def __init__(self, job_id):
93
+ self.job_id = job_id
94
+
95
+ def __repr__(self):
96
+ return self.job_id
97
+
98
+ def __str__(self):
99
+ return self.job_id
100
+
101
+ def done(self):
102
+ return False
103
+
104
+ def running(self):
105
+ return False
106
+
107
+ def result(self):
108
+ if self.done():
109
+ return "{} is done.".format(self.job_id)
110
+ else:
111
+ return "{} is running.".format(self.job_id)
112
+
113
+ def stderr(self):
114
+ return self.result()
115
+
116
+ def stdout(self):
117
+ return self.result()
fairseq/examples/MMPT/mmpt_cli/predict.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import glob
7
+ import argparse
8
+ import pprint
9
+ import omegaconf
10
+
11
+ from omegaconf import OmegaConf
12
+ from torch.utils.data import DataLoader
13
+
14
+ from mmpt.utils import load_config, set_seed
15
+ from mmpt.evaluators import Evaluator
16
+ from mmpt.evaluators import predictor as predictor_path
17
+ from mmpt.tasks import Task
18
+ from mmpt import processors
19
+ from mmpt.datasets import MMDataset
20
+
21
+
22
+ def get_dataloader(config):
23
+ meta_processor_cls = getattr(processors, config.dataset.meta_processor)
24
+ video_processor_cls = getattr(processors, config.dataset.video_processor)
25
+ text_processor_cls = getattr(processors, config.dataset.text_processor)
26
+ aligner_cls = getattr(processors, config.dataset.aligner)
27
+
28
+ meta_processor = meta_processor_cls(config.dataset)
29
+ video_processor = video_processor_cls(config.dataset)
30
+ text_processor = text_processor_cls(config.dataset)
31
+ aligner = aligner_cls(config.dataset)
32
+
33
+ test_data = MMDataset(
34
+ meta_processor,
35
+ video_processor,
36
+ text_processor,
37
+ aligner,
38
+ )
39
+ print("test_len", len(test_data))
40
+ output = test_data[0]
41
+ test_data.print_example(output)
42
+
43
+ test_dataloader = DataLoader(
44
+ test_data,
45
+ batch_size=config.fairseq.dataset.batch_size,
46
+ shuffle=False,
47
+ num_workers=6,
48
+ collate_fn=test_data.collater,
49
+ )
50
+ return test_dataloader
51
+
52
+
53
+ def main(args):
54
+ config = load_config(args)
55
+
56
+ if isinstance(config, omegaconf.dictconfig.DictConfig):
57
+ print(OmegaConf.to_yaml(config))
58
+ else:
59
+ pp = pprint.PrettyPrinter(indent=4)
60
+ pp.print(config)
61
+
62
+ mmtask = Task.config_task(config)
63
+ mmtask.build_model()
64
+
65
+ test_dataloader = get_dataloader(config)
66
+ checkpoint_search_path = os.path.dirname(config.eval.save_path)
67
+ results = []
68
+
69
+ prefix = os.path.basename(args.taskconfig)
70
+ if prefix.startswith("test"):
71
+ # loop all checkpoint for datasets without validation set.
72
+ if "best" not in config.fairseq.common_eval.path:
73
+ print("eval each epoch.")
74
+ for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
75
+ model = mmtask.load_checkpoint(checkpoint)
76
+ ckpt = os.path.basename(checkpoint)
77
+ evaluator = Evaluator(config)
78
+ output = evaluator.evaluate(
79
+ model, test_dataloader, ckpt + "_merged")
80
+ results.append((checkpoint, output))
81
+ # use the one specified by the config lastly.
82
+ model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
83
+ evaluator = Evaluator(config)
84
+ output = evaluator.evaluate(model, test_dataloader)
85
+ results.append((config.fairseq.common_eval.path, output))
86
+
87
+ best_result = None
88
+ best_metric = 0.
89
+ for checkpoint, result in results:
90
+ print(checkpoint)
91
+ evaluator.metric.print_computed_metrics(result)
92
+ best_score = evaluator.metric.best_metric(result)
93
+ if best_score > best_metric:
94
+ best_result = (checkpoint, result)
95
+ best_metric = best_score
96
+ print("best results:")
97
+ print(best_result[0])
98
+ evaluator.metric.print_computed_metrics(best_result[1])
99
+
100
+ elif prefix.startswith("vis"):
101
+ model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
102
+ predictor_cls = getattr(predictor_path, config.predictor)
103
+ predictor = predictor_cls(config)
104
+ predictor.predict_loop(model, test_dataloader, mmtask, None)
105
+ else:
106
+ raise ValueError("unknown prefix of the config file", args.taskconfig)
107
+
108
+
109
+ if __name__ == "__main__":
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("taskconfig", type=str)
112
+ args = parser.parse_args()
113
+ main(args)
fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/mfmmlm.yaml
2
+ project_dir: mtm/mmfusionmtm
3
+ task_group:
4
+ pretrain:
5
+ task: VLMTask # reproducible
6
+ dataset:
7
+ aligner: MFMMLMAligner
8
+ model:
9
+ use_seg_emb: True # reproducible
10
+ model_cls: MMFusionMTM
11
+ mm_encoder_cls: MMBertForMFMMLM
12
+ loss:
13
+ loss_cls: MTM
14
+ finetune:
15
+ model:
16
+ use_seg_emb: True # reproducible
17
+ test:
18
+ model:
19
+ use_seg_emb: True # reproducible
fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: VideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: COINActionSegmentationMetaProcessor
5
+ train_path: data/coin/COIN.json
6
+ val_path: data/coin/COIN.json
7
+ vfeat_dir: data/feat/feat_coin_s3d
8
+ text_processor: COINActionSegmentationTextProcessor
9
+ aligner: COINActionSegmentationAligner
10
+ num_iso_layer: 12
11
+ sliding_window: 8
12
+ sliding_window_size: 32
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 1
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 8
35
+ checkpoint:
36
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/mtm/vlm/coin
41
+ task_type: sweep_big
42
+ model:
43
+ model_cls: MMFusionActionSegmentation
44
+ mm_encoder_cls: MMBertForTokenClassification
45
+ use_seg_emb: true
46
+ loss:
47
+ loss_cls: CrossEntropy
fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: ShardedVideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: ShardedHow2MetaProcessor
5
+ train_path: data/how2/how2_s3d_train.lst
6
+ val_path: data/how2/how2_s3d_val.lst
7
+ vfeat_dir: data/feat/feat_how2_s3d_shard_small
8
+ text_processor: ShardedTextProcessor
9
+ tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
10
+ aligner: MFMMLMAligner
11
+ subsampling: 32
12
+ sampled_min_len: 8
13
+ sampled_max_len: 64
14
+ max_video_len: 32
15
+ max_len: 96
16
+ lazy_vfeat_mask: true
17
+ mfm_probability: 0.15
18
+ mlm_probability: 0.15
19
+ mm_prob: 0.5
20
+ fairseq:
21
+ common:
22
+ tensorboard_logdir: run
23
+ log_interval: 1000
24
+ fp16: true
25
+ dataset:
26
+ num_workers: 4
27
+ batch_size: 256
28
+ optimization:
29
+ lr:
30
+ - 5.0e-05
31
+ clip_norm: 2.0
32
+ optimizer: adam
33
+ adam_betas: (0.9, 0.98)
34
+ lr_scheduler: polynomial_decay
35
+ total_num_update: 1000000
36
+ warmup_updates: 1000
37
+ weight_decay: 0.0
38
+ ddp_backend: no_c10d
39
+ max_epoch: 15
40
+ checkpoint:
41
+ save_dir: runs/mtm/vlm
42
+ save_interval_updates: 1024
43
+ keep_interval_updates: 2
44
+ keep_last_epochs: 30
45
+ task_type: sweep_big
46
+ slurm_config: big
47
+ eval:
48
+ save_path: runs/mtm/vlm
49
+ model:
50
+ model_cls: MMFusionMTM
51
+ mm_encoder_cls: MMBertForMFMMLM
52
+ use_seg_emb: true
53
+ loss:
54
+ loss_cls: MTM
55
+ task: VLMTask
fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: CrossTaskVideoProcessor
6
+ aligner: CrossTaskAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: CrossTaskMetaProcessor
9
+ test_path: data/crosstask/crosstask_release/videos_val.csv
10
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
11
+ val_path: data/crosstask/crosstask_release/videos_val.csv
12
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
13
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
14
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
15
+ vfeat_dir: data/feat/feat_crosstask_s3d
16
+ annotation_path: data/crosstask/crosstask_release/annotations
17
+ n_train: 30
18
+ text_processor: CrossTaskTextProcessor
19
+ num_iso_layer: 12
20
+ sliding_window: 16
21
+ sliding_window_size: 32
22
+ max_video_len: 32
23
+ max_len: 96
24
+ fairseq:
25
+ dataset:
26
+ batch_size: 1
27
+ valid_subset: test
28
+ num_workers: 2
29
+ common_eval:
30
+ path: runs/mtm/vlm/crosstask/checkpoint_best.pt
31
+ model:
32
+ model_cls: MMFusionActionLocalization
33
+ mm_encoder_cls: MMBertForJoint
34
+ use_seg_emb: true
35
+ eval:
36
+ save_path: runs/mtm/vlm/crosstask/eval
37
+ metric: CrossTaskMetric
38
+ predictor: CrossTaskPredictor
fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: DSAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTMetaProcessor
9
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTTextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/mtm/vlm/vtt/checkpoint_last.pt
22
+ model:
23
+ model_cls: MMFusionJoint
24
+ mm_encoder_cls: MMBertForJoint
25
+ use_seg_emb: true
26
+ eval:
27
+ save_path: runs/mtm/vlm/vtt/eval
28
+ metric: RetrievalMetric
29
+ predictor: RetrievalPredictor
fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: MSRVTTQAAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTQAMetaProcessor
9
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTQATextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/mtm/vlm/vttqa/checkpoint_last.pt
22
+ model:
23
+ model_cls: MMFusionJoint
24
+ mm_encoder_cls: MMBertForJoint
25
+ use_seg_emb: true
26
+ eval:
27
+ save_path: runs/mtm/vlm/vttqa/eval
28
+ metric: QAMetric
29
+ predictor: QAPredictor
fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: YoucookVideoProcessor
6
+ aligner: DSNLGAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: YoucookNLGMetaProcessor
9
+ test_path: data/youcook/val_list.txt
10
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
11
+ vfeat_dir: data/feat/feat_youcook_s3d
12
+ text_processor: NLGTextProcessor
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/mtm/vlm/youcookcap/checkpoint_best.pt
22
+ model:
23
+ model_cls: MMFusionNLG
24
+ mm_encoder_cls: MMBertForNLG
25
+ max_decode_length: 24
26
+ use_seg_emb: true
27
+ eval:
28
+ save_path: runs/mtm/vlm/youcookcap/eval
29
+ metric: NLGMetric
30
+ predictor: NLGPredictor
31
+ gen_param:
32
+ num_beams: 5
fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: VideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: MSRVTTMetaProcessor
5
+ train_path: data/msrvtt/MSRVTT_train.csv
6
+ jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
7
+ full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
8
+ dup: 20
9
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTTextProcessor
12
+ json_path: data/msrvtt/MSRVTT_data.json
13
+ aligner: DSAligner
14
+ num_iso_layer: 12
15
+ max_video_len: 32
16
+ max_len: 96
17
+ fairseq:
18
+ common:
19
+ tensorboard_logdir: run
20
+ log_interval: 1000
21
+ fp16: true
22
+ dataset:
23
+ num_workers: 4
24
+ batch_size: 256
25
+ optimization:
26
+ lr:
27
+ - 5.0e-05
28
+ clip_norm: 2.0
29
+ optimizer: adam
30
+ adam_betas: (0.9, 0.98)
31
+ lr_scheduler: polynomial_decay
32
+ total_num_update: 1000000
33
+ warmup_updates: 122
34
+ weight_decay: 0.0
35
+ ddp_backend: no_c10d
36
+ max_epoch: 10
37
+ checkpoint:
38
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
39
+ reset_optimizer: true
40
+ reset_dataloader: true
41
+ reset_meters: true
42
+ save_dir: runs/mtm/vlm/vtt
43
+ task_type: sweep_small
44
+ model:
45
+ model_cls: MMFusionJoint
46
+ mm_encoder_cls: MMBertForJoint
47
+ use_seg_emb: true
48
+ loss:
49
+ loss_cls: T2VContraLoss
fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: VideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: MSRVTTMetaProcessor
5
+ train_path: data/msrvtt/MSRVTT_train.csv
6
+ dup: 20
7
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
8
+ vfeat_dir: data/feat/feat_vtt_s3d
9
+ text_processor: MSRVTTTextProcessor
10
+ json_path: data/msrvtt/MSRVTT_data.json
11
+ aligner: DSAligner
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 128
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 5
35
+ checkpoint:
36
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/mtm/vlm/vttqa
41
+ task_type: sweep_small
42
+ model:
43
+ model_cls: MMFusionJoint
44
+ mm_encoder_cls: MMBertForJoint
45
+ use_seg_emb: true
46
+ loss:
47
+ loss_cls: V2TContraLoss
fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: YoucookVideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: YoucookMetaProcessor
5
+ train_path: data/youcook/youcook_train.pkl
6
+ val_path: data/youcook/youcook_val.pkl
7
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
8
+ use_annotation_text: true
9
+ vfeat_dir: data/feat/feat_youcook_s3d
10
+ text_processor: TextProcessor
11
+ aligner: DSAligner
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 128
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 10
35
+ checkpoint:
36
+ restore_file: runs/mtm/vlm/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/mtm/vlm/youcook
41
+ task_type: sweep_small
42
+ model:
43
+ model_cls: MMFusionJoint
44
+ mm_encoder_cls: MMBertForJoint
45
+ use_seg_emb: true
46
+ loss:
47
+ loss_cls: T2VContraLoss
fairseq/examples/MMPT/projects/retri/videoclip.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/retri/videoretri.yaml
2
+ project_dir: retri/videoclip
3
+ task_group:
4
+ pretrain:
5
+ model:
6
+ model_cls: MMFusionSeparate
7
+ mm_encoder_cls:
8
+ video_encoder_cls: MMBertForEncoder
9
+ text_encoder_cls: BertModel
10
+ num_hidden_video_layers: 6
fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: VideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: COINActionSegmentationMetaProcessor
5
+ train_path: data/coin/COIN.json
6
+ val_path: data/coin/COIN.json
7
+ vfeat_dir: data/feat/feat_coin_s3d
8
+ text_processor: COINActionSegmentationTextProcessor
9
+ aligner: COINActionSegmentationAligner
10
+ num_iso_layer: 12
11
+ sliding_window: 8
12
+ sliding_window_size: 32
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 1
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 8
35
+ checkpoint:
36
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/retri/videoclip/coin
41
+ task_type: sweep_big
42
+ model:
43
+ model_cls: MMFusionSeparateActionSegmentation
44
+ mm_encoder_cls: null
45
+ video_encoder_cls: MMBertForTokenClassification
46
+ text_encoder_cls: BertModel
47
+ num_hidden_video_layers: 6
48
+ loss:
49
+ loss_cls: CrossEntropy
fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: CrossTaskVideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: CrossTaskMetaProcessor
5
+ train_path: data/crosstask/crosstask_release/videos.csv
6
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
7
+ val_path: data/crosstask/crosstask_release/videos_val.csv
8
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
9
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
10
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
11
+ vfeat_dir: data/feat/feat_crosstask_s3d
12
+ annotation_path: data/crosstask/crosstask_release/annotations
13
+ n_train: 30
14
+ text_processor: CrossTaskTextProcessor
15
+ aligner: CrossTaskAligner
16
+ num_iso_layer: 12
17
+ sliding_window: 16
18
+ sliding_window_size: 32
19
+ max_video_len: 32
20
+ max_len: 96
21
+ fairseq:
22
+ common:
23
+ tensorboard_logdir: run
24
+ log_interval: 1000
25
+ fp16: true
26
+ dataset:
27
+ num_workers: 4
28
+ batch_size: 1
29
+ optimization:
30
+ lr:
31
+ - 5.0e-05
32
+ clip_norm: 2.0
33
+ optimizer: adam
34
+ adam_betas: (0.9, 0.98)
35
+ lr_scheduler: polynomial_decay
36
+ total_num_update: 1000000
37
+ warmup_updates: 122
38
+ weight_decay: 0.0
39
+ ddp_backend: no_c10d
40
+ max_epoch: 5
41
+ checkpoint:
42
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
43
+ reset_optimizer: true
44
+ reset_dataloader: true
45
+ reset_meters: true
46
+ save_dir: runs/retri/videoclip/crosstask
47
+ task_type: sweep_small
48
+ model:
49
+ model_cls: MMFusionSeparateActionLocalization
50
+ mm_encoder_cls: null
51
+ video_encoder_cls: MMBertForEncoder
52
+ text_encoder_cls: BertModel
53
+ num_hidden_video_layers: 6
54
+ loss:
55
+ loss_cls: BCE
fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: ShardedVideoRetriVideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: ShardedHow2VideoRetriMetaProcessor
5
+ train_path: data/how2/how2_s3d_train.lst
6
+ val_path: data/how2/how2_s3d_val.lst
7
+ vfeat_dir: data/feat/feat_how2_s3d_shard_small
8
+ text_processor: ShardedVideoRetriTextProcessor
9
+ tfeat_dir: data/feat/feat_how2_s3d_shard_small/raw_caption_dedup.bert-base-uncased.
10
+ aligner: VideoRetriOverlappedAligner
11
+ subsampling: 1
12
+ sampled_min_len: 8
13
+ sampled_max_len: 64
14
+ max_video_len: 32
15
+ max_len: 96
16
+ lazy_vfeat_mask: true
17
+ mfm_probability: 0.15
18
+ mlm_probability: 0.15
19
+ mm_prob: 0.5
20
+ sampled_video_min_len: 3
21
+ sampled_video_max_len: 32
22
+ num_video_per_batch: 32
23
+ clip_per_video: 16
24
+ fairseq:
25
+ common:
26
+ tensorboard_logdir: run
27
+ log_interval: 1000
28
+ fp16: true
29
+ dataset:
30
+ num_workers: 4
31
+ batch_size: 1
32
+ optimization:
33
+ lr:
34
+ - 5.0e-05
35
+ clip_norm: 2.0
36
+ optimizer: adam
37
+ adam_betas: (0.9, 0.98)
38
+ lr_scheduler: polynomial_decay
39
+ total_num_update: 1000000
40
+ warmup_updates: 1000
41
+ weight_decay: 0.0
42
+ ddp_backend: no_c10d
43
+ max_epoch: 25
44
+ checkpoint:
45
+ save_dir: runs/retri/videoclip
46
+ save_interval_updates: 1024
47
+ keep_interval_updates: 2
48
+ keep_last_epochs: 30
49
+ task_type: sweep_big
50
+ slurm_config: big
51
+ eval:
52
+ save_path: runs/retri/videoclip
53
+ model:
54
+ model_cls: MMFusionSeparate
55
+ mm_encoder_cls: null
56
+ video_encoder_cls: MMBertForEncoder
57
+ text_encoder_cls: BertModel
58
+ num_hidden_video_layers: 6
59
+ loss:
60
+ loss_cls: MMContraLoss
61
+ task: VideoRetriTask
62
+ retri_epoch: 1
63
+ vectorpool_cls: VideoVectorPool
64
+ retriever_cls: VectorRetriever
65
+ num_cands: 64
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: COINActionSegmentationAligner
7
+ bert_name: bert-base-uncased
8
+ test_path: data/coin/COIN.json
9
+ meta_processor: COINActionSegmentationMetaProcessor
10
+ vfeat_dir: data/feat/feat_coin_s3d
11
+ text_processor: COINActionSegmentationTextProcessor
12
+ num_iso_layer: 12
13
+ sliding_window: 16
14
+ sliding_window_size: 32
15
+ max_video_len: 32
16
+ max_len: 96
17
+ fairseq:
18
+ dataset:
19
+ batch_size: 1
20
+ valid_subset: test
21
+ num_workers: 2
22
+ common_eval:
23
+ path: runs/retri/videoclip/coin/checkpoint_best.pt
24
+ model:
25
+ model_cls: MMFusionSeparateActionSegmentation
26
+ mm_encoder_cls: null
27
+ video_encoder_cls: MMBertForTokenClassification
28
+ text_encoder_cls: BertModel
29
+ num_hidden_video_layers: 6
30
+ eval:
31
+ save_path: runs/retri/videoclip/coin/eval
32
+ metric: COINActionSegmentationMetric
33
+ predictor: COINPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: COINActionSegmentationAligner
7
+ bert_name: bert-base-uncased
8
+ test_path: data/coin/COIN.json
9
+ meta_processor: COINActionSegmentationMetaProcessor
10
+ vfeat_dir: data/feat/feat_coin_s3d
11
+ text_processor: COINActionSegmentationTextProcessor
12
+ num_iso_layer: 12
13
+ sliding_window: 16
14
+ sliding_window_size: 32
15
+ max_video_len: 32
16
+ max_len: 96
17
+ fairseq:
18
+ dataset:
19
+ batch_size: 1
20
+ valid_subset: test
21
+ num_workers: 2
22
+ common_eval:
23
+ path: runs/retri/videoclip/checkpoint_best.pt
24
+ model:
25
+ model_cls: MMFusionSeparate
26
+ mm_encoder_cls: null
27
+ video_encoder_cls: MMBertForEncoder
28
+ text_encoder_cls: BertModel
29
+ num_hidden_video_layers: 6
30
+ eval:
31
+ save_path: runs/retri/videoclip/coin_zs/eval
32
+ metric: COINActionSegmentationMetric
33
+ predictor: COINZSPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: CrossTaskVideoProcessor
6
+ aligner: CrossTaskAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: CrossTaskMetaProcessor
9
+ test_path: data/crosstask/crosstask_release/videos_val.csv
10
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
11
+ val_path: data/crosstask/crosstask_release/videos_val.csv
12
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
13
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
14
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
15
+ vfeat_dir: data/feat/feat_crosstask_s3d
16
+ annotation_path: data/crosstask/crosstask_release/annotations
17
+ n_train: 30
18
+ text_processor: CrossTaskTextProcessor
19
+ num_iso_layer: 12
20
+ sliding_window: 16
21
+ sliding_window_size: 32
22
+ max_video_len: 32
23
+ max_len: 96
24
+ fairseq:
25
+ dataset:
26
+ batch_size: 1
27
+ valid_subset: test
28
+ num_workers: 2
29
+ common_eval:
30
+ path: runs/retri/videoclip/crosstask/checkpoint_best.pt
31
+ model:
32
+ model_cls: MMFusionSeparateActionLocalization
33
+ mm_encoder_cls: null
34
+ video_encoder_cls: MMBertForEncoder
35
+ text_encoder_cls: BertModel
36
+ num_hidden_video_layers: 6
37
+ eval:
38
+ save_path: runs/retri/videoclip/crosstask/eval
39
+ metric: CrossTaskMetric
40
+ predictor: CrossTaskPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: CrossTaskVideoProcessor
6
+ aligner: CrossTaskAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: CrossTaskMetaProcessor
9
+ test_path: data/crosstask/crosstask_release/videos_val.csv
10
+ train_csv_path: data/crosstask/crosstask_release/videos.csv
11
+ val_path: data/crosstask/crosstask_release/videos_val.csv
12
+ val_csv_path: data/crosstask/crosstask_release/videos_val.csv
13
+ primary_path: data/crosstask/crosstask_release/tasks_primary.txt
14
+ related_path: data/crosstask/crosstask_release/tasks_related.txt
15
+ vfeat_dir: data/feat/feat_crosstask_s3d
16
+ annotation_path: data/crosstask/crosstask_release/annotations
17
+ n_train: 30
18
+ text_processor: CrossTaskTextProcessor
19
+ num_iso_layer: 12
20
+ sliding_window: 16
21
+ sliding_window_size: 32
22
+ max_video_len: 32
23
+ max_len: 96
24
+ fairseq:
25
+ dataset:
26
+ batch_size: 1
27
+ valid_subset: test
28
+ num_workers: 2
29
+ common_eval:
30
+ path: runs/retri/videoclip/checkpoint_best.pt
31
+ model:
32
+ model_cls: MMFusionSeparateActionLocalization
33
+ mm_encoder_cls: null
34
+ video_encoder_cls: MMBertForEncoder
35
+ text_encoder_cls: BertModel
36
+ num_hidden_video_layers: 6
37
+ eval:
38
+ save_path: runs/retri/videoclip/crosstask_zs/eval
39
+ metric: CrossTaskMetric
40
+ predictor: CrossTaskPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: DSAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTMetaProcessor
9
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTTextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/retri/videoclip/vtt/checkpoint_last.pt
22
+ model:
23
+ model_cls: MMFusionSeparate
24
+ mm_encoder_cls: null
25
+ video_encoder_cls: MMBertForEncoder
26
+ text_encoder_cls: BertModel
27
+ num_hidden_video_layers: 6
28
+ eval:
29
+ save_path: runs/retri/videoclip/vtt/eval
30
+ metric: RetrievalMetric
31
+ predictor: RetrievalPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: DSAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTMetaProcessor
9
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTTextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/retri/videoclip/checkpoint_best.pt
22
+ model:
23
+ model_cls: MMFusionSeparate
24
+ mm_encoder_cls: null
25
+ video_encoder_cls: MMBertForEncoder
26
+ text_encoder_cls: BertModel
27
+ num_hidden_video_layers: 6
28
+ eval:
29
+ save_path: runs/retri/videoclip/vtt_zs/eval
30
+ metric: RetrievalMetric
31
+ predictor: RetrievalPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: MSRVTTQAAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTQAMetaProcessor
9
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTQATextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/retri/videoclip/vttqa/checkpoint_last.pt
22
+ model:
23
+ model_cls: MMFusionSeparate
24
+ mm_encoder_cls: null
25
+ video_encoder_cls: MMBertForEncoder
26
+ text_encoder_cls: BertModel
27
+ num_hidden_video_layers: 6
28
+ eval:
29
+ save_path: runs/retri/videoclip/vttqa/eval
30
+ metric: QAMetric
31
+ predictor: QAPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: MSRVTTQAAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: MSRVTTQAMetaProcessor
9
+ test_path: data/msrvtt-qa/MSR_MC_test.csv
10
+ vfeat_dir: data/feat/feat_vtt_s3d
11
+ text_processor: MSRVTTQATextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/retri/videoclip/checkpoint_best.pt
22
+ model:
23
+ model_cls: MMFusionSeparate
24
+ mm_encoder_cls: null
25
+ video_encoder_cls: MMBertForEncoder
26
+ text_encoder_cls: BertModel
27
+ num_hidden_video_layers: 6
28
+ eval:
29
+ save_path: runs/retri/videoclip/vttqa_zs/eval
30
+ metric: QAMetric
31
+ predictor: QAPredictor
fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: YoucookVideoProcessor
6
+ aligner: DSAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: YoucookMetaProcessor
9
+ test_path: data/youcook/youcook_val.pkl
10
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
11
+ use_annotation_text: true
12
+ vfeat_dir: data/feat/feat_youcook_s3d
13
+ text_processor: TextProcessor
14
+ num_iso_layer: 12
15
+ max_video_len: 32
16
+ max_len: 96
17
+ fairseq:
18
+ dataset:
19
+ batch_size: 256
20
+ valid_subset: test
21
+ num_workers: 2
22
+ common_eval:
23
+ path: runs/retri/videoclip/youcook/checkpoint_last.pt
24
+ model:
25
+ model_cls: MMFusionSeparate
26
+ mm_encoder_cls: null
27
+ video_encoder_cls: MMBertForEncoder
28
+ text_encoder_cls: BertModel
29
+ num_hidden_video_layers: 6
30
+ eval:
31
+ save_path: runs/retri/videoclip/youcook/eval
32
+ metric: RetrievalMetric
33
+ predictor: RetrievalPredictor