PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
9043f3c
·
verified ·
1 Parent(s): d72889e

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/fairseq.egg-info/PKG-INFO +283 -0
  2. fairseq/fairseq.egg-info/SOURCES.txt +1546 -0
  3. fairseq/fairseq.egg-info/entry_points.txt +9 -0
  4. fairseq/fairseq.egg-info/requires.txt +22 -0
  5. fairseq/fairseq.egg-info/top_level.txt +4 -0
  6. fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc +0 -0
  7. fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc +0 -0
  8. fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc +0 -0
  9. fairseq/fairseq/__pycache__/pdb.cpython-310.pyc +0 -0
  10. fairseq/fairseq_cli/__init__.py +0 -0
  11. fairseq/fairseq_cli/eval_lm.py +347 -0
  12. fairseq/fairseq_cli/generate.py +417 -0
  13. fairseq/fairseq_cli/hydra_train.py +91 -0
  14. fairseq/fairseq_cli/hydra_validate.py +188 -0
  15. fairseq/fairseq_cli/interactive.py +317 -0
  16. fairseq/fairseq_cli/preprocess.py +393 -0
  17. fairseq/fairseq_cli/score.py +102 -0
  18. fairseq/fairseq_cli/train.py +581 -0
  19. fairseq/fairseq_cli/validate.py +153 -0
  20. fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py +3 -0
  21. fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py +23 -0
  22. fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py +121 -0
  23. fairseq/hydra_plugins/dependency_submitit_launcher/setup.py +29 -0
  24. fairseq/scripts/__init__.py +0 -0
  25. fairseq/scripts/average_checkpoints.py +176 -0
  26. fairseq/scripts/build_sym_alignment.py +97 -0
  27. fairseq/scripts/check_installation.py +36 -0
  28. fairseq/scripts/compare_namespaces.py +46 -0
  29. fairseq/scripts/compound_split_bleu.sh +20 -0
  30. fairseq/scripts/constraints/extract.py +90 -0
  31. fairseq/scripts/constraints/validate.py +34 -0
  32. fairseq/scripts/convert_dictionary.lua +34 -0
  33. fairseq/scripts/convert_model.lua +108 -0
  34. fairseq/scripts/count_docs.py +58 -0
  35. fairseq/scripts/read_binarized.py +48 -0
  36. fairseq/scripts/rm_pt.py +141 -0
  37. fairseq/scripts/sacrebleu.sh +27 -0
  38. fairseq/scripts/shard_docs.py +54 -0
  39. fairseq/scripts/split_train_valid_docs.py +86 -0
  40. fairseq/scripts/spm_decode.py +53 -0
  41. fairseq/scripts/spm_encode.py +119 -0
  42. fairseq/scripts/spm_train.py +16 -0
  43. fairseq/scripts/test_fsdp.sh +24 -0
  44. fairseq/tests/__init__.py +0 -0
  45. fairseq/tests/tasks/test_masked_lm.py +78 -0
  46. fairseq/tests/tasks/test_span_masked_lm.py +106 -0
  47. fairseq/tests/test_activation_checkpointing.py +79 -0
  48. fairseq/tests/test_amp_optimizer.py +75 -0
  49. fairseq/tests/test_average_checkpoints.py +134 -0
  50. fairseq/tests/test_backtranslation_dataset.py +123 -0
fairseq/fairseq.egg-info/PKG-INFO ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: fairseq
3
+ Version: 0.12.2
4
+ Summary: Facebook AI Research Sequence-to-Sequence Toolkit
5
+ Home-page: https://github.com/pytorch/fairseq
6
+ Classifier: Intended Audience :: Science/Research
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Programming Language :: Python :: 3.6
9
+ Classifier: Programming Language :: Python :: 3.7
10
+ Classifier: Programming Language :: Python :: 3.8
11
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: cffi
15
+ Requires-Dist: cython
16
+ Requires-Dist: hydra-core<1.1,>=1.0.7
17
+ Requires-Dist: omegaconf<2.1
18
+ Requires-Dist: numpy>=1.21.3
19
+ Requires-Dist: regex
20
+ Requires-Dist: sacrebleu>=1.4.12
21
+ Requires-Dist: torch>=1.13
22
+ Requires-Dist: tqdm
23
+ Requires-Dist: bitarray
24
+ Requires-Dist: torchaudio>=0.8.0
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: packaging
27
+ Provides-Extra: dev
28
+ Requires-Dist: flake8; extra == "dev"
29
+ Requires-Dist: pytest; extra == "dev"
30
+ Requires-Dist: black==22.3.0; extra == "dev"
31
+ Provides-Extra: docs
32
+ Requires-Dist: sphinx; extra == "docs"
33
+ Requires-Dist: sphinx-argparse; extra == "docs"
34
+ Dynamic: classifier
35
+ Dynamic: description
36
+ Dynamic: description-content-type
37
+ Dynamic: home-page
38
+ Dynamic: provides-extra
39
+ Dynamic: requires-dist
40
+ Dynamic: summary
41
+
42
+ <p align="center">
43
+ <img src="docs/fairseq_logo.png" width="150">
44
+ <br />
45
+ <br />
46
+ <a href="https://opensource.fb.com/support-ukraine"><img alt="Support Ukraine" src="https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB" /></a>
47
+ <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
48
+ <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
49
+ <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
50
+ <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
51
+ <a href="https://app.circleci.com/pipelines/github/facebookresearch/fairseq/"><img alt="CicleCI Status" src="https://circleci.com/gh/facebookresearch/fairseq.svg?style=shield" /></a>
52
+ </p>
53
+
54
+ --------------------------------------------------------------------------------
55
+
56
+ Fairseq(-py) is a sequence modeling toolkit that allows researchers and
57
+ developers to train custom models for translation, summarization, language
58
+ modeling and other text generation tasks.
59
+
60
+ We provide reference implementations of various sequence modeling papers:
61
+
62
+ <details><summary>List of implemented papers</summary><p>
63
+
64
+ * **Convolutional Neural Networks (CNN)**
65
+ + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
66
+ + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
67
+ + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
68
+ + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
69
+ + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
70
+ * **LightConv and DynamicConv models**
71
+ + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
72
+ * **Long Short-Term Memory (LSTM) networks**
73
+ + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
74
+ * **Transformer (self-attention) networks**
75
+ + Attention Is All You Need (Vaswani et al., 2017)
76
+ + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
77
+ + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
78
+ + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
79
+ + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
80
+ + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
81
+ + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
82
+ + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
83
+ + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
84
+ + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
85
+ + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
86
+ + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
87
+ + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
88
+ + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
89
+ + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
90
+ + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
91
+ + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
92
+ + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
93
+ + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
94
+ + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
95
+ + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430)
96
+ + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
97
+ + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
98
+ + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680)
99
+ + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf)
100
+ + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf)
101
+ + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md)
102
+ * **Non-autoregressive Transformers**
103
+ + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
104
+ + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
105
+ + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
106
+ + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
107
+ + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
108
+ * **Finetuning**
109
+ + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
110
+
111
+ </p></details>
112
+
113
+ ### What's New:
114
+ * May 2023 [Released models for Scaling Speech Technology to 1,000+ Languages (Pratap, et al., 2023)](examples/mms/README.md)
115
+ * June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md)
116
+ * May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers)
117
+ * December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md)
118
+ * October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md)
119
+ * October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md)
120
+ * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
121
+ * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
122
+ * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
123
+ * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
124
+ * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
125
+ * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
126
+ * February 2021 [Added LASER training code](examples/laser/README.md)
127
+ * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
128
+ * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
129
+ * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
130
+ * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
131
+ * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
132
+ * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
133
+ * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
134
+ * October 2020: [Added CRISS models and code](examples/criss/README.md)
135
+
136
+ <details><summary>Previous updates</summary><p>
137
+
138
+ * September 2020: [Added Linformer code](examples/linformer/README.md)
139
+ * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
140
+ * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
141
+ * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
142
+ * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
143
+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
144
+ * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
145
+ * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
146
+ * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
147
+ * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
148
+ * February 2020: [mBART model and code released](examples/mbart/README.md)
149
+ * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
150
+ * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
151
+ * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
152
+ * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
153
+ * November 2019: [BART model and code released](examples/bart/README.md)
154
+ * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
155
+ * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
156
+ * August 2019: [WMT'19 models released](examples/wmt19/README.md)
157
+ * July 2019: fairseq relicensed under MIT license
158
+ * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
159
+ * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
160
+
161
+ </p></details>
162
+
163
+ ### Features:
164
+
165
+ * multi-GPU training on one machine or across multiple machines (data and model parallel)
166
+ * fast generation on both CPU and GPU with multiple search algorithms implemented:
167
+ + beam search
168
+ + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
169
+ + sampling (unconstrained, top-k and top-p/nucleus)
170
+ + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
171
+ * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
172
+ * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
173
+ * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
174
+ * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
175
+ * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
176
+ * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
177
+
178
+ We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
179
+ with a convenient `torch.hub` interface:
180
+
181
+ ``` python
182
+ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
183
+ en2de.translate('Hello world', beam=5)
184
+ # 'Hallo Welt'
185
+ ```
186
+
187
+ See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
188
+ and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
189
+
190
+ # Requirements and Installation
191
+
192
+ * [PyTorch](http://pytorch.org/) version >= 1.10.0
193
+ * Python version >= 3.8
194
+ * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
195
+ * **To install fairseq** and develop locally:
196
+
197
+ ``` bash
198
+ git clone https://github.com/pytorch/fairseq
199
+ cd fairseq
200
+ pip install --editable ./
201
+
202
+ # on MacOS:
203
+ # CFLAGS="-stdlib=libc++" pip install --editable ./
204
+
205
+ # to install the latest stable release (0.10.x)
206
+ # pip install fairseq
207
+ ```
208
+
209
+ * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
210
+
211
+ ``` bash
212
+ git clone https://github.com/NVIDIA/apex
213
+ cd apex
214
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
215
+ --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
216
+ --global-option="--fast_multihead_attn" ./
217
+ ```
218
+
219
+ * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
220
+ * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
221
+ as command line options to `nvidia-docker run` .
222
+
223
+ # Getting Started
224
+
225
+ The [full documentation](https://fairseq.readthedocs.io/) contains instructions
226
+ for getting started, training new models and extending fairseq with new model
227
+ types and tasks.
228
+
229
+ # Pre-trained models and examples
230
+
231
+ We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
232
+ as well as example training and evaluation commands.
233
+
234
+ * [Translation](examples/translation/README.md): convolutional and transformer models are available
235
+ * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
236
+
237
+ We also have more detailed READMEs to reproduce results from specific papers:
238
+
239
+ * [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md)
240
+ * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
241
+ * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
242
+ * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
243
+ * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
244
+ * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
245
+ * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
246
+ * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
247
+ * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
248
+ * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
249
+ * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
250
+ * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
251
+ * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
252
+ * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
253
+ * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
254
+ * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
255
+ * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
256
+ * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
257
+ * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
258
+ * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
259
+ * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
260
+
261
+ # Join the fairseq community
262
+
263
+ * Twitter: https://twitter.com/fairseq
264
+ * Facebook page: https://www.facebook.com/groups/fairseq.users
265
+ * Google group: https://groups.google.com/forum/#!forum/fairseq-users
266
+
267
+ # License
268
+
269
+ fairseq(-py) is MIT-licensed.
270
+ The license applies to the pre-trained models as well.
271
+
272
+ # Citation
273
+
274
+ Please cite as:
275
+
276
+ ``` bibtex
277
+ @inproceedings{ott2019fairseq,
278
+ title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
279
+ author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
280
+ booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
281
+ year = {2019},
282
+ }
283
+ ```
fairseq/fairseq.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,1546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ setup.cfg
6
+ setup.py
7
+ examples/operators/alignment_train_cpu.cpp
8
+ examples/operators/alignment_train_cuda.cpp
9
+ examples/operators/alignment_train_kernel.cu
10
+ fairseq/__init__.py
11
+ fairseq/binarizer.py
12
+ fairseq/checkpoint_utils.py
13
+ fairseq/file_chunker_utils.py
14
+ fairseq/file_io.py
15
+ fairseq/file_utils.py
16
+ fairseq/hub_utils.py
17
+ fairseq/incremental_decoding_utils.py
18
+ fairseq/iterative_refinement_generator.py
19
+ fairseq/nan_detector.py
20
+ fairseq/ngram_repeat_block.py
21
+ fairseq/options.py
22
+ fairseq/pdb.py
23
+ fairseq/quantization_utils.py
24
+ fairseq/registry.py
25
+ fairseq/search.py
26
+ fairseq/sequence_generator.py
27
+ fairseq/sequence_scorer.py
28
+ fairseq/speech_generator.py
29
+ fairseq/token_generation_constraints.py
30
+ fairseq/tokenizer.py
31
+ fairseq/trainer.py
32
+ fairseq/utils.py
33
+ fairseq/version.py
34
+ fairseq/version.txt
35
+ fairseq.egg-info/PKG-INFO
36
+ fairseq.egg-info/SOURCES.txt
37
+ fairseq.egg-info/dependency_links.txt
38
+ fairseq.egg-info/entry_points.txt
39
+ fairseq.egg-info/not-zip-safe
40
+ fairseq.egg-info/requires.txt
41
+ fairseq.egg-info/top_level.txt
42
+ fairseq/benchmark/__init__.py
43
+ fairseq/benchmark/benchmark_multihead_attention.py
44
+ fairseq/benchmark/dummy_dataset.py
45
+ fairseq/benchmark/dummy_lm.py
46
+ fairseq/benchmark/dummy_masked_lm.py
47
+ fairseq/benchmark/dummy_model.py
48
+ fairseq/benchmark/dummy_mt.py
49
+ fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
50
+ fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
51
+ fairseq/clib/libbase/balanced_assignment.cpp
52
+ fairseq/clib/libbleu/libbleu.cpp
53
+ fairseq/clib/libbleu/module.cpp
54
+ fairseq/clib/libnat/edit_dist.cpp
55
+ fairseq/clib/libnat_cuda/binding.cpp
56
+ fairseq/clib/libnat_cuda/edit_dist.cu
57
+ fairseq/config/__init__.py
58
+ fairseq/config/config.yaml
59
+ fairseq/config/fb_run_config/slurm.yaml
60
+ fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml
61
+ fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml
62
+ fairseq/config/model/transformer_lm/transformer_lm_big.yaml
63
+ fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml
64
+ fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml
65
+ fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml
66
+ fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml
67
+ fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml
68
+ fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml
69
+ fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml
70
+ fairseq/config/model/wav2vec2/wav2vec2_base.yaml
71
+ fairseq/config/model/wav2vec2/wav2vec2_large.yaml
72
+ fairseq/criterions/__init__.py
73
+ fairseq/criterions/adaptive_loss.py
74
+ fairseq/criterions/composite_loss.py
75
+ fairseq/criterions/cross_entropy.py
76
+ fairseq/criterions/ctc.py
77
+ fairseq/criterions/fairseq_criterion.py
78
+ fairseq/criterions/fastspeech2_loss.py
79
+ fairseq/criterions/hubert_criterion.py
80
+ fairseq/criterions/label_smoothed_cross_entropy.py
81
+ fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py
82
+ fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
83
+ fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py
84
+ fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py
85
+ fairseq/criterions/legacy_masked_lm.py
86
+ fairseq/criterions/masked_lm.py
87
+ fairseq/criterions/model_criterion.py
88
+ fairseq/criterions/nat_loss.py
89
+ fairseq/criterions/sentence_prediction.py
90
+ fairseq/criterions/sentence_prediction_adapters.py
91
+ fairseq/criterions/sentence_ranking.py
92
+ fairseq/criterions/speech_dlm_criterion.py
93
+ fairseq/criterions/speech_to_speech_criterion.py
94
+ fairseq/criterions/speech_ulm_criterion.py
95
+ fairseq/criterions/tacotron2_loss.py
96
+ fairseq/criterions/wav2vec_criterion.py
97
+ fairseq/data/__init__.py
98
+ fairseq/data/add_class_target_dataset.py
99
+ fairseq/data/add_target_dataset.py
100
+ fairseq/data/append_token_dataset.py
101
+ fairseq/data/backtranslation_dataset.py
102
+ fairseq/data/base_wrapper_dataset.py
103
+ fairseq/data/bucket_pad_length_dataset.py
104
+ fairseq/data/codedataset.py
105
+ fairseq/data/colorize_dataset.py
106
+ fairseq/data/concat_dataset.py
107
+ fairseq/data/concat_sentences_dataset.py
108
+ fairseq/data/data_utils.py
109
+ fairseq/data/data_utils_fast.pyx
110
+ fairseq/data/denoising_dataset.py
111
+ fairseq/data/dictionary.py
112
+ fairseq/data/fairseq_dataset.py
113
+ fairseq/data/fasta_dataset.py
114
+ fairseq/data/id_dataset.py
115
+ fairseq/data/indexed_dataset.py
116
+ fairseq/data/iterators.py
117
+ fairseq/data/language_pair_dataset.py
118
+ fairseq/data/list_dataset.py
119
+ fairseq/data/lm_context_window_dataset.py
120
+ fairseq/data/lru_cache_dataset.py
121
+ fairseq/data/mask_tokens_dataset.py
122
+ fairseq/data/monolingual_dataset.py
123
+ fairseq/data/multi_corpus_dataset.py
124
+ fairseq/data/multi_corpus_sampled_dataset.py
125
+ fairseq/data/nested_dictionary_dataset.py
126
+ fairseq/data/noising.py
127
+ fairseq/data/num_samples_dataset.py
128
+ fairseq/data/numel_dataset.py
129
+ fairseq/data/offset_tokens_dataset.py
130
+ fairseq/data/pad_dataset.py
131
+ fairseq/data/padding_mask_dataset.py
132
+ fairseq/data/plasma_utils.py
133
+ fairseq/data/prepend_dataset.py
134
+ fairseq/data/prepend_token_dataset.py
135
+ fairseq/data/raw_label_dataset.py
136
+ fairseq/data/replace_dataset.py
137
+ fairseq/data/resampling_dataset.py
138
+ fairseq/data/roll_dataset.py
139
+ fairseq/data/round_robin_zip_datasets.py
140
+ fairseq/data/shorten_dataset.py
141
+ fairseq/data/sort_dataset.py
142
+ fairseq/data/span_mask_tokens_dataset.py
143
+ fairseq/data/speech_dlm_dataset.py
144
+ fairseq/data/strip_token_dataset.py
145
+ fairseq/data/subsample_dataset.py
146
+ fairseq/data/text_compressor.py
147
+ fairseq/data/token_block_dataset.py
148
+ fairseq/data/token_block_utils_fast.pyx
149
+ fairseq/data/transform_eos_concat_langpair_dataset.py
150
+ fairseq/data/transform_eos_dataset.py
151
+ fairseq/data/transform_eos_lang_pair_dataset.py
152
+ fairseq/data/audio/__init__.py
153
+ fairseq/data/audio/audio_utils.py
154
+ fairseq/data/audio/data_cfg.py
155
+ fairseq/data/audio/frm_text_to_speech_dataset.py
156
+ fairseq/data/audio/hubert_dataset.py
157
+ fairseq/data/audio/multi_modality_dataset.py
158
+ fairseq/data/audio/raw_audio_dataset.py
159
+ fairseq/data/audio/speech_to_speech_dataset.py
160
+ fairseq/data/audio/speech_to_text_dataset.py
161
+ fairseq/data/audio/speech_to_text_joint_dataset.py
162
+ fairseq/data/audio/text_to_speech_dataset.py
163
+ fairseq/data/audio/dataset_transforms/__init__.py
164
+ fairseq/data/audio/dataset_transforms/concataugment.py
165
+ fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py
166
+ fairseq/data/audio/feature_transforms/__init__.py
167
+ fairseq/data/audio/feature_transforms/delta_deltas.py
168
+ fairseq/data/audio/feature_transforms/global_cmvn.py
169
+ fairseq/data/audio/feature_transforms/specaugment.py
170
+ fairseq/data/audio/feature_transforms/utterance_cmvn.py
171
+ fairseq/data/audio/waveform_transforms/__init__.py
172
+ fairseq/data/audio/waveform_transforms/noiseaugment.py
173
+ fairseq/data/encoders/__init__.py
174
+ fairseq/data/encoders/byte_bpe.py
175
+ fairseq/data/encoders/byte_utils.py
176
+ fairseq/data/encoders/bytes.py
177
+ fairseq/data/encoders/characters.py
178
+ fairseq/data/encoders/fastbpe.py
179
+ fairseq/data/encoders/gpt2_bpe.py
180
+ fairseq/data/encoders/gpt2_bpe_utils.py
181
+ fairseq/data/encoders/hf_bert_bpe.py
182
+ fairseq/data/encoders/hf_byte_bpe.py
183
+ fairseq/data/encoders/moses_tokenizer.py
184
+ fairseq/data/encoders/nltk_tokenizer.py
185
+ fairseq/data/encoders/sentencepiece_bpe.py
186
+ fairseq/data/encoders/space_tokenizer.py
187
+ fairseq/data/encoders/subword_nmt_bpe.py
188
+ fairseq/data/encoders/utils.py
189
+ fairseq/data/huffman/__init__.py
190
+ fairseq/data/huffman/huffman_coder.py
191
+ fairseq/data/huffman/huffman_mmap_indexed_dataset.py
192
+ fairseq/data/legacy/__init__.py
193
+ fairseq/data/legacy/block_pair_dataset.py
194
+ fairseq/data/legacy/masked_lm_dataset.py
195
+ fairseq/data/legacy/masked_lm_dictionary.py
196
+ fairseq/data/multilingual/__init__.py
197
+ fairseq/data/multilingual/multilingual_data_manager.py
198
+ fairseq/data/multilingual/multilingual_utils.py
199
+ fairseq/data/multilingual/sampled_multi_dataset.py
200
+ fairseq/data/multilingual/sampled_multi_epoch_dataset.py
201
+ fairseq/data/multilingual/sampling_method.py
202
+ fairseq/dataclass/__init__.py
203
+ fairseq/dataclass/configs.py
204
+ fairseq/dataclass/constants.py
205
+ fairseq/dataclass/initialize.py
206
+ fairseq/dataclass/utils.py
207
+ fairseq/distributed/__init__.py
208
+ fairseq/distributed/distributed_timeout_wrapper.py
209
+ fairseq/distributed/fully_sharded_data_parallel.py
210
+ fairseq/distributed/legacy_distributed_data_parallel.py
211
+ fairseq/distributed/module_proxy_wrapper.py
212
+ fairseq/distributed/tpu_distributed_data_parallel.py
213
+ fairseq/distributed/utils.py
214
+ fairseq/examples/.gitignore
215
+ fairseq/examples/__init__.py
216
+ fairseq/examples/MMPT/.gitignore
217
+ fairseq/examples/MMPT/CONFIG.md
218
+ fairseq/examples/MMPT/DATASET.md
219
+ fairseq/examples/MMPT/README.md
220
+ fairseq/examples/MMPT/endtask.md
221
+ fairseq/examples/MMPT/locallaunch.py
222
+ fairseq/examples/MMPT/pretraining.md
223
+ fairseq/examples/MMPT/setup.py
224
+ fairseq/examples/MMPT/videoclip.png
225
+ fairseq/examples/MMPT/vlm.png
226
+ fairseq/examples/MMPT/mmpt/__init__.py
227
+ fairseq/examples/MMPT/mmpt/datasets/__init__.py
228
+ fairseq/examples/MMPT/mmpt/datasets/fairseqmmdataset.py
229
+ fairseq/examples/MMPT/mmpt/datasets/mmdataset.py
230
+ fairseq/examples/MMPT/mmpt/evaluators/__init__.py
231
+ fairseq/examples/MMPT/mmpt/evaluators/evaluator.py
232
+ fairseq/examples/MMPT/mmpt/evaluators/metric.py
233
+ fairseq/examples/MMPT/mmpt/evaluators/predictor.py
234
+ fairseq/examples/MMPT/mmpt/losses/__init__.py
235
+ fairseq/examples/MMPT/mmpt/losses/fairseqmmloss.py
236
+ fairseq/examples/MMPT/mmpt/losses/loss.py
237
+ fairseq/examples/MMPT/mmpt/losses/nce.py
238
+ fairseq/examples/MMPT/mmpt/models/__init__.py
239
+ fairseq/examples/MMPT/mmpt/models/fairseqmmmodel.py
240
+ fairseq/examples/MMPT/mmpt/models/mmfusion.py
241
+ fairseq/examples/MMPT/mmpt/models/mmfusionnlg.py
242
+ fairseq/examples/MMPT/mmpt/models/transformermodel.py
243
+ fairseq/examples/MMPT/mmpt/modules/__init__.py
244
+ fairseq/examples/MMPT/mmpt/modules/mm.py
245
+ fairseq/examples/MMPT/mmpt/modules/retri.py
246
+ fairseq/examples/MMPT/mmpt/modules/vectorpool.py
247
+ fairseq/examples/MMPT/mmpt/processors/__init__.py
248
+ fairseq/examples/MMPT/mmpt/processors/dedupprocessor.py
249
+ fairseq/examples/MMPT/mmpt/processors/dsprocessor.py
250
+ fairseq/examples/MMPT/mmpt/processors/how2processor.py
251
+ fairseq/examples/MMPT/mmpt/processors/how2retriprocessor.py
252
+ fairseq/examples/MMPT/mmpt/processors/processor.py
253
+ fairseq/examples/MMPT/mmpt/processors/models/s3dg.py
254
+ fairseq/examples/MMPT/mmpt/tasks/__init__.py
255
+ fairseq/examples/MMPT/mmpt/tasks/fairseqmmtask.py
256
+ fairseq/examples/MMPT/mmpt/tasks/milncetask.py
257
+ fairseq/examples/MMPT/mmpt/tasks/retritask.py
258
+ fairseq/examples/MMPT/mmpt/tasks/task.py
259
+ fairseq/examples/MMPT/mmpt/tasks/vlmtask.py
260
+ fairseq/examples/MMPT/mmpt/utils/__init__.py
261
+ fairseq/examples/MMPT/mmpt/utils/load_config.py
262
+ fairseq/examples/MMPT/mmpt/utils/shardedtensor.py
263
+ fairseq/examples/MMPT/mmpt_cli/localjob.py
264
+ fairseq/examples/MMPT/mmpt_cli/predict.py
265
+ fairseq/examples/MMPT/projects/mfmmlm.yaml
266
+ fairseq/examples/MMPT/projects/mtm/mmfusionmtm.yaml
267
+ fairseq/examples/MMPT/projects/mtm/vlm.yaml
268
+ fairseq/examples/MMPT/projects/mtm/vlm/coin.yaml
269
+ fairseq/examples/MMPT/projects/mtm/vlm/crosstask.yaml
270
+ fairseq/examples/MMPT/projects/mtm/vlm/how2.yaml
271
+ fairseq/examples/MMPT/projects/mtm/vlm/test_coin.yaml
272
+ fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask.yaml
273
+ fairseq/examples/MMPT/projects/mtm/vlm/test_crosstask_zs.yaml
274
+ fairseq/examples/MMPT/projects/mtm/vlm/test_vtt.yaml
275
+ fairseq/examples/MMPT/projects/mtm/vlm/test_vttqa.yaml
276
+ fairseq/examples/MMPT/projects/mtm/vlm/test_youcook.yaml
277
+ fairseq/examples/MMPT/projects/mtm/vlm/test_youcookcap.yaml
278
+ fairseq/examples/MMPT/projects/mtm/vlm/vtt.yaml
279
+ fairseq/examples/MMPT/projects/mtm/vlm/vttqa.yaml
280
+ fairseq/examples/MMPT/projects/mtm/vlm/youcook.yaml
281
+ fairseq/examples/MMPT/projects/mtm/vlm/youcookcap.yaml
282
+ fairseq/examples/MMPT/projects/retri/videoclip.yaml
283
+ fairseq/examples/MMPT/projects/retri/videoretri.yaml
284
+ fairseq/examples/MMPT/projects/retri/videoclip/coin_videoclip.yaml
285
+ fairseq/examples/MMPT/projects/retri/videoclip/crosstask_videoclip.yaml
286
+ fairseq/examples/MMPT/projects/retri/videoclip/how2.yaml
287
+ fairseq/examples/MMPT/projects/retri/videoclip/test_coin_videoclip.yaml
288
+ fairseq/examples/MMPT/projects/retri/videoclip/test_coin_zs.yaml
289
+ fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_videoclip.yaml
290
+ fairseq/examples/MMPT/projects/retri/videoclip/test_crosstask_zs_videoclip.yaml
291
+ fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml
292
+ fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_videoclip.yaml
293
+ fairseq/examples/MMPT/projects/retri/videoclip/test_vtt_zs.yaml
294
+ fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_videoclip.yaml
295
+ fairseq/examples/MMPT/projects/retri/videoclip/test_vttqa_zs.yaml
296
+ fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_videoclip.yaml
297
+ fairseq/examples/MMPT/projects/retri/videoclip/test_youcook_zs.yaml
298
+ fairseq/examples/MMPT/projects/retri/videoclip/vtt_videoclip.yaml
299
+ fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml
300
+ fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml
301
+ fairseq/examples/MMPT/projects/task/coin.yaml
302
+ fairseq/examples/MMPT/projects/task/coin_videoclip.yaml
303
+ fairseq/examples/MMPT/projects/task/crosstask.yaml
304
+ fairseq/examples/MMPT/projects/task/crosstask_videoclip.yaml
305
+ fairseq/examples/MMPT/projects/task/default.yaml
306
+ fairseq/examples/MMPT/projects/task/ft.yaml
307
+ fairseq/examples/MMPT/projects/task/how2.yaml
308
+ fairseq/examples/MMPT/projects/task/test.yaml
309
+ fairseq/examples/MMPT/projects/task/test_coin.yaml
310
+ fairseq/examples/MMPT/projects/task/test_coin_videoclip.yaml
311
+ fairseq/examples/MMPT/projects/task/test_coin_zs.yaml
312
+ fairseq/examples/MMPT/projects/task/test_crosstask.yaml
313
+ fairseq/examples/MMPT/projects/task/test_crosstask_videoclip.yaml
314
+ fairseq/examples/MMPT/projects/task/test_crosstask_zs.yaml
315
+ fairseq/examples/MMPT/projects/task/test_crosstask_zs_videoclip.yaml
316
+ fairseq/examples/MMPT/projects/task/test_didemo_zs.yaml
317
+ fairseq/examples/MMPT/projects/task/test_vtt.yaml
318
+ fairseq/examples/MMPT/projects/task/test_vtt_videoclip.yaml
319
+ fairseq/examples/MMPT/projects/task/test_vtt_zs.yaml
320
+ fairseq/examples/MMPT/projects/task/test_vttqa.yaml
321
+ fairseq/examples/MMPT/projects/task/test_vttqa_videoclip.yaml
322
+ fairseq/examples/MMPT/projects/task/test_vttqa_zs.yaml
323
+ fairseq/examples/MMPT/projects/task/test_youcook.yaml
324
+ fairseq/examples/MMPT/projects/task/test_youcook_videoclip.yaml
325
+ fairseq/examples/MMPT/projects/task/test_youcook_zs.yaml
326
+ fairseq/examples/MMPT/projects/task/test_youcookcap.yaml
327
+ fairseq/examples/MMPT/projects/task/vtt.yaml
328
+ fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml
329
+ fairseq/examples/MMPT/projects/task/vttqa.yaml
330
+ fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml
331
+ fairseq/examples/MMPT/projects/task/youcook.yaml
332
+ fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml
333
+ fairseq/examples/MMPT/projects/task/youcookcap.yaml
334
+ fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py
335
+ fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml
336
+ fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py
337
+ fairseq/examples/MMPT/scripts/video_feature_extractor/model.py
338
+ fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py
339
+ fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py
340
+ fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py
341
+ fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py
342
+ fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py
343
+ fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh
344
+ fairseq/examples/adaptive_span/README.md
345
+ fairseq/examples/adaptive_span/__init__.py
346
+ fairseq/examples/adaptive_span/adagrad_with_grad_clip.py
347
+ fairseq/examples/adaptive_span/adaptive_span_attention.py
348
+ fairseq/examples/adaptive_span/adaptive_span_loss.py
349
+ fairseq/examples/adaptive_span/adaptive_span_model.py
350
+ fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py
351
+ fairseq/examples/adaptive_span/truncated_bptt_lm_task.py
352
+ fairseq/examples/attention_head_selection/README.md
353
+ fairseq/examples/attention_head_selection/src/__init__.py
354
+ fairseq/examples/attention_head_selection/src/speech_to_text_head_selection.py
355
+ fairseq/examples/attention_head_selection/src/data/__init__.py
356
+ fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py
357
+ fairseq/examples/attention_head_selection/src/loss/__init__.py
358
+ fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py
359
+ fairseq/examples/attention_head_selection/src/models/__init__.py
360
+ fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py
361
+ fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py
362
+ fairseq/examples/attention_head_selection/src/modules/__init__.py
363
+ fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py
364
+ fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py
365
+ fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py
366
+ fairseq/examples/attention_head_selection/src/modules/multihead_functional.py
367
+ fairseq/examples/audio_nlp/nlu/README.md
368
+ fairseq/examples/audio_nlp/nlu/create_dict_stop.sh
369
+ fairseq/examples/audio_nlp/nlu/generate_manifests.py
370
+ fairseq/examples/audio_nlp/nlu/configs/nlu_finetuning.yaml
371
+ fairseq/examples/backtranslation/README.md
372
+ fairseq/examples/backtranslation/deduplicate_lines.py
373
+ fairseq/examples/backtranslation/extract_bt_data.py
374
+ fairseq/examples/backtranslation/prepare-de-monolingual.sh
375
+ fairseq/examples/backtranslation/prepare-wmt18en2de.sh
376
+ fairseq/examples/backtranslation/sacrebleu.sh
377
+ fairseq/examples/backtranslation/tokenized_bleu.sh
378
+ fairseq/examples/bart/README.glue.md
379
+ fairseq/examples/bart/README.md
380
+ fairseq/examples/bart/README.summarization.md
381
+ fairseq/examples/bart/summarize.py
382
+ fairseq/examples/byte_level_bpe/README.md
383
+ fairseq/examples/byte_level_bpe/get_bitext.py
384
+ fairseq/examples/byte_level_bpe/get_data.sh
385
+ fairseq/examples/byte_level_bpe/gru_transformer.py
386
+ fairseq/examples/camembert/README.md
387
+ fairseq/examples/constrained_decoding/README.md
388
+ fairseq/examples/constrained_decoding/normalize.py
389
+ fairseq/examples/constrained_decoding/tok.py
390
+ fairseq/examples/conv_seq2seq/README.md
391
+ fairseq/examples/criss/README.md
392
+ fairseq/examples/criss/download_and_preprocess_flores_test.sh
393
+ fairseq/examples/criss/download_and_preprocess_tatoeba.sh
394
+ fairseq/examples/criss/save_encoder.py
395
+ fairseq/examples/criss/mining/mine.py
396
+ fairseq/examples/criss/mining/mine_example.sh
397
+ fairseq/examples/criss/sentence_retrieval/encoder_analysis.py
398
+ fairseq/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh
399
+ fairseq/examples/criss/unsupervised_mt/eval.sh
400
+ fairseq/examples/cross_lingual_language_model/README.md
401
+ fairseq/examples/data2vec/README.md
402
+ fairseq/examples/data2vec/__init__.py
403
+ fairseq/examples/data2vec/fb_convert_beit_cp.py
404
+ fairseq/examples/data2vec/config/audio/classification/base_classification.yaml
405
+ fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1.yaml
406
+ fairseq/examples/data2vec/config/audio/classification/run_config/slurm_1g.yaml
407
+ fairseq/examples/data2vec/config/audio/classification/run_config/slurm_2.yaml
408
+ fairseq/examples/data2vec/config/audio/pretraining/audioset.yaml
409
+ fairseq/examples/data2vec/config/audio/pretraining/base_librispeech.yaml
410
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/local.yaml
411
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1.yaml
412
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_1_aws.yaml
413
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2.yaml
414
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_2_aws.yaml
415
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_3.yaml
416
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4.yaml
417
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_4_aws.yaml
418
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_6_aws.yaml
419
+ fairseq/examples/data2vec/config/audio/pretraining/run_config/slurm_8_aws.yaml
420
+ fairseq/examples/data2vec/config/text/pretraining/base.yaml
421
+ fairseq/examples/data2vec/config/text/pretraining/run_config/local.yaml
422
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_1_aws.yaml
423
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2.yaml
424
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_2_aws.yaml
425
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_3.yaml
426
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4.yaml
427
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_4_aws.yaml
428
+ fairseq/examples/data2vec/config/text/pretraining/run_config/slurm_8_aws.yaml
429
+ fairseq/examples/data2vec/config/v2/base_audio_only_task.yaml
430
+ fairseq/examples/data2vec/config/v2/base_images_only_task.yaml
431
+ fairseq/examples/data2vec/config/v2/base_text_only_task.yaml
432
+ fairseq/examples/data2vec/config/v2/huge_images14_only_task.yaml
433
+ fairseq/examples/data2vec/config/v2/huge_images_only_task.yaml
434
+ fairseq/examples/data2vec/config/v2/large_audio_only_task.yaml
435
+ fairseq/examples/data2vec/config/v2/large_images_only_task.yaml
436
+ fairseq/examples/data2vec/config/v2/large_text_only_task.yaml
437
+ fairseq/examples/data2vec/config/v2/large_text_only_task_pgrp_1M.yaml
438
+ fairseq/examples/data2vec/config/v2/run_config/local.yaml
439
+ fairseq/examples/data2vec/config/v2/run_config/slurm_1.yaml
440
+ fairseq/examples/data2vec/config/v2/run_config/slurm_1_aws.yaml
441
+ fairseq/examples/data2vec/config/v2/run_config/slurm_2.yaml
442
+ fairseq/examples/data2vec/config/v2/run_config/slurm_2_aws.yaml
443
+ fairseq/examples/data2vec/config/v2/run_config/slurm_3.yaml
444
+ fairseq/examples/data2vec/config/v2/run_config/slurm_4.yaml
445
+ fairseq/examples/data2vec/config/v2/run_config/slurm_4_aws.yaml
446
+ fairseq/examples/data2vec/config/v2/run_config/slurm_6_aws.yaml
447
+ fairseq/examples/data2vec/config/v2/run_config/slurm_8.yaml
448
+ fairseq/examples/data2vec/config/v2/run_config/slurm_8_aws.yaml
449
+ fairseq/examples/data2vec/config/v2/text_finetuning/cola.yaml
450
+ fairseq/examples/data2vec/config/v2/text_finetuning/mnli.yaml
451
+ fairseq/examples/data2vec/config/v2/text_finetuning/mrpc.yaml
452
+ fairseq/examples/data2vec/config/v2/text_finetuning/qnli.yaml
453
+ fairseq/examples/data2vec/config/v2/text_finetuning/qqp.yaml
454
+ fairseq/examples/data2vec/config/v2/text_finetuning/rte.yaml
455
+ fairseq/examples/data2vec/config/v2/text_finetuning/sst_2.yaml
456
+ fairseq/examples/data2vec/config/v2/text_finetuning/sts_b.yaml
457
+ fairseq/examples/data2vec/config/v2/text_finetuning/run_config/local.yaml
458
+ fairseq/examples/data2vec/config/vision/finetuning/imagenet.yaml
459
+ fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_clean.yaml
460
+ fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_huge_clean.yaml
461
+ fairseq/examples/data2vec/config/vision/finetuning/mae_imagenet_large_clean.yaml
462
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/local.yaml
463
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1.yaml
464
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
465
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
466
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
467
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
468
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4.yaml
469
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
470
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
471
+ fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
472
+ fairseq/examples/data2vec/config/vision/pretraining/base_imagenet.yaml
473
+ fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
474
+ fairseq/examples/data2vec/config/vision/pretraining/base_mae_imagenet.yaml
475
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml
476
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
477
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1_aws.yaml
478
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2.yaml
479
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
480
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
481
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
482
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
483
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
484
+ fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
485
+ fairseq/examples/data2vec/data/__init__.py
486
+ fairseq/examples/data2vec/data/add_class_target_dataset.py
487
+ fairseq/examples/data2vec/data/image_dataset.py
488
+ fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py
489
+ fairseq/examples/data2vec/data/mae_image_dataset.py
490
+ fairseq/examples/data2vec/data/modality.py
491
+ fairseq/examples/data2vec/data/path_dataset.py
492
+ fairseq/examples/data2vec/models/__init__.py
493
+ fairseq/examples/data2vec/models/audio_classification.py
494
+ fairseq/examples/data2vec/models/data2vec2.py
495
+ fairseq/examples/data2vec/models/data2vec_audio.py
496
+ fairseq/examples/data2vec/models/data2vec_image_classification.py
497
+ fairseq/examples/data2vec/models/data2vec_text.py
498
+ fairseq/examples/data2vec/models/data2vec_text_classification.py
499
+ fairseq/examples/data2vec/models/data2vec_vision.py
500
+ fairseq/examples/data2vec/models/mae.py
501
+ fairseq/examples/data2vec/models/mae_image_classification.py
502
+ fairseq/examples/data2vec/models/utils.py
503
+ fairseq/examples/data2vec/models/modalities/__init__.py
504
+ fairseq/examples/data2vec/models/modalities/audio.py
505
+ fairseq/examples/data2vec/models/modalities/base.py
506
+ fairseq/examples/data2vec/models/modalities/images.py
507
+ fairseq/examples/data2vec/models/modalities/modules.py
508
+ fairseq/examples/data2vec/models/modalities/text.py
509
+ fairseq/examples/data2vec/scripts/convert_audioset_labels.py
510
+ fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
511
+ fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
512
+ fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
513
+ fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
514
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh
515
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
516
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
517
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
518
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
519
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep.sh
520
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws.sh
521
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_local_lr.sh
522
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr.sh
523
+ fairseq/examples/data2vec/scripts/text/finetune_all_fair_nodep_aws_lr_nopos.sh
524
+ fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_aws_local_lr.sh
525
+ fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_local_lr.sh
526
+ fairseq/examples/data2vec/scripts/text/finetune_all_large_fair_nodep_aws_local_lr.sh
527
+ fairseq/examples/data2vec/scripts/text/finetune_sst2_qnli_sweep_fair_nodep.sh
528
+ fairseq/examples/data2vec/scripts/text/glue.py
529
+ fairseq/examples/data2vec/scripts/text/glue_lr.py
530
+ fairseq/examples/data2vec/scripts/text/unprocess_data.py
531
+ fairseq/examples/data2vec/scripts/text/valids.py
532
+ fairseq/examples/data2vec/tasks/__init__.py
533
+ fairseq/examples/data2vec/tasks/audio_classification.py
534
+ fairseq/examples/data2vec/tasks/image_classification.py
535
+ fairseq/examples/data2vec/tasks/image_pretraining.py
536
+ fairseq/examples/data2vec/tasks/mae_image_classification.py
537
+ fairseq/examples/data2vec/tasks/mae_image_pretraining.py
538
+ fairseq/examples/data2vec/tasks/multimodal.py
539
+ fairseq/examples/discriminative_reranking_nmt/README.md
540
+ fairseq/examples/discriminative_reranking_nmt/__init__.py
541
+ fairseq/examples/discriminative_reranking_nmt/drnmt_rerank.py
542
+ fairseq/examples/discriminative_reranking_nmt/config/deen.yaml
543
+ fairseq/examples/discriminative_reranking_nmt/criterions/__init__.py
544
+ fairseq/examples/discriminative_reranking_nmt/criterions/discriminative_reranking_criterion.py
545
+ fairseq/examples/discriminative_reranking_nmt/models/__init__.py
546
+ fairseq/examples/discriminative_reranking_nmt/models/discriminative_reranking_model.py
547
+ fairseq/examples/discriminative_reranking_nmt/scripts/prep_data.py
548
+ fairseq/examples/discriminative_reranking_nmt/tasks/__init__.py
549
+ fairseq/examples/discriminative_reranking_nmt/tasks/discriminative_reranking_task.py
550
+ fairseq/examples/emotion_conversion/README.md
551
+ fairseq/examples/emotion_conversion/requirements.txt
552
+ fairseq/examples/emotion_conversion/synthesize.py
553
+ fairseq/examples/emotion_conversion/emotion_models/__init__.py
554
+ fairseq/examples/emotion_conversion/emotion_models/duration_predictor.py
555
+ fairseq/examples/emotion_conversion/emotion_models/duration_predictor.yaml
556
+ fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.py
557
+ fairseq/examples/emotion_conversion/emotion_models/pitch_predictor.yaml
558
+ fairseq/examples/emotion_conversion/emotion_models/utils.py
559
+ fairseq/examples/emotion_conversion/fairseq_models/__init__.py
560
+ fairseq/examples/emotion_conversion/preprocess/__init__.py
561
+ fairseq/examples/emotion_conversion/preprocess/build_hifigan_manifest.py
562
+ fairseq/examples/emotion_conversion/preprocess/build_translation_manifests.py
563
+ fairseq/examples/emotion_conversion/preprocess/create_core_manifest.py
564
+ fairseq/examples/emotion_conversion/preprocess/extract_f0.py
565
+ fairseq/examples/emotion_conversion/preprocess/process_km.py
566
+ fairseq/examples/emotion_conversion/preprocess/split_emov_km_tsv_by_uttid.py
567
+ fairseq/examples/emotion_conversion/preprocess/split_km.py
568
+ fairseq/examples/emotion_conversion/preprocess/split_km_tsv.py
569
+ fairseq/examples/fast_noisy_channel/README.md
570
+ fairseq/examples/fast_noisy_channel/__init__.py
571
+ fairseq/examples/fast_noisy_channel/noisy_channel_beam_search.py
572
+ fairseq/examples/fast_noisy_channel/noisy_channel_sequence_generator.py
573
+ fairseq/examples/fast_noisy_channel/noisy_channel_translation.py
574
+ fairseq/examples/flores101/README.md
575
+ fairseq/examples/flores101/flores_logo.png
576
+ fairseq/examples/fully_sharded_data_parallel/README.md
577
+ fairseq/examples/gottbert/README.md
578
+ fairseq/examples/hubert/README.md
579
+ fairseq/examples/hubert/measure_teacher_quality.py
580
+ fairseq/examples/hubert/update_ckpt.py
581
+ fairseq/examples/hubert/config/decode/infer_fsqlm.yaml
582
+ fairseq/examples/hubert/config/decode/infer_kenlm.yaml
583
+ fairseq/examples/hubert/config/decode/infer_viterbi.yaml
584
+ fairseq/examples/hubert/config/decode/ax_sweep/ngram.yaml
585
+ fairseq/examples/hubert/config/decode/ax_sweep/transformer.yaml
586
+ fairseq/examples/hubert/config/decode/run/submitit_slurm.yaml
587
+ fairseq/examples/hubert/config/decode/run/submitit_slurm_8gpu.yaml
588
+ fairseq/examples/hubert/config/finetune/base_10h.yaml
589
+ fairseq/examples/hubert/config/finetune/ckpt/it1.yaml
590
+ fairseq/examples/hubert/config/finetune/lm/ls_4gram.yaml
591
+ fairseq/examples/hubert/config/finetune/run/submitit_reg.yaml
592
+ fairseq/examples/hubert/config/pretrain/hubert_base_librispeech.yaml
593
+ fairseq/examples/hubert/config/pretrain/hubert_large_librivox.yaml
594
+ fairseq/examples/hubert/config/pretrain/hubert_xlarge_librivox.yaml
595
+ fairseq/examples/hubert/config/pretrain/data/iter1.yaml
596
+ fairseq/examples/hubert/config/pretrain/data/iter2.yaml
597
+ fairseq/examples/hubert/config/pretrain/run/submitit_reg.yaml
598
+ fairseq/examples/hubert/simple_kmeans/README.md
599
+ fairseq/examples/hubert/simple_kmeans/dump_hubert_feature.py
600
+ fairseq/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py
601
+ fairseq/examples/hubert/simple_kmeans/dump_km_label.py
602
+ fairseq/examples/hubert/simple_kmeans/dump_mfcc_feature.py
603
+ fairseq/examples/hubert/simple_kmeans/dump_w2v2_feature.py
604
+ fairseq/examples/hubert/simple_kmeans/feature_utils.py
605
+ fairseq/examples/hubert/simple_kmeans/learn_kmeans.py
606
+ fairseq/examples/hubert/tests/6313-76958-0021.flac
607
+ fairseq/examples/hubert/tests/sample.base.L9.km500.km
608
+ fairseq/examples/hubert/tests/sample.base.L9.len
609
+ fairseq/examples/hubert/tests/sample.base.L9.npy
610
+ fairseq/examples/hubert/tests/sample.large.L20.len
611
+ fairseq/examples/hubert/tests/sample.large.L20.npy
612
+ fairseq/examples/hubert/tests/sample.large.hypo.word
613
+ fairseq/examples/hubert/tests/sample.xlarge.L30.len
614
+ fairseq/examples/hubert/tests/sample.xlarge.L30.npy
615
+ fairseq/examples/hubert/tests/sample.xlarge.hypo.word
616
+ fairseq/examples/hubert/tests/test_feature_and_unit.sh
617
+ fairseq/examples/hubert/tests/test_finetuned_asr.sh
618
+ fairseq/examples/joint_alignment_translation/README.md
619
+ fairseq/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
620
+ fairseq/examples/language_model/README.adaptive_inputs.md
621
+ fairseq/examples/language_model/README.conv.md
622
+ fairseq/examples/language_model/README.md
623
+ fairseq/examples/language_model/prepare-wikitext-103.sh
624
+ fairseq/examples/laser/README.md
625
+ fairseq/examples/laser/laser_src/__init__.py
626
+ fairseq/examples/laser/laser_src/laser_lstm.py
627
+ fairseq/examples/laser/laser_src/laser_task.py
628
+ fairseq/examples/laser/laser_src/laser_transformer.py
629
+ fairseq/examples/laser/laser_src/multitask_data_utils.py
630
+ fairseq/examples/latent_depth/README.md
631
+ fairseq/examples/latent_depth/latent_depth_src/__init__.py
632
+ fairseq/examples/latent_depth/latent_depth_src/multilingual_translation_latent_depth.py
633
+ fairseq/examples/latent_depth/latent_depth_src/loss/__init__.py
634
+ fairseq/examples/latent_depth/latent_depth_src/loss/latent_depth.py
635
+ fairseq/examples/latent_depth/latent_depth_src/models/__init__.py
636
+ fairseq/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py
637
+ fairseq/examples/latent_depth/latent_depth_src/models/latent_transformer.py
638
+ fairseq/examples/latent_depth/latent_depth_src/modules/__init__.py
639
+ fairseq/examples/latent_depth/latent_depth_src/modules/latent_layers.py
640
+ fairseq/examples/layerdrop/README.md
641
+ fairseq/examples/linformer/README.md
642
+ fairseq/examples/linformer/linformer_src/__init__.py
643
+ fairseq/examples/linformer/linformer_src/models/__init__.py
644
+ fairseq/examples/linformer/linformer_src/models/linformer_roberta.py
645
+ fairseq/examples/linformer/linformer_src/modules/__init__.py
646
+ fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py
647
+ fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py
648
+ fairseq/examples/linformer/linformer_src/modules/multihead_linear_attention.py
649
+ fairseq/examples/m2m_100/README.md
650
+ fairseq/examples/m2m_100/install_dependecies.sh
651
+ fairseq/examples/m2m_100/tok.sh
652
+ fairseq/examples/m2m_100/process_data/clean_histogram.py
653
+ fairseq/examples/m2m_100/process_data/dedup_data.py
654
+ fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
655
+ fairseq/examples/m2m_100/tokenizers/README.md
656
+ fairseq/examples/m2m_100/tokenizers/seg_ja.sh
657
+ fairseq/examples/m2m_100/tokenizers/seg_ko.sh
658
+ fairseq/examples/m2m_100/tokenizers/tokenize_indic.py
659
+ fairseq/examples/m2m_100/tokenizers/tokenize_thai.py
660
+ fairseq/examples/m2m_100/tokenizers/tokenize_zh.py
661
+ fairseq/examples/m2m_100/tokenizers/tokenizer_ar.sh
662
+ fairseq/examples/m2m_100/tokenizers/thirdparty/.gitignore
663
+ fairseq/examples/mbart/README.md
664
+ fairseq/examples/megatron_11b/README.md
665
+ fairseq/examples/megatron_11b/detok.py
666
+ fairseq/examples/mms/MODEL_CARD.md
667
+ fairseq/examples/mms/README.md
668
+ fairseq/examples/mms/asr/config/infer_common.yaml
669
+ fairseq/examples/mms/asr/infer/example_infer_adapter.sh
670
+ fairseq/examples/mms/asr/infer/mms_infer.py
671
+ fairseq/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb
672
+ fairseq/examples/mms/data_prep/README.md
673
+ fairseq/examples/mms/data_prep/align_and_segment.py
674
+ fairseq/examples/mms/data_prep/align_utils.py
675
+ fairseq/examples/mms/data_prep/norm_config.py
676
+ fairseq/examples/mms/data_prep/punctuations.lst
677
+ fairseq/examples/mms/data_prep/text_normalization.py
678
+ fairseq/examples/mms/lid/infer.py
679
+ fairseq/examples/mms/lid/tutorial/MMS_LID_Inference_Colab.ipynb
680
+ fairseq/examples/mms/lid_rerank/README.md
681
+ fairseq/examples/mms/lid_rerank/cer_langs.txt
682
+ fairseq/examples/mms/lid_rerank/requirements.txt
683
+ fairseq/examples/mms/lid_rerank/mala/infer.py
684
+ fairseq/examples/mms/lid_rerank/mms/make_parallel_single_runs.py
685
+ fairseq/examples/mms/lid_rerank/mms/merge_by_lang.py
686
+ fairseq/examples/mms/lid_rerank/mms/prep_wav_list.py
687
+ fairseq/examples/mms/lid_rerank/mms/run_single_lang.py
688
+ fairseq/examples/mms/lid_rerank/mms/split_by_lang.py
689
+ fairseq/examples/mms/lid_rerank/mms-zs/falign.py
690
+ fairseq/examples/mms/lid_rerank/mms-zs/lib.py
691
+ fairseq/examples/mms/lid_rerank/mms-zs/uromanize.py
692
+ fairseq/examples/mms/lid_rerank/nllb/infer.py
693
+ fairseq/examples/mms/lid_rerank/rerank/rerank.py
694
+ fairseq/examples/mms/lid_rerank/rerank/tune_coefficients.py
695
+ fairseq/examples/mms/lid_rerank/whisper/infer_asr.py
696
+ fairseq/examples/mms/lid_rerank/whisper/infer_lid.py
697
+ fairseq/examples/mms/lid_rerank/whisper/lid_mapping.txt
698
+ fairseq/examples/mms/misc/get_sample_size.py
699
+ fairseq/examples/mms/tts/infer.py
700
+ fairseq/examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb
701
+ fairseq/examples/mms/zero_shot/README.md
702
+ fairseq/examples/moe_lm/README.md
703
+ fairseq/examples/moe_lm/data_card.md
704
+ fairseq/examples/moe_lm/model_card.md
705
+ fairseq/examples/mr_hubert/README.md
706
+ fairseq/examples/mr_hubert/decode.sh
707
+ fairseq/examples/mr_hubert/finetune.sh
708
+ fairseq/examples/mr_hubert/train.sh
709
+ fairseq/examples/mr_hubert/config/decode/infer.yaml
710
+ fairseq/examples/mr_hubert/config/decode/infer_lm.yaml
711
+ fairseq/examples/mr_hubert/config/decode/run/submitit_slurm.yaml
712
+ fairseq/examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml
713
+ fairseq/examples/mr_hubert/config/finetune/base_100h.yaml
714
+ fairseq/examples/mr_hubert/config/finetune/base_100h_large.yaml
715
+ fairseq/examples/mr_hubert/config/finetune/base_10h.yaml
716
+ fairseq/examples/mr_hubert/config/finetune/base_10h_large.yaml
717
+ fairseq/examples/mr_hubert/config/finetune/base_1h.yaml
718
+ fairseq/examples/mr_hubert/config/finetune/base_1h_large.yaml
719
+ fairseq/examples/mr_hubert/config/pretrain/mrhubert_base_librispeech.yaml
720
+ fairseq/examples/mr_hubert/config/pretrain/mrhubert_large_librilight.yaml
721
+ fairseq/examples/mr_hubert/config/pretrain/run/submitit_reg.yaml
722
+ fairseq/examples/mr_hubert/simple_kmeans/README.md
723
+ fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature.py
724
+ fairseq/examples/mr_hubert/simple_kmeans/dump_hubert_feature_s2t.py
725
+ fairseq/examples/mr_hubert/simple_kmeans/dump_km_label.py
726
+ fairseq/examples/mr_hubert/simple_kmeans/dump_mfcc_feature.py
727
+ fairseq/examples/mr_hubert/simple_kmeans/dump_w2v2_feature.py
728
+ fairseq/examples/mr_hubert/simple_kmeans/feature_utils.py
729
+ fairseq/examples/mr_hubert/simple_kmeans/learn_kmeans.py
730
+ fairseq/examples/multilingual/ML50_langs.txt
731
+ fairseq/examples/multilingual/README.md
732
+ fairseq/examples/multilingual/finetune_multilingual_model.sh
733
+ fairseq/examples/multilingual/multilingual_fairseq_gen.sh
734
+ fairseq/examples/multilingual/train_multilingual_model.sh
735
+ fairseq/examples/multilingual/data_scripts/README.md
736
+ fairseq/examples/multilingual/data_scripts/binarize.py
737
+ fairseq/examples/multilingual/data_scripts/check_iswlt_test_data.py
738
+ fairseq/examples/multilingual/data_scripts/check_self_overlaps.py
739
+ fairseq/examples/multilingual/data_scripts/check_valid_test_overlaps.py
740
+ fairseq/examples/multilingual/data_scripts/dedup_all.py
741
+ fairseq/examples/multilingual/data_scripts/download_ML50_v1.sh
742
+ fairseq/examples/multilingual/data_scripts/download_af_xh.sh
743
+ fairseq/examples/multilingual/data_scripts/download_flores_data.sh
744
+ fairseq/examples/multilingual/data_scripts/download_iitb.sh
745
+ fairseq/examples/multilingual/data_scripts/download_iwslt_and_extract.sh
746
+ fairseq/examples/multilingual/data_scripts/download_lotus.sh
747
+ fairseq/examples/multilingual/data_scripts/download_ted_and_extract.py
748
+ fairseq/examples/multilingual/data_scripts/download_wat19_my.sh
749
+ fairseq/examples/multilingual/data_scripts/download_wmt19_and_before.py
750
+ fairseq/examples/multilingual/data_scripts/download_wmt20.sh
751
+ fairseq/examples/multilingual/data_scripts/preprocess_ML50_v1.sh
752
+ fairseq/examples/multilingual/data_scripts/remove_valid_test_in_train.py
753
+ fairseq/examples/multilingual/data_scripts/requirement.txt
754
+ fairseq/examples/multilingual/data_scripts/utils/dedup.py
755
+ fairseq/examples/multilingual/data_scripts/utils/fasttext_multi_filter.py
756
+ fairseq/examples/multilingual/data_scripts/utils/strip_sgm.sh
757
+ fairseq/examples/noisychannel/README.md
758
+ fairseq/examples/noisychannel/__init__.py
759
+ fairseq/examples/noisychannel/rerank.py
760
+ fairseq/examples/noisychannel/rerank_generate.py
761
+ fairseq/examples/noisychannel/rerank_options.py
762
+ fairseq/examples/noisychannel/rerank_score_bw.py
763
+ fairseq/examples/noisychannel/rerank_score_lm.py
764
+ fairseq/examples/noisychannel/rerank_tune.py
765
+ fairseq/examples/noisychannel/rerank_utils.py
766
+ fairseq/examples/nonautoregressive_translation/README.md
767
+ fairseq/examples/nonautoregressive_translation/scripts.md
768
+ fairseq/examples/normformer/README.md
769
+ fairseq/examples/normformer/train_lm.sh
770
+ fairseq/examples/operators/alignment_train_cpu.cpp
771
+ fairseq/examples/operators/alignment_train_cuda.cpp
772
+ fairseq/examples/operators/alignment_train_cuda.h
773
+ fairseq/examples/operators/alignment_train_kernel.cu
774
+ fairseq/examples/operators/utils.h
775
+ fairseq/examples/paraphraser/README.md
776
+ fairseq/examples/paraphraser/paraphrase.py
777
+ fairseq/examples/pay_less_attention_paper/README.md
778
+ fairseq/examples/pointer_generator/README.md
779
+ fairseq/examples/pointer_generator/README.xsum.md
780
+ fairseq/examples/pointer_generator/postprocess.py
781
+ fairseq/examples/pointer_generator/preprocess.py
782
+ fairseq/examples/pointer_generator/pointer_generator_src/__init__.py
783
+ fairseq/examples/pointer_generator/pointer_generator_src/transformer_pg.py
784
+ fairseq/examples/quant_noise/README.md
785
+ fairseq/examples/quant_noise/transformer_quantization_config.yaml
786
+ fairseq/examples/roberta/README.custom_classification.md
787
+ fairseq/examples/roberta/README.glue.md
788
+ fairseq/examples/roberta/README.md
789
+ fairseq/examples/roberta/README.pretraining.md
790
+ fairseq/examples/roberta/README.race.md
791
+ fairseq/examples/roberta/multiprocessing_bpe_encoder.py
792
+ fairseq/examples/roberta/preprocess_GLUE_tasks.sh
793
+ fairseq/examples/roberta/preprocess_RACE.py
794
+ fairseq/examples/roberta/preprocess_RACE.sh
795
+ fairseq/examples/roberta/commonsense_qa/README.md
796
+ fairseq/examples/roberta/commonsense_qa/__init__.py
797
+ fairseq/examples/roberta/commonsense_qa/commonsense_qa_task.py
798
+ fairseq/examples/roberta/commonsense_qa/download_cqa_data.sh
799
+ fairseq/examples/roberta/config/finetuning/cola.yaml
800
+ fairseq/examples/roberta/config/finetuning/mnli.yaml
801
+ fairseq/examples/roberta/config/finetuning/mrpc.yaml
802
+ fairseq/examples/roberta/config/finetuning/qnli.yaml
803
+ fairseq/examples/roberta/config/finetuning/qqp.yaml
804
+ fairseq/examples/roberta/config/finetuning/rte.yaml
805
+ fairseq/examples/roberta/config/finetuning/sst_2.yaml
806
+ fairseq/examples/roberta/config/finetuning/sts_b.yaml
807
+ fairseq/examples/roberta/config/finetuning/run_config/local.yaml
808
+ fairseq/examples/roberta/config/finetuning/run_config/slurm_1g.yaml
809
+ fairseq/examples/roberta/config/finetuning/run_config/slurm_1g_aws.yaml
810
+ fairseq/examples/roberta/config/pretraining/base.yaml
811
+ fairseq/examples/roberta/config/pretraining/run_config/local.yaml
812
+ fairseq/examples/roberta/config/pretraining/run_config/slurm_2.yaml
813
+ fairseq/examples/roberta/config/pretraining/run_config/slurm_2_aws.yaml
814
+ fairseq/examples/roberta/config/pretraining/run_config/slurm_3.yaml
815
+ fairseq/examples/roberta/config/pretraining/run_config/slurm_4.yaml
816
+ fairseq/examples/roberta/fb_multilingual/README.multilingual.pretraining.md
817
+ fairseq/examples/roberta/wsc/README.md
818
+ fairseq/examples/roberta/wsc/__init__.py
819
+ fairseq/examples/roberta/wsc/wsc_criterion.py
820
+ fairseq/examples/roberta/wsc/wsc_task.py
821
+ fairseq/examples/roberta/wsc/wsc_utils.py
822
+ fairseq/examples/rxf/README.md
823
+ fairseq/examples/rxf/__init__.py
824
+ fairseq/examples/rxf/rxf_src/__init__.py
825
+ fairseq/examples/rxf/rxf_src/label_smoothed_cross_entropy_r3f.py
826
+ fairseq/examples/rxf/rxf_src/sentence_prediction_r3f.py
827
+ fairseq/examples/scaling_nmt/README.md
828
+ fairseq/examples/shuffled_word_order/README.finetuning.md
829
+ fairseq/examples/shuffled_word_order/README.md
830
+ fairseq/examples/simultaneous_translation/README.md
831
+ fairseq/examples/simultaneous_translation/__init__.py
832
+ fairseq/examples/simultaneous_translation/docs/ende-mma.md
833
+ fairseq/examples/simultaneous_translation/docs/enja-waitk.md
834
+ fairseq/examples/simultaneous_translation/eval/agents/simul_t2t_enja.py
835
+ fairseq/examples/simultaneous_translation/models/__init__.py
836
+ fairseq/examples/simultaneous_translation/models/convtransformer_simul_trans.py
837
+ fairseq/examples/simultaneous_translation/models/transformer_monotonic_attention.py
838
+ fairseq/examples/simultaneous_translation/modules/__init__.py
839
+ fairseq/examples/simultaneous_translation/modules/fixed_pre_decision.py
840
+ fairseq/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
841
+ fairseq/examples/simultaneous_translation/modules/monotonic_transformer_layer.py
842
+ fairseq/examples/simultaneous_translation/tests/test_alignment_train.py
843
+ fairseq/examples/simultaneous_translation/tests/test_text_models.py
844
+ fairseq/examples/simultaneous_translation/utils/__init__.py
845
+ fairseq/examples/simultaneous_translation/utils/functions.py
846
+ fairseq/examples/simultaneous_translation/utils/monotonic_attention.py
847
+ fairseq/examples/simultaneous_translation/utils/p_choose_strategy.py
848
+ fairseq/examples/speech_recognition/README.md
849
+ fairseq/examples/speech_recognition/__init__.py
850
+ fairseq/examples/speech_recognition/infer.py
851
+ fairseq/examples/speech_recognition/w2l_decoder.py
852
+ fairseq/examples/speech_recognition/criterions/ASG_loss.py
853
+ fairseq/examples/speech_recognition/criterions/__init__.py
854
+ fairseq/examples/speech_recognition/criterions/cross_entropy_acc.py
855
+ fairseq/examples/speech_recognition/data/__init__.py
856
+ fairseq/examples/speech_recognition/data/asr_dataset.py
857
+ fairseq/examples/speech_recognition/data/collaters.py
858
+ fairseq/examples/speech_recognition/data/data_utils.py
859
+ fairseq/examples/speech_recognition/data/replabels.py
860
+ fairseq/examples/speech_recognition/datasets/asr_prep_json.py
861
+ fairseq/examples/speech_recognition/datasets/prepare-librispeech.sh
862
+ fairseq/examples/speech_recognition/kaldi/__init__.py
863
+ fairseq/examples/speech_recognition/kaldi/add-self-loop-simple.cc
864
+ fairseq/examples/speech_recognition/kaldi/kaldi_decoder.py
865
+ fairseq/examples/speech_recognition/kaldi/kaldi_initializer.py
866
+ fairseq/examples/speech_recognition/kaldi/config/kaldi_initializer.yaml
867
+ fairseq/examples/speech_recognition/models/__init__.py
868
+ fairseq/examples/speech_recognition/models/vggtransformer.py
869
+ fairseq/examples/speech_recognition/models/w2l_conv_glu_enc.py
870
+ fairseq/examples/speech_recognition/new/README.md
871
+ fairseq/examples/speech_recognition/new/__init__.py
872
+ fairseq/examples/speech_recognition/new/infer.py
873
+ fairseq/examples/speech_recognition/new/conf/infer.yaml
874
+ fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax.yaml
875
+ fairseq/examples/speech_recognition/new/conf/hydra/sweeper/ax_sil.yaml
876
+ fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_1.yaml
877
+ fairseq/examples/speech_recognition/new/conf/run_config/fb_slurm_2g.yaml
878
+ fairseq/examples/speech_recognition/new/decoders/__init__.py
879
+ fairseq/examples/speech_recognition/new/decoders/base_decoder.py
880
+ fairseq/examples/speech_recognition/new/decoders/decoder.py
881
+ fairseq/examples/speech_recognition/new/decoders/decoder_config.py
882
+ fairseq/examples/speech_recognition/new/decoders/flashlight_decoder.py
883
+ fairseq/examples/speech_recognition/new/decoders/viterbi_decoder.py
884
+ fairseq/examples/speech_recognition/tasks/__init__.py
885
+ fairseq/examples/speech_recognition/tasks/speech_recognition.py
886
+ fairseq/examples/speech_recognition/utils/wer_utils.py
887
+ fairseq/examples/speech_synthesis/README.md
888
+ fairseq/examples/speech_synthesis/__init__.py
889
+ fairseq/examples/speech_synthesis/data_utils.py
890
+ fairseq/examples/speech_synthesis/generate_waveform.py
891
+ fairseq/examples/speech_synthesis/utils.py
892
+ fairseq/examples/speech_synthesis/docs/common_voice_example.md
893
+ fairseq/examples/speech_synthesis/docs/ljspeech_example.md
894
+ fairseq/examples/speech_synthesis/docs/vctk_example.md
895
+ fairseq/examples/speech_synthesis/evaluation/__init__.py
896
+ fairseq/examples/speech_synthesis/evaluation/eval_asr.py
897
+ fairseq/examples/speech_synthesis/evaluation/eval_f0.py
898
+ fairseq/examples/speech_synthesis/evaluation/eval_sp.py
899
+ fairseq/examples/speech_synthesis/evaluation/get_eval_manifest.py
900
+ fairseq/examples/speech_synthesis/preprocessing/__init__.py
901
+ fairseq/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py
902
+ fairseq/examples/speech_synthesis/preprocessing/get_common_voice_audio_manifest.py
903
+ fairseq/examples/speech_synthesis/preprocessing/get_feature_manifest.py
904
+ fairseq/examples/speech_synthesis/preprocessing/get_ljspeech_audio_manifest.py
905
+ fairseq/examples/speech_synthesis/preprocessing/get_speaker_embedding.py
906
+ fairseq/examples/speech_synthesis/preprocessing/get_vctk_audio_manifest.py
907
+ fairseq/examples/speech_synthesis/preprocessing/denoiser/__init__.py
908
+ fairseq/examples/speech_synthesis/preprocessing/denoiser/demucs.py
909
+ fairseq/examples/speech_synthesis/preprocessing/denoiser/pretrained.py
910
+ fairseq/examples/speech_synthesis/preprocessing/denoiser/resample.py
911
+ fairseq/examples/speech_synthesis/preprocessing/denoiser/utils.py
912
+ fairseq/examples/speech_synthesis/preprocessing/speaker_embedder/__init__.py
913
+ fairseq/examples/speech_synthesis/preprocessing/vad/__init__.py
914
+ fairseq/examples/speech_text_joint_to_text/README.md
915
+ fairseq/examples/speech_text_joint_to_text/__init__.py
916
+ fairseq/examples/speech_text_joint_to_text/configs/mustc_noise.list
917
+ fairseq/examples/speech_text_joint_to_text/criterions/__init__.py
918
+ fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py
919
+ fairseq/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py
920
+ fairseq/examples/speech_text_joint_to_text/criterions/text_guide_cross_entropy_acc.py
921
+ fairseq/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py
922
+ fairseq/examples/speech_text_joint_to_text/docs/ende-mustc.md
923
+ fairseq/examples/speech_text_joint_to_text/docs/iwslt2021.md
924
+ fairseq/examples/speech_text_joint_to_text/docs/pre-training.md
925
+ fairseq/examples/speech_text_joint_to_text/models/__init__.py
926
+ fairseq/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py
927
+ fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputtransformer.py
928
+ fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py
929
+ fairseq/examples/speech_text_joint_to_text/models/s2t_dualinputxmtransformer.py
930
+ fairseq/examples/speech_text_joint_to_text/scripts/convert_model.py
931
+ fairseq/examples/speech_text_joint_to_text/scripts/g2p_encode.py
932
+ fairseq/examples/speech_text_joint_to_text/tasks/__init__.py
933
+ fairseq/examples/speech_text_joint_to_text/tasks/pair_denoising.py
934
+ fairseq/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py
935
+ fairseq/examples/speech_text_joint_to_text/tasks/speech_text_joint.py
936
+ fairseq/examples/speech_to_speech/README.md
937
+ fairseq/examples/speech_to_speech/__init__.py
938
+ fairseq/examples/speech_to_speech/generate_waveform_from_code.py
939
+ fairseq/examples/speech_to_speech/asr_bleu/README.md
940
+ fairseq/examples/speech_to_speech/asr_bleu/__init__.py
941
+ fairseq/examples/speech_to_speech/asr_bleu/asr_model_cfgs.json
942
+ fairseq/examples/speech_to_speech/asr_bleu/compute_asr_bleu.py
943
+ fairseq/examples/speech_to_speech/asr_bleu/requirements.txt
944
+ fairseq/examples/speech_to_speech/asr_bleu/utils.py
945
+ fairseq/examples/speech_to_speech/benchmarking/README.md
946
+ fairseq/examples/speech_to_speech/benchmarking/core.py
947
+ fairseq/examples/speech_to_speech/benchmarking/data_utils.py
948
+ fairseq/examples/speech_to_speech/benchmarking/get_metrics.py
949
+ fairseq/examples/speech_to_speech/benchmarking/configs/2StageS2ST.yaml
950
+ fairseq/examples/speech_to_speech/benchmarking/configs/3StageS2ST.yaml
951
+ fairseq/examples/speech_to_speech/benchmarking/configs/DirectS2U.yaml
952
+ fairseq/examples/speech_to_speech/benchmarking/configs/S2T.yaml
953
+ fairseq/examples/speech_to_speech/docs/data_augmentation.md
954
+ fairseq/examples/speech_to_speech/docs/direct_s2st_discrete_units.md
955
+ fairseq/examples/speech_to_speech/docs/enhanced_direct_s2st_discrete_units.md
956
+ fairseq/examples/speech_to_speech/docs/textless_s2st_real_data.md
957
+ fairseq/examples/speech_to_speech/preprocessing/__init__.py
958
+ fairseq/examples/speech_to_speech/preprocessing/data_utils.py
959
+ fairseq/examples/speech_to_speech/preprocessing/prep_s2spect_data.py
960
+ fairseq/examples/speech_to_speech/preprocessing/prep_s2ut_data.py
961
+ fairseq/examples/speech_to_speech/preprocessing/prep_sn_data.py
962
+ fairseq/examples/speech_to_speech/preprocessing/prep_sn_output_data.py
963
+ fairseq/examples/speech_to_speech/unity/__init__.py
964
+ fairseq/examples/speech_to_speech/unity/sequence_generator.py
965
+ fairseq/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py
966
+ fairseq/examples/speech_to_text/README.md
967
+ fairseq/examples/speech_to_text/data_utils.py
968
+ fairseq/examples/speech_to_text/prep_covost_data.py
969
+ fairseq/examples/speech_to_text/prep_librispeech_data.py
970
+ fairseq/examples/speech_to_text/prep_mtedx_data.py
971
+ fairseq/examples/speech_to_text/prep_mustc_data.py
972
+ fairseq/examples/speech_to_text/seg_mustc_data.py
973
+ fairseq/examples/speech_to_text/docs/covost_example.md
974
+ fairseq/examples/speech_to_text/docs/librispeech_example.md
975
+ fairseq/examples/speech_to_text/docs/mtedx_example.md
976
+ fairseq/examples/speech_to_text/docs/mustc_example.md
977
+ fairseq/examples/speech_to_text/docs/simulst_mustc_example.md
978
+ fairseq/examples/speech_to_text/simultaneous_translation/agents/fairseq_simul_st_agent.py
979
+ fairseq/examples/stories/README.md
980
+ fairseq/examples/textless_nlp/dgslm/README.md
981
+ fairseq/examples/textless_nlp/dgslm/create_code_file.py
982
+ fairseq/examples/textless_nlp/dgslm/dgslm_utils.py
983
+ fairseq/examples/textless_nlp/dgslm/sample_speech_dlm.py
984
+ fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md
985
+ fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md
986
+ fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py
987
+ fairseq/examples/textless_nlp/gslm/README.md
988
+ fairseq/examples/textless_nlp/gslm/metrics/README.md
989
+ fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/README.md
990
+ fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py
991
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md
992
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/continuation_eval.py
993
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/ppx.py
994
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py
995
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/bleu_utils.py
996
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/cut_as.py
997
+ fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/misc/dict.ltr.txt
998
+ fairseq/examples/textless_nlp/gslm/speech2unit/README.md
999
+ fairseq/examples/textless_nlp/gslm/speech2unit/__init__.py
1000
+ fairseq/examples/textless_nlp/gslm/speech2unit/clustering/__init__.py
1001
+ fairseq/examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py
1002
+ fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py
1003
+ fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py
1004
+ fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py
1005
+ fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py
1006
+ fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py
1007
+ fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py
1008
+ fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py
1009
+ fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py
1010
+ fairseq/examples/textless_nlp/gslm/tools/README.md
1011
+ fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py
1012
+ fairseq/examples/textless_nlp/gslm/ulm/README.md
1013
+ fairseq/examples/textless_nlp/gslm/ulm/sample.py
1014
+ fairseq/examples/textless_nlp/gslm/unit2speech/README.md
1015
+ fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py
1016
+ fairseq/examples/textless_nlp/gslm/unit2speech/glow.py
1017
+ fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py
1018
+ fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py
1019
+ fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py
1020
+ fairseq/examples/textless_nlp/gslm/unit2speech/utils.py
1021
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py
1022
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py
1023
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py
1024
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py
1025
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py
1026
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py
1027
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py
1028
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py
1029
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py
1030
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py
1031
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py
1032
+ fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py
1033
+ fairseq/examples/textless_nlp/pgslm/README.md
1034
+ fairseq/examples/textless_nlp/pgslm/data_utils.py
1035
+ fairseq/examples/textless_nlp/pgslm/generate_waveform.py
1036
+ fairseq/examples/textless_nlp/pgslm/inference_dataset.py
1037
+ fairseq/examples/textless_nlp/pgslm/naive_decoder.py
1038
+ fairseq/examples/textless_nlp/pgslm/prepare_dataset.py
1039
+ fairseq/examples/textless_nlp/pgslm/preprocess_f0.py
1040
+ fairseq/examples/textless_nlp/pgslm/quantize_f0.py
1041
+ fairseq/examples/textless_nlp/pgslm/truncated_laplace.py
1042
+ fairseq/examples/textless_nlp/pgslm/eval/__init__.py
1043
+ fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py
1044
+ fairseq/examples/textless_nlp/pgslm/sample/__init__.py
1045
+ fairseq/examples/textless_nlp/pgslm/sample/sample.py
1046
+ fairseq/examples/textless_nlp/pgslm/scripts/join_units_manifest.py
1047
+ fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh
1048
+ fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh
1049
+ fairseq/examples/textless_nlp/speech-resynth/README.md
1050
+ fairseq/examples/textless_nlp/speech-resynth/img/fig.png
1051
+ fairseq/examples/translation/README.md
1052
+ fairseq/examples/translation/prepare-iwslt14.sh
1053
+ fairseq/examples/translation/prepare-iwslt17-multilingual.sh
1054
+ fairseq/examples/translation/prepare-wmt14en2de.sh
1055
+ fairseq/examples/translation/prepare-wmt14en2fr.sh
1056
+ fairseq/examples/translation_moe/README.md
1057
+ fairseq/examples/translation_moe/score.py
1058
+ fairseq/examples/translation_moe/translation_moe_src/__init__.py
1059
+ fairseq/examples/translation_moe/translation_moe_src/logsumexp_moe.py
1060
+ fairseq/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py
1061
+ fairseq/examples/translation_moe/translation_moe_src/translation_moe.py
1062
+ fairseq/examples/truncated_bptt/README.md
1063
+ fairseq/examples/truncated_bptt/__init__.py
1064
+ fairseq/examples/truncated_bptt/transformer_xl_model.py
1065
+ fairseq/examples/truncated_bptt/truncated_bptt_lm_task.py
1066
+ fairseq/examples/unsupervised_quality_estimation/README.md
1067
+ fairseq/examples/unsupervised_quality_estimation/aggregate_scores.py
1068
+ fairseq/examples/unsupervised_quality_estimation/meteor.py
1069
+ fairseq/examples/unsupervised_quality_estimation/repeat_lines.py
1070
+ fairseq/examples/wav2vec/README.md
1071
+ fairseq/examples/wav2vec/__init__.py
1072
+ fairseq/examples/wav2vec/libri_labels.py
1073
+ fairseq/examples/wav2vec/vq-wav2vec_featurize.py
1074
+ fairseq/examples/wav2vec/wav2vec_featurize.py
1075
+ fairseq/examples/wav2vec/wav2vec_manifest.py
1076
+ fairseq/examples/wav2vec/config/finetuning/base_100h.yaml
1077
+ fairseq/examples/wav2vec/config/finetuning/base_10h.yaml
1078
+ fairseq/examples/wav2vec/config/finetuning/base_10m.yaml
1079
+ fairseq/examples/wav2vec/config/finetuning/base_1h.yaml
1080
+ fairseq/examples/wav2vec/config/finetuning/base_960h.yaml
1081
+ fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml
1082
+ fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml
1083
+ fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml
1084
+ fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml
1085
+ fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml
1086
+ fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml
1087
+ fairseq/examples/wav2vec/config/finetuning/vox_10h_2_aws.yaml
1088
+ fairseq/examples/wav2vec/config/finetuning/vox_10h_aws.yaml
1089
+ fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml
1090
+ fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml
1091
+ fairseq/examples/wav2vec/config/finetuning/vox_10m_2.yaml
1092
+ fairseq/examples/wav2vec/config/finetuning/vox_10m_2_aws.yaml
1093
+ fairseq/examples/wav2vec/config/finetuning/vox_10m_3.yaml
1094
+ fairseq/examples/wav2vec/config/finetuning/vox_1h.yaml
1095
+ fairseq/examples/wav2vec/config/finetuning/vox_1h_2.yaml
1096
+ fairseq/examples/wav2vec/config/finetuning/vox_1h_2_aws.yaml
1097
+ fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml
1098
+ fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml
1099
+ fairseq/examples/wav2vec/config/finetuning/vox_1h_aws.yaml
1100
+ fairseq/examples/wav2vec/config/finetuning/vox_960h.yaml
1101
+ fairseq/examples/wav2vec/config/finetuning/vox_960h_2.yaml
1102
+ fairseq/examples/wav2vec/config/finetuning/vox_960h_2_aws.yaml
1103
+ fairseq/examples/wav2vec/config/finetuning/vox_960h_3.yaml
1104
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1.yaml
1105
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_16.yaml
1106
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_aws.yaml
1107
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_1_old.yaml
1108
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2.yaml
1109
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2_aws.yaml
1110
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_2g.yaml
1111
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_3.yaml
1112
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g.yaml
1113
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_4g_aws.yaml
1114
+ fairseq/examples/wav2vec/config/finetuning/run_config/slurm_8.yaml
1115
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml
1116
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_base_librispeech.yaml
1117
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_conformer_large_librivox.yaml
1118
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml
1119
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu-pod.yaml
1120
+ fairseq/examples/wav2vec/config/pretraining/wav2vec2_large_librivox_tpu.yaml
1121
+ fairseq/examples/wav2vec/scripts/binarize_manifest.sh
1122
+ fairseq/examples/wav2vec/unsupervised/README.md
1123
+ fairseq/examples/wav2vec/unsupervised/__init__.py
1124
+ fairseq/examples/wav2vec/unsupervised/w2vu_generate.py
1125
+ fairseq/examples/wav2vec/unsupervised/config/finetuning/w2v_finetune.yaml
1126
+ fairseq/examples/wav2vec/unsupervised/config/gan/w2vu.yaml
1127
+ fairseq/examples/wav2vec/unsupervised/config/gan/w2vu2.yaml
1128
+ fairseq/examples/wav2vec/unsupervised/config/generate/viterbi.yaml
1129
+ fairseq/examples/wav2vec/unsupervised/config/timit_matched/test.uid
1130
+ fairseq/examples/wav2vec/unsupervised/config/timit_matched/train.uid
1131
+ fairseq/examples/wav2vec/unsupervised/config/timit_matched/train_text.uid
1132
+ fairseq/examples/wav2vec/unsupervised/config/timit_matched/valid.uid
1133
+ fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/test.uid
1134
+ fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train.uid
1135
+ fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/train_text.uid
1136
+ fairseq/examples/wav2vec/unsupervised/config/timit_unmatched/valid.uid
1137
+ fairseq/examples/wav2vec/unsupervised/data/__init__.py
1138
+ fairseq/examples/wav2vec/unsupervised/data/extracted_features_dataset.py
1139
+ fairseq/examples/wav2vec/unsupervised/data/random_input_dataset.py
1140
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/README.md
1141
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/cmd.sh
1142
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_phone.sh
1143
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step1.sh
1144
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/decode_word_step2.sh
1145
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/path.sh
1146
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/train.sh
1147
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/copy_aligned_text.py
1148
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/decode.sh
1149
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_data_from_w2v.py
1150
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang.sh
1151
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lang_word.sh
1152
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/prepare_lm.sh
1153
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/score.sh
1154
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/show_wer.sh
1155
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/train_subset_lgbeam.sh
1156
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select.py
1157
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode.sh
1158
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/local/unsup_select_decode_word.sh
1159
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_deltas.sh
1160
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_lda_mllt.sh
1161
+ fairseq/examples/wav2vec/unsupervised/kaldi_self_train/st/steps_gan/train_sat.sh
1162
+ fairseq/examples/wav2vec/unsupervised/models/__init__.py
1163
+ fairseq/examples/wav2vec/unsupervised/models/wav2vec_u.py
1164
+ fairseq/examples/wav2vec/unsupervised/scripts/apply_pca.py
1165
+ fairseq/examples/wav2vec/unsupervised/scripts/copy_labels.py
1166
+ fairseq/examples/wav2vec/unsupervised/scripts/filter_lexicon.py
1167
+ fairseq/examples/wav2vec/unsupervised/scripts/filter_tsv.py
1168
+ fairseq/examples/wav2vec/unsupervised/scripts/g2p_wrd_to_phn.py
1169
+ fairseq/examples/wav2vec/unsupervised/scripts/ltr_to_wrd.py
1170
+ fairseq/examples/wav2vec/unsupervised/scripts/mean_pool.py
1171
+ fairseq/examples/wav2vec/unsupervised/scripts/merge_clusters.py
1172
+ fairseq/examples/wav2vec/unsupervised/scripts/normalize_and_filter_text.py
1173
+ fairseq/examples/wav2vec/unsupervised/scripts/normalize_text.py
1174
+ fairseq/examples/wav2vec/unsupervised/scripts/pca.py
1175
+ fairseq/examples/wav2vec/unsupervised/scripts/phonemize_with_sil.py
1176
+ fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio.sh
1177
+ fairseq/examples/wav2vec/unsupervised/scripts/prepare_audio_v2.sh
1178
+ fairseq/examples/wav2vec/unsupervised/scripts/prepare_text.sh
1179
+ fairseq/examples/wav2vec/unsupervised/scripts/prepare_timit.sh
1180
+ fairseq/examples/wav2vec/unsupervised/scripts/remove_silence.py
1181
+ fairseq/examples/wav2vec/unsupervised/scripts/vads.py
1182
+ fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_apply_cluster_faiss.py
1183
+ fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_cluster_faiss.py
1184
+ fairseq/examples/wav2vec/unsupervised/scripts/wav2vec_extract_features.py
1185
+ fairseq/examples/wav2vec/unsupervised/scripts/wer.py
1186
+ fairseq/examples/wav2vec/unsupervised/scripts/wrd_to_ltr.py
1187
+ fairseq/examples/wav2vec/unsupervised/tasks/__init__.py
1188
+ fairseq/examples/wav2vec/unsupervised/tasks/unpaired_audio_text.py
1189
+ fairseq/examples/wav2vec/xlsr/README.md
1190
+ fairseq/examples/wav2vec/xlsr/config/finetune.yaml
1191
+ fairseq/examples/wav2vec/xlsr/scripts/eval_speaker_clf_task.py
1192
+ fairseq/examples/wav2vec/xlsr/scripts/gen_audio_embedding.py
1193
+ fairseq/examples/wmt19/README.md
1194
+ fairseq/examples/wmt20/README.md
1195
+ fairseq/examples/wmt21/README.md
1196
+ fairseq/examples/wmt21/eval.sh
1197
+ fairseq/examples/wmt21/scripts/normalize-punctuation.perl
1198
+ fairseq/examples/wmt21/scripts/replace-unicode-punctuation.perl
1199
+ fairseq/examples/womens_bios/README.md
1200
+ fairseq/examples/womens_bios/query_occupations_from_wikidata.py
1201
+ fairseq/examples/xformers/README.md
1202
+ fairseq/examples/xglm/README.md
1203
+ fairseq/examples/xglm/XStoryCloze.md
1204
+ fairseq/examples/xglm/model_card.md
1205
+ fairseq/examples/xlmr/README.md
1206
+ fairseq/examples/xmod/README.md
1207
+ fairseq/examples/xmod/preprocess_nli.py
1208
+ fairseq/logging/__init__.py
1209
+ fairseq/logging/meters.py
1210
+ fairseq/logging/metrics.py
1211
+ fairseq/logging/progress_bar.py
1212
+ fairseq/model_parallel/__init__.py
1213
+ fairseq/model_parallel/megatron_trainer.py
1214
+ fairseq/model_parallel/criterions/__init__.py
1215
+ fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
1216
+ fairseq/model_parallel/models/__init__.py
1217
+ fairseq/model_parallel/models/transformer.py
1218
+ fairseq/model_parallel/models/transformer_lm.py
1219
+ fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py
1220
+ fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
1221
+ fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
1222
+ fairseq/model_parallel/models/roberta/__init__.py
1223
+ fairseq/model_parallel/models/roberta/model.py
1224
+ fairseq/model_parallel/modules/__init__.py
1225
+ fairseq/model_parallel/modules/multihead_attention.py
1226
+ fairseq/model_parallel/modules/transformer_layer.py
1227
+ fairseq/models/__init__.py
1228
+ fairseq/models/composite_encoder.py
1229
+ fairseq/models/distributed_fairseq_model.py
1230
+ fairseq/models/fairseq_decoder.py
1231
+ fairseq/models/fairseq_encoder.py
1232
+ fairseq/models/fairseq_incremental_decoder.py
1233
+ fairseq/models/fairseq_model.py
1234
+ fairseq/models/fconv.py
1235
+ fairseq/models/fconv_lm.py
1236
+ fairseq/models/fconv_self_att.py
1237
+ fairseq/models/lightconv.py
1238
+ fairseq/models/lightconv_lm.py
1239
+ fairseq/models/lstm.py
1240
+ fairseq/models/lstm_lm.py
1241
+ fairseq/models/masked_lm.py
1242
+ fairseq/models/model_utils.py
1243
+ fairseq/models/multilingual_transformer.py
1244
+ fairseq/models/transformer_align.py
1245
+ fairseq/models/transformer_from_pretrained_xlm.py
1246
+ fairseq/models/transformer_lm.py
1247
+ fairseq/models/transformer_ulm.py
1248
+ fairseq/models/bart/__init__.py
1249
+ fairseq/models/bart/hub_interface.py
1250
+ fairseq/models/bart/model.py
1251
+ fairseq/models/ema/__init__.py
1252
+ fairseq/models/ema/ema.py
1253
+ fairseq/models/hubert/__init__.py
1254
+ fairseq/models/hubert/hubert.py
1255
+ fairseq/models/hubert/hubert_asr.py
1256
+ fairseq/models/huggingface/__init__.py
1257
+ fairseq/models/huggingface/hf_gpt2.py
1258
+ fairseq/models/multires_hubert/__init__.py
1259
+ fairseq/models/multires_hubert/multires_hubert.py
1260
+ fairseq/models/multires_hubert/multires_hubert_asr.py
1261
+ fairseq/models/nat/__init__.py
1262
+ fairseq/models/nat/cmlm_transformer.py
1263
+ fairseq/models/nat/fairseq_nat_model.py
1264
+ fairseq/models/nat/insertion_transformer.py
1265
+ fairseq/models/nat/iterative_nonautoregressive_transformer.py
1266
+ fairseq/models/nat/levenshtein_transformer.py
1267
+ fairseq/models/nat/levenshtein_utils.py
1268
+ fairseq/models/nat/nat_crf_transformer.py
1269
+ fairseq/models/nat/nonautoregressive_ensembles.py
1270
+ fairseq/models/nat/nonautoregressive_transformer.py
1271
+ fairseq/models/roberta/__init__.py
1272
+ fairseq/models/roberta/alignment_utils.py
1273
+ fairseq/models/roberta/enc_dec.py
1274
+ fairseq/models/roberta/hub_interface.py
1275
+ fairseq/models/roberta/model.py
1276
+ fairseq/models/roberta/model_camembert.py
1277
+ fairseq/models/roberta/model_gottbert.py
1278
+ fairseq/models/roberta/model_xlmr.py
1279
+ fairseq/models/speech_dlm/__init__.py
1280
+ fairseq/models/speech_dlm/hub_interface.py
1281
+ fairseq/models/speech_dlm/speech_dlm.py
1282
+ fairseq/models/speech_dlm/modules/__init__.py
1283
+ fairseq/models/speech_dlm/modules/speech_dlm_decoder.py
1284
+ fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py
1285
+ fairseq/models/speech_dlm/sequence_generator/__init__.py
1286
+ fairseq/models/speech_dlm/sequence_generator/multichannel_search.py
1287
+ fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py
1288
+ fairseq/models/speech_to_speech/__init__.py
1289
+ fairseq/models/speech_to_speech/s2s_conformer.py
1290
+ fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py
1291
+ fairseq/models/speech_to_speech/s2s_conformer_unity.py
1292
+ fairseq/models/speech_to_speech/s2s_transformer.py
1293
+ fairseq/models/speech_to_speech/modules/__init__.py
1294
+ fairseq/models/speech_to_speech/modules/ctc_decoder.py
1295
+ fairseq/models/speech_to_speech/modules/stacked_embedding.py
1296
+ fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py
1297
+ fairseq/models/speech_to_speech/modules/transformer_encoder.py
1298
+ fairseq/models/speech_to_text/__init__.py
1299
+ fairseq/models/speech_to_text/berard.py
1300
+ fairseq/models/speech_to_text/convtransformer.py
1301
+ fairseq/models/speech_to_text/hub_interface.py
1302
+ fairseq/models/speech_to_text/multi_modality_model.py
1303
+ fairseq/models/speech_to_text/s2t_conformer.py
1304
+ fairseq/models/speech_to_text/s2t_transformer.py
1305
+ fairseq/models/speech_to_text/s2t_wav_transformer.py
1306
+ fairseq/models/speech_to_text/utils.py
1307
+ fairseq/models/speech_to_text/xm_transformer.py
1308
+ fairseq/models/speech_to_text/xm_transformer_unity.py
1309
+ fairseq/models/speech_to_text/modules/__init__.py
1310
+ fairseq/models/speech_to_text/modules/augmented_memory_attention.py
1311
+ fairseq/models/speech_to_text/modules/convolution.py
1312
+ fairseq/models/speech_to_text/modules/emformer.py
1313
+ fairseq/models/text_to_speech/__init__.py
1314
+ fairseq/models/text_to_speech/codehifigan.py
1315
+ fairseq/models/text_to_speech/fastspeech2.py
1316
+ fairseq/models/text_to_speech/hifigan.py
1317
+ fairseq/models/text_to_speech/hub_interface.py
1318
+ fairseq/models/text_to_speech/tacotron2.py
1319
+ fairseq/models/text_to_speech/tts_transformer.py
1320
+ fairseq/models/text_to_speech/vocoder.py
1321
+ fairseq/models/transformer/__init__.py
1322
+ fairseq/models/transformer/transformer_base.py
1323
+ fairseq/models/transformer/transformer_config.py
1324
+ fairseq/models/transformer/transformer_decoder.py
1325
+ fairseq/models/transformer/transformer_decoder_aug.py
1326
+ fairseq/models/transformer/transformer_encoder.py
1327
+ fairseq/models/transformer/transformer_legacy.py
1328
+ fairseq/models/wav2vec/__init__.py
1329
+ fairseq/models/wav2vec/utils.py
1330
+ fairseq/models/wav2vec/wav2vec.py
1331
+ fairseq/models/wav2vec/wav2vec2.py
1332
+ fairseq/models/wav2vec/wav2vec2_asr.py
1333
+ fairseq/models/wav2vec/wav2vec2_classification.py
1334
+ fairseq/models/wav2vec/wav2vec2_laser.py
1335
+ fairseq/models/xmod/__init__.py
1336
+ fairseq/models/xmod/hub_interface.py
1337
+ fairseq/models/xmod/model.py
1338
+ fairseq/models/xmod/transformer_layer_xmod.py
1339
+ fairseq/modules/__init__.py
1340
+ fairseq/modules/adaptive_input.py
1341
+ fairseq/modules/adaptive_softmax.py
1342
+ fairseq/modules/base_layer.py
1343
+ fairseq/modules/beamable_mm.py
1344
+ fairseq/modules/character_token_embedder.py
1345
+ fairseq/modules/checkpoint_activations.py
1346
+ fairseq/modules/conformer_layer.py
1347
+ fairseq/modules/conv_tbc.py
1348
+ fairseq/modules/cross_entropy.py
1349
+ fairseq/modules/downsampled_multihead_attention.py
1350
+ fairseq/modules/dynamic_convolution.py
1351
+ fairseq/modules/dynamic_crf_layer.py
1352
+ fairseq/modules/ema_module.py
1353
+ fairseq/modules/espnet_multihead_attention.py
1354
+ fairseq/modules/fairseq_dropout.py
1355
+ fairseq/modules/fp32_batch_norm.py
1356
+ fairseq/modules/fp32_group_norm.py
1357
+ fairseq/modules/fp32_instance_norm.py
1358
+ fairseq/modules/gelu.py
1359
+ fairseq/modules/grad_multiply.py
1360
+ fairseq/modules/gumbel_vector_quantizer.py
1361
+ fairseq/modules/kmeans_attention.py
1362
+ fairseq/modules/kmeans_vector_quantizer.py
1363
+ fairseq/modules/layer_drop.py
1364
+ fairseq/modules/layer_norm.py
1365
+ fairseq/modules/learned_positional_embedding.py
1366
+ fairseq/modules/lightweight_convolution.py
1367
+ fairseq/modules/linearized_convolution.py
1368
+ fairseq/modules/location_attention.py
1369
+ fairseq/modules/lstm_cell_with_zoneout.py
1370
+ fairseq/modules/multihead_attention.py
1371
+ fairseq/modules/positional_embedding.py
1372
+ fairseq/modules/positional_encoding.py
1373
+ fairseq/modules/quant_noise.py
1374
+ fairseq/modules/rotary_positional_embedding.py
1375
+ fairseq/modules/same_pad.py
1376
+ fairseq/modules/scalar_bias.py
1377
+ fairseq/modules/sinusoidal_positional_embedding.py
1378
+ fairseq/modules/sparse_multihead_attention.py
1379
+ fairseq/modules/sparse_transformer_sentence_encoder.py
1380
+ fairseq/modules/sparse_transformer_sentence_encoder_layer.py
1381
+ fairseq/modules/transformer_layer.py
1382
+ fairseq/modules/transformer_layer_aug.py
1383
+ fairseq/modules/transformer_sentence_encoder.py
1384
+ fairseq/modules/transformer_sentence_encoder_layer.py
1385
+ fairseq/modules/transpose_last.py
1386
+ fairseq/modules/unfold.py
1387
+ fairseq/modules/vggblock.py
1388
+ fairseq/modules/dynamicconv_layer/__init__.py
1389
+ fairseq/modules/dynamicconv_layer/cuda_function_gen.py
1390
+ fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
1391
+ fairseq/modules/dynamicconv_layer/setup.py
1392
+ fairseq/modules/lightconv_layer/__init__.py
1393
+ fairseq/modules/lightconv_layer/cuda_function_gen.py
1394
+ fairseq/modules/lightconv_layer/lightconv_layer.py
1395
+ fairseq/modules/lightconv_layer/setup.py
1396
+ fairseq/modules/quantization/__init__.py
1397
+ fairseq/modules/quantization/quantization_options.py
1398
+ fairseq/modules/quantization/pq/__init__.py
1399
+ fairseq/modules/quantization/pq/em.py
1400
+ fairseq/modules/quantization/pq/pq.py
1401
+ fairseq/modules/quantization/pq/utils.py
1402
+ fairseq/modules/quantization/pq/modules/__init__.py
1403
+ fairseq/modules/quantization/pq/modules/qconv.py
1404
+ fairseq/modules/quantization/pq/modules/qemb.py
1405
+ fairseq/modules/quantization/pq/modules/qlinear.py
1406
+ fairseq/modules/quantization/scalar/__init__.py
1407
+ fairseq/modules/quantization/scalar/ops.py
1408
+ fairseq/modules/quantization/scalar/utils.py
1409
+ fairseq/modules/quantization/scalar/modules/__init__.py
1410
+ fairseq/modules/quantization/scalar/modules/qact.py
1411
+ fairseq/modules/quantization/scalar/modules/qconv.py
1412
+ fairseq/modules/quantization/scalar/modules/qemb.py
1413
+ fairseq/modules/quantization/scalar/modules/qlinear.py
1414
+ fairseq/optim/__init__.py
1415
+ fairseq/optim/adadelta.py
1416
+ fairseq/optim/adafactor.py
1417
+ fairseq/optim/adagrad.py
1418
+ fairseq/optim/adam.py
1419
+ fairseq/optim/adamax.py
1420
+ fairseq/optim/amp_optimizer.py
1421
+ fairseq/optim/bmuf.py
1422
+ fairseq/optim/composite.py
1423
+ fairseq/optim/cpu_adam.py
1424
+ fairseq/optim/dynamic_loss_scaler.py
1425
+ fairseq/optim/fairseq_optimizer.py
1426
+ fairseq/optim/fp16_optimizer.py
1427
+ fairseq/optim/fused_adam.py
1428
+ fairseq/optim/fused_lamb.py
1429
+ fairseq/optim/nag.py
1430
+ fairseq/optim/sgd.py
1431
+ fairseq/optim/shard.py
1432
+ fairseq/optim/lr_scheduler/__init__.py
1433
+ fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
1434
+ fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
1435
+ fairseq/optim/lr_scheduler/fixed_schedule.py
1436
+ fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
1437
+ fairseq/optim/lr_scheduler/manual_lr_scheduler.py
1438
+ fairseq/optim/lr_scheduler/pass_through.py
1439
+ fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
1440
+ fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
1441
+ fairseq/optim/lr_scheduler/step_lr_scheduler.py
1442
+ fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py
1443
+ fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
1444
+ fairseq/scoring/__init__.py
1445
+ fairseq/scoring/bertscore.py
1446
+ fairseq/scoring/bleu.py
1447
+ fairseq/scoring/chrf.py
1448
+ fairseq/scoring/meteor.py
1449
+ fairseq/scoring/tokenizer.py
1450
+ fairseq/scoring/wer.py
1451
+ fairseq/tasks/__init__.py
1452
+ fairseq/tasks/audio_classification.py
1453
+ fairseq/tasks/audio_finetuning.py
1454
+ fairseq/tasks/audio_pretraining.py
1455
+ fairseq/tasks/cross_lingual_lm.py
1456
+ fairseq/tasks/denoising.py
1457
+ fairseq/tasks/fairseq_task.py
1458
+ fairseq/tasks/frm_text_to_speech.py
1459
+ fairseq/tasks/hubert_pretraining.py
1460
+ fairseq/tasks/language_modeling.py
1461
+ fairseq/tasks/legacy_masked_lm.py
1462
+ fairseq/tasks/masked_lm.py
1463
+ fairseq/tasks/multilingual_denoising.py
1464
+ fairseq/tasks/multilingual_language_modeling.py
1465
+ fairseq/tasks/multilingual_masked_lm.py
1466
+ fairseq/tasks/multilingual_translation.py
1467
+ fairseq/tasks/multires_hubert_pretraining.py
1468
+ fairseq/tasks/nlu_finetuning.py
1469
+ fairseq/tasks/online_backtranslation.py
1470
+ fairseq/tasks/semisupervised_translation.py
1471
+ fairseq/tasks/sentence_prediction.py
1472
+ fairseq/tasks/sentence_prediction_adapters.py
1473
+ fairseq/tasks/sentence_ranking.py
1474
+ fairseq/tasks/simultaneous_translation.py
1475
+ fairseq/tasks/span_masked_lm.py
1476
+ fairseq/tasks/speech_dlm_task.py
1477
+ fairseq/tasks/speech_to_speech.py
1478
+ fairseq/tasks/speech_to_text.py
1479
+ fairseq/tasks/speech_ulm_task.py
1480
+ fairseq/tasks/text_to_speech.py
1481
+ fairseq/tasks/translation.py
1482
+ fairseq/tasks/translation_from_pretrained_bart.py
1483
+ fairseq/tasks/translation_from_pretrained_xlm.py
1484
+ fairseq/tasks/translation_lev.py
1485
+ fairseq/tasks/translation_multi_simple_epoch.py
1486
+ fairseq_cli/__init__.py
1487
+ fairseq_cli/eval_lm.py
1488
+ fairseq_cli/generate.py
1489
+ fairseq_cli/hydra_train.py
1490
+ fairseq_cli/hydra_validate.py
1491
+ fairseq_cli/interactive.py
1492
+ fairseq_cli/preprocess.py
1493
+ fairseq_cli/score.py
1494
+ fairseq_cli/train.py
1495
+ fairseq_cli/validate.py
1496
+ tests/test_activation_checkpointing.py
1497
+ tests/test_amp_optimizer.py
1498
+ tests/test_average_checkpoints.py
1499
+ tests/test_backtranslation_dataset.py
1500
+ tests/test_binaries.py
1501
+ tests/test_binarizer.py
1502
+ tests/test_character_token_embedder.py
1503
+ tests/test_checkpoint_utils.py
1504
+ tests/test_checkpoint_utils_for_task_level_attributes.py
1505
+ tests/test_concat_dataset.py
1506
+ tests/test_constraints.py
1507
+ tests/test_convtbc.py
1508
+ tests/test_data_utils.py
1509
+ tests/test_dataclass_utils.py
1510
+ tests/test_dataset.py
1511
+ tests/test_dictionary.py
1512
+ tests/test_ema.py
1513
+ tests/test_espnet_multihead_attention.py
1514
+ tests/test_export.py
1515
+ tests/test_file_chunker_utils.py
1516
+ tests/test_file_io.py
1517
+ tests/test_fp16_optimizer.py
1518
+ tests/test_hf_hub.py
1519
+ tests/test_huffman.py
1520
+ tests/test_inference_dropout.py
1521
+ tests/test_iopath.py
1522
+ tests/test_iterators.py
1523
+ tests/test_label_smoothing.py
1524
+ tests/test_lm_context_window.py
1525
+ tests/test_lstm_jitable.py
1526
+ tests/test_memory_efficient_fp16.py
1527
+ tests/test_metrics.py
1528
+ tests/test_multi_corpus_dataset.py
1529
+ tests/test_multi_corpus_sampled_dataset.py
1530
+ tests/test_multihead_attention.py
1531
+ tests/test_noising.py
1532
+ tests/test_online_backtranslation.py
1533
+ tests/test_plasma_utils.py
1534
+ tests/test_positional_encoding.py
1535
+ tests/test_reproducibility.py
1536
+ tests/test_resampling_dataset.py
1537
+ tests/test_roberta.py
1538
+ tests/test_rotary_positional_embedding.py
1539
+ tests/test_sequence_generator.py
1540
+ tests/test_sequence_scorer.py
1541
+ tests/test_sparse_multihead_attention.py
1542
+ tests/test_token_block_dataset.py
1543
+ tests/test_train.py
1544
+ tests/test_transformer.py
1545
+ tests/test_utils.py
1546
+ tests/test_valid_subset_checks.py
fairseq/fairseq.egg-info/entry_points.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [console_scripts]
2
+ fairseq-eval-lm = fairseq_cli.eval_lm:cli_main
3
+ fairseq-generate = fairseq_cli.generate:cli_main
4
+ fairseq-hydra-train = fairseq_cli.hydra_train:cli_main
5
+ fairseq-interactive = fairseq_cli.interactive:cli_main
6
+ fairseq-preprocess = fairseq_cli.preprocess:cli_main
7
+ fairseq-score = fairseq_cli.score:cli_main
8
+ fairseq-train = fairseq_cli.train:cli_main
9
+ fairseq-validate = fairseq_cli.validate:cli_main
fairseq/fairseq.egg-info/requires.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cffi
2
+ cython
3
+ hydra-core<1.1,>=1.0.7
4
+ omegaconf<2.1
5
+ numpy>=1.21.3
6
+ regex
7
+ sacrebleu>=1.4.12
8
+ torch>=1.13
9
+ tqdm
10
+ bitarray
11
+ torchaudio>=0.8.0
12
+ scikit-learn
13
+ packaging
14
+
15
+ [dev]
16
+ flake8
17
+ pytest
18
+ black==22.3.0
19
+
20
+ [docs]
21
+ sphinx
22
+ sphinx-argparse
fairseq/fairseq.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ alignment_train_cpu_binding
2
+ alignment_train_cuda_binding
3
+ fairseq
4
+ fairseq_cli
fairseq/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc ADDED
Binary file (2.27 kB). View file
 
fairseq/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc ADDED
Binary file (8.77 kB). View file
 
fairseq/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
fairseq/fairseq/__pycache__/pdb.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
fairseq/fairseq_cli/__init__.py ADDED
File without changes
fairseq/fairseq_cli/eval_lm.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluate the perplexity of a trained language model.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from argparse import Namespace
16
+ from typing import Iterable, List, Optional
17
+
18
+ import torch
19
+ from omegaconf import DictConfig
20
+
21
+ import fairseq
22
+ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
23
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
24
+ from fairseq.logging import progress_bar
25
+ from fairseq.logging.meters import StopwatchMeter
26
+ from fairseq.sequence_scorer import SequenceScorer
27
+
28
+ logging.basicConfig(
29
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
32
+ stream=sys.stdout,
33
+ )
34
+ logger = logging.getLogger("fairseq_cli.eval_lm")
35
+
36
+
37
+ def eval_lm(
38
+ models: List[fairseq.models.FairseqModel],
39
+ source_dictionary: fairseq.data.Dictionary,
40
+ batch_iterator: Iterable,
41
+ post_process: Optional[str] = None,
42
+ output_word_probs: bool = False,
43
+ output_word_stats: bool = False,
44
+ target_dictionary: Optional[fairseq.data.Dictionary] = None,
45
+ softmax_batch: int = 0,
46
+ remove_bos_token: bool = False,
47
+ device: Optional[torch.device] = None,
48
+ ):
49
+ """
50
+ Args:
51
+ models (List[~fairseq.models.FairseqModel]): list of models to
52
+ evaluate. Models are essentially `nn.Module` instances, but
53
+ must be compatible with fairseq's `SequenceScorer`.
54
+ source_dictionary (~fairseq.data.Dictionary): dictionary for
55
+ applying any relevant post processing or outputing word
56
+ probs/stats.
57
+ batch_iterator (Iterable): yield batches of data
58
+ post_process (Optional[str]): post-process text by removing BPE,
59
+ letter segmentation, etc. Valid options can be found in
60
+ fairseq.data.utils.post_process, although not all options
61
+ are implemented here.
62
+ output_word_probs (Optional[bool]): output words and their
63
+ predicted log probabilities
64
+ output_word_stats (Optional[bool]): output word statistics such
65
+ as word count and average probability
66
+ target_dictionary (Optional[~fairseq.data.Dictionary]): output
67
+ dictionary (defaults to *source_dictionary*)
68
+ softmax_batch (Optional[bool]): if BxT is more than this, will
69
+ batch the softmax over vocab to this amount of tokens, in
70
+ order to fit into GPU memory
71
+ remove_bos_token (Optional[bool]): if True, confirm that the
72
+ first token is the beginning-of-sentence symbol (according
73
+ to the relevant dictionary) and remove it from the output
74
+ device (Optional[torch.device]): device to use for evaluation
75
+ (defaults to device of first model parameter)
76
+ """
77
+ if target_dictionary is None:
78
+ target_dictionary = source_dictionary
79
+ if device is None:
80
+ device = next(models[0].parameters()).device
81
+
82
+ gen_timer = StopwatchMeter()
83
+ scorer = SequenceScorer(target_dictionary, softmax_batch)
84
+
85
+ score_sum = 0.0
86
+ count = 0
87
+
88
+ if post_process is not None:
89
+ if post_process in {"subword_nmt", "@@ "}:
90
+ bpe_cont = post_process.rstrip()
91
+ bpe_toks = {
92
+ i
93
+ for i in range(len(source_dictionary))
94
+ if source_dictionary[i].endswith(bpe_cont)
95
+ }
96
+ else:
97
+ raise NotImplementedError(
98
+ f"--post-process={post_process} is not implemented"
99
+ )
100
+ bpe_len = len(bpe_cont)
101
+ else:
102
+ bpe_toks = None
103
+ bpe_len = 0
104
+
105
+ word_stats = dict()
106
+
107
+ for sample in batch_iterator:
108
+ if "net_input" not in sample:
109
+ continue
110
+
111
+ sample = utils.move_to_cuda(sample, device=device)
112
+
113
+ gen_timer.start()
114
+ hypos = scorer.generate(models, sample)
115
+ gen_timer.stop(sample["ntokens"])
116
+
117
+ for i, hypos_i in enumerate(hypos):
118
+ hypo = hypos_i[0]
119
+ sample_id = sample["id"][i]
120
+
121
+ tokens = hypo["tokens"]
122
+ tgt_len = tokens.numel()
123
+ pos_scores = hypo["positional_scores"].float()
124
+
125
+ if remove_bos_token:
126
+ assert hypo["tokens"][0].item() == target_dictionary.bos()
127
+ tokens = tokens[1:]
128
+ pos_scores = pos_scores[1:]
129
+
130
+ skipped_toks = 0
131
+ if bpe_toks is not None:
132
+ for i in range(tgt_len - 1):
133
+ if tokens[i].item() in bpe_toks:
134
+ skipped_toks += 1
135
+ pos_scores[i + 1] += pos_scores[i]
136
+ pos_scores[i] = 0
137
+
138
+ inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf"))
139
+ if inf_scores.any():
140
+ logger.info(
141
+ "skipping tokens with inf scores:",
142
+ target_dictionary.string(tokens[inf_scores.nonzero()]),
143
+ )
144
+ pos_scores = pos_scores[(~inf_scores).nonzero()]
145
+ score_sum += pos_scores.sum().cpu()
146
+ count += pos_scores.numel() - skipped_toks
147
+
148
+ if output_word_probs or output_word_stats:
149
+ w = ""
150
+ word_prob = []
151
+ is_bpe = False
152
+ for i in range(len(tokens)):
153
+ w_ind = tokens[i].item()
154
+ w += source_dictionary[w_ind]
155
+ if bpe_toks is not None and w_ind in bpe_toks:
156
+ w = w[:-bpe_len]
157
+ is_bpe = True
158
+ else:
159
+ word_prob.append((w, pos_scores[i].item()))
160
+
161
+ next_prob = None
162
+ ind = i + 1
163
+ while ind < len(tokens):
164
+ if pos_scores[ind].item() != 0:
165
+ next_prob = pos_scores[ind]
166
+ break
167
+ ind += 1
168
+
169
+ word_stats.setdefault(w, WordStat(w, is_bpe)).add(
170
+ pos_scores[i].item(), next_prob
171
+ )
172
+ is_bpe = False
173
+ w = ""
174
+ if output_word_probs:
175
+ logger.info(
176
+ str(int(sample_id))
177
+ + " "
178
+ + (
179
+ "\t".join(
180
+ "{} [{:2f}]".format(x[0], x[1]) for x in word_prob
181
+ )
182
+ )
183
+ )
184
+
185
+ avg_nll_loss = (
186
+ -score_sum / count / math.log(2) if count > 0 else 0
187
+ ) # convert to base 2
188
+ logger.info(
189
+ "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format(
190
+ gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0
191
+ )
192
+ )
193
+
194
+ if output_word_stats:
195
+ for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
196
+ logger.info(ws)
197
+
198
+ return {
199
+ "loss": avg_nll_loss,
200
+ "perplexity": 2**avg_nll_loss,
201
+ }
202
+
203
+
204
+ class WordStat(object):
205
+ def __init__(self, word, is_bpe):
206
+ self.word = word
207
+ self.is_bpe = is_bpe
208
+ self.log_prob = 0
209
+ self.next_word_prob = 0
210
+ self.count = 0
211
+ self.missing_next_words = 0
212
+
213
+ def add(self, log_prob, next_word_prob):
214
+ """increments counters for the sum of log probs of current word and next
215
+ word (given context ending at current word). Since the next word might be at the end of the example,
216
+ or it might be not counted because it is not an ending subword unit,
217
+ also keeps track of how many of those we have seen"""
218
+ if next_word_prob is not None:
219
+ self.next_word_prob += next_word_prob
220
+ else:
221
+ self.missing_next_words += 1
222
+ self.log_prob += log_prob
223
+ self.count += 1
224
+
225
+ def __str__(self):
226
+ return "{}\t{}\t{}\t{}\t{}\t{}".format(
227
+ self.word,
228
+ self.count,
229
+ self.log_prob,
230
+ self.is_bpe,
231
+ self.next_word_prob,
232
+ self.count - self.missing_next_words,
233
+ )
234
+
235
+
236
+ def main(cfg: DictConfig, **unused_kwargs):
237
+ if isinstance(cfg, Namespace):
238
+ cfg = convert_namespace_to_omegaconf(cfg)
239
+
240
+ utils.import_user_module(cfg.common)
241
+
242
+ logger.info(cfg)
243
+
244
+ if cfg.eval_lm.context_window > 0:
245
+ # reduce tokens per sample by the required context window size
246
+ cfg.task.tokens_per_sample -= cfg.eval_lm.context_window
247
+
248
+ # Initialize the task using the current *cfg*
249
+ task = tasks.setup_task(cfg.task)
250
+
251
+ # Load ensemble
252
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
253
+ models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
254
+ [cfg.common_eval.path],
255
+ arg_overrides=eval(cfg.common_eval.model_overrides),
256
+ suffix=cfg.checkpoint.checkpoint_suffix,
257
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
258
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
259
+ task=task,
260
+ )
261
+
262
+ use_fp16 = cfg.common.fp16
263
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
264
+ if use_cuda:
265
+ torch.cuda.set_device(cfg.distributed_training.device_id)
266
+
267
+ # Optimize ensemble for generation and set the source and dest dicts on the model
268
+ # (required by scorer)
269
+ for model in models:
270
+ if use_fp16:
271
+ model.half()
272
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
273
+ model.cuda()
274
+ model.prepare_for_inference_(cfg)
275
+
276
+ assert len(models) > 0
277
+
278
+ logger.info(
279
+ "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters()))
280
+ )
281
+
282
+ # Load dataset splits
283
+ task.load_dataset(cfg.dataset.gen_subset)
284
+ dataset = task.dataset(cfg.dataset.gen_subset)
285
+ logger.info(
286
+ "{} {} {:,} examples".format(
287
+ cfg.task.data, cfg.dataset.gen_subset, len(dataset)
288
+ )
289
+ )
290
+
291
+ itr = task.eval_lm_dataloader(
292
+ dataset=dataset,
293
+ max_tokens=cfg.dataset.max_tokens or 36000,
294
+ batch_size=cfg.dataset.batch_size,
295
+ max_positions=utils.resolve_max_positions(
296
+ *[model.max_positions() for model in models]
297
+ ),
298
+ num_shards=max(
299
+ cfg.dataset.num_shards,
300
+ cfg.distributed_training.distributed_world_size,
301
+ ),
302
+ shard_id=max(
303
+ cfg.dataset.shard_id,
304
+ cfg.distributed_training.distributed_rank,
305
+ ),
306
+ num_workers=cfg.dataset.num_workers,
307
+ data_buffer_size=cfg.dataset.data_buffer_size,
308
+ context_window=cfg.eval_lm.context_window,
309
+ )
310
+
311
+ itr = progress_bar.progress_bar(
312
+ itr,
313
+ log_format=cfg.common.log_format,
314
+ log_interval=cfg.common.log_interval,
315
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
316
+ )
317
+
318
+ results = eval_lm(
319
+ models=models,
320
+ source_dictionary=task.source_dictionary,
321
+ batch_iterator=itr,
322
+ post_process=cfg.common_eval.post_process,
323
+ output_word_probs=cfg.eval_lm.output_word_probs,
324
+ output_word_stats=cfg.eval_lm.output_word_stats,
325
+ target_dictionary=task.target_dictionary,
326
+ softmax_batch=cfg.eval_lm.softmax_batch,
327
+ remove_bos_token=getattr(cfg.task, "add_bos_token", False),
328
+ )
329
+
330
+ logger.info(
331
+ "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format(
332
+ results["loss"], results["perplexity"]
333
+ )
334
+ )
335
+
336
+ return results
337
+
338
+
339
+ def cli_main():
340
+ parser = options.get_eval_lm_parser()
341
+ args = options.parse_args_and_arch(parser)
342
+
343
+ distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ cli_main()
fairseq/fairseq_cli/generate.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate pre-processed data with a trained model.
8
+ """
9
+
10
+ import ast
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from argparse import Namespace
16
+ from itertools import chain
17
+
18
+ import numpy as np
19
+ import torch
20
+ from omegaconf import DictConfig
21
+
22
+ from fairseq import checkpoint_utils, options, scoring, tasks, utils
23
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
24
+ from fairseq.logging import progress_bar
25
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
26
+
27
+
28
+ def main(cfg: DictConfig):
29
+
30
+ if isinstance(cfg, Namespace):
31
+ cfg = convert_namespace_to_omegaconf(cfg)
32
+
33
+ assert cfg.common_eval.path is not None, "--path required for generation!"
34
+ assert (
35
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
36
+ ), "--sampling requires --nbest to be equal to --beam"
37
+ assert (
38
+ cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
39
+ ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
40
+
41
+ if cfg.common_eval.results_path is not None:
42
+ os.makedirs(cfg.common_eval.results_path, exist_ok=True)
43
+ output_path = os.path.join(
44
+ cfg.common_eval.results_path,
45
+ "generate-{}.txt".format(cfg.dataset.gen_subset),
46
+ )
47
+ with open(output_path, "w", buffering=1, encoding="utf-8") as h:
48
+ return _main(cfg, h)
49
+ else:
50
+ return _main(cfg, sys.stdout)
51
+
52
+
53
+ def get_symbols_to_strip_from_output(generator):
54
+ if hasattr(generator, "symbols_to_strip_from_output"):
55
+ return generator.symbols_to_strip_from_output
56
+ else:
57
+ return {generator.eos}
58
+
59
+
60
+ def _main(cfg: DictConfig, output_file):
61
+ logging.basicConfig(
62
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
63
+ datefmt="%Y-%m-%d %H:%M:%S",
64
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
65
+ stream=output_file,
66
+ )
67
+ logger = logging.getLogger("fairseq_cli.generate")
68
+
69
+ utils.import_user_module(cfg.common)
70
+
71
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
72
+ cfg.dataset.max_tokens = 12000
73
+ logger.info(cfg)
74
+
75
+ # Fix seed for stochastic decoding
76
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
77
+ np.random.seed(cfg.common.seed)
78
+ utils.set_torch_seed(cfg.common.seed)
79
+
80
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
81
+
82
+ # Load dataset splits
83
+ task = tasks.setup_task(cfg.task)
84
+
85
+ # Set dictionaries
86
+ try:
87
+ src_dict = getattr(task, "source_dictionary", None)
88
+ except NotImplementedError:
89
+ src_dict = None
90
+ tgt_dict = task.target_dictionary
91
+
92
+ overrides = ast.literal_eval(cfg.common_eval.model_overrides)
93
+
94
+ # Load ensemble
95
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
96
+ models, saved_cfg = checkpoint_utils.load_model_ensemble(
97
+ utils.split_paths(cfg.common_eval.path),
98
+ arg_overrides=overrides,
99
+ task=task,
100
+ suffix=cfg.checkpoint.checkpoint_suffix,
101
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
102
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
103
+ )
104
+
105
+ # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
106
+ task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
107
+
108
+ if cfg.generation.lm_path is not None:
109
+ overrides["data"] = cfg.task.data
110
+
111
+ try:
112
+ lms, _ = checkpoint_utils.load_model_ensemble(
113
+ [cfg.generation.lm_path], arg_overrides=overrides, task=None
114
+ )
115
+ except:
116
+ logger.warning(
117
+ f"Failed to load language model! Please make sure that the language model dict is the same "
118
+ f"as target dict and is located in the data dir ({cfg.task.data})"
119
+ )
120
+ raise
121
+
122
+ assert len(lms) == 1
123
+ else:
124
+ lms = [None]
125
+
126
+ # Optimize ensemble for generation
127
+ for model in chain(models, lms):
128
+ if model is None:
129
+ continue
130
+ if cfg.common.fp16:
131
+ model.half()
132
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
133
+ model.cuda()
134
+ model.prepare_for_inference_(cfg)
135
+
136
+ # Load alignment dictionary for unknown word replacement
137
+ # (None if no unknown word replacement, empty if no path to align dictionary)
138
+ align_dict = utils.load_align_dict(cfg.generation.replace_unk)
139
+
140
+ # Load dataset (possibly sharded)
141
+ itr = task.get_batch_iterator(
142
+ dataset=task.dataset(cfg.dataset.gen_subset),
143
+ max_tokens=cfg.dataset.max_tokens,
144
+ max_sentences=cfg.dataset.batch_size,
145
+ max_positions=utils.resolve_max_positions(
146
+ task.max_positions(), *[m.max_positions() for m in models]
147
+ ),
148
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
149
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
150
+ seed=cfg.common.seed,
151
+ num_shards=cfg.distributed_training.distributed_world_size,
152
+ shard_id=cfg.distributed_training.distributed_rank,
153
+ num_workers=cfg.dataset.num_workers,
154
+ data_buffer_size=cfg.dataset.data_buffer_size,
155
+ ).next_epoch_itr(shuffle=False)
156
+ progress = progress_bar.progress_bar(
157
+ itr,
158
+ log_format=cfg.common.log_format,
159
+ log_interval=cfg.common.log_interval,
160
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
161
+ )
162
+
163
+ # Initialize generator
164
+ gen_timer = StopwatchMeter()
165
+
166
+ extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
167
+ generator = task.build_generator(
168
+ models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs
169
+ )
170
+
171
+ # Handle tokenization and BPE
172
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
173
+ bpe = task.build_bpe(cfg.bpe)
174
+
175
+ def decode_fn(x):
176
+ if bpe is not None:
177
+ x = bpe.decode(x)
178
+ if tokenizer is not None:
179
+ x = tokenizer.decode(x)
180
+ return x
181
+
182
+ scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
183
+
184
+ num_sentences = 0
185
+ has_target = True
186
+ wps_meter = TimeMeter()
187
+ for sample in progress:
188
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
189
+ if "net_input" not in sample:
190
+ continue
191
+
192
+ prefix_tokens = None
193
+ if cfg.generation.prefix_size > 0:
194
+ prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
195
+
196
+ constraints = None
197
+ if "constraints" in sample:
198
+ constraints = sample["constraints"]
199
+
200
+ gen_timer.start()
201
+ hypos = task.inference_step(
202
+ generator,
203
+ models,
204
+ sample,
205
+ prefix_tokens=prefix_tokens,
206
+ constraints=constraints,
207
+ )
208
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
209
+ gen_timer.stop(num_generated_tokens)
210
+
211
+ for i, sample_id in enumerate(sample["id"].tolist()):
212
+ has_target = sample["target"] is not None
213
+
214
+ # Remove padding
215
+ if "src_tokens" in sample["net_input"]:
216
+ src_tokens = utils.strip_pad(
217
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
218
+ )
219
+ else:
220
+ src_tokens = None
221
+
222
+ target_tokens = None
223
+ if has_target:
224
+ target_tokens = (
225
+ utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
226
+ )
227
+
228
+ # Either retrieve the original sentences or regenerate them from tokens.
229
+ if align_dict is not None:
230
+ src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
231
+ sample_id
232
+ )
233
+ target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
234
+ sample_id
235
+ )
236
+ else:
237
+ if src_dict is not None:
238
+ src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
239
+ else:
240
+ src_str = ""
241
+ if has_target:
242
+ target_str = tgt_dict.string(
243
+ target_tokens,
244
+ cfg.common_eval.post_process,
245
+ escape_unk=True,
246
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(
247
+ generator
248
+ ),
249
+ )
250
+
251
+ src_str = decode_fn(src_str)
252
+ if has_target:
253
+ target_str = decode_fn(target_str)
254
+
255
+ if not cfg.common_eval.quiet:
256
+ if src_dict is not None:
257
+ print("S-{}\t{}".format(sample_id, src_str), file=output_file)
258
+ if has_target:
259
+ print("T-{}\t{}".format(sample_id, target_str), file=output_file)
260
+
261
+ # Process top predictions
262
+ for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
263
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
264
+ hypo_tokens=hypo["tokens"].int().cpu(),
265
+ src_str=src_str,
266
+ alignment=hypo["alignment"],
267
+ align_dict=align_dict,
268
+ tgt_dict=tgt_dict,
269
+ remove_bpe=cfg.common_eval.post_process,
270
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
271
+ )
272
+ detok_hypo_str = decode_fn(hypo_str)
273
+ if not cfg.common_eval.quiet:
274
+ score = hypo["score"] / math.log(2) # convert to base 2
275
+ # original hypothesis (after tokenization and BPE)
276
+ print(
277
+ "H-{}\t{}\t{}".format(sample_id, score, hypo_str),
278
+ file=output_file,
279
+ )
280
+ # detokenized hypothesis
281
+ print(
282
+ "D-{}\t{}\t{}".format(sample_id, score, detok_hypo_str),
283
+ file=output_file,
284
+ )
285
+ print(
286
+ "P-{}\t{}".format(
287
+ sample_id,
288
+ " ".join(
289
+ map(
290
+ lambda x: "{:.4f}".format(x),
291
+ # convert from base e to base 2
292
+ hypo["positional_scores"]
293
+ .div_(math.log(2))
294
+ .tolist(),
295
+ )
296
+ ),
297
+ ),
298
+ file=output_file,
299
+ )
300
+
301
+ if cfg.generation.print_alignment == "hard":
302
+ print(
303
+ "A-{}\t{}".format(
304
+ sample_id,
305
+ " ".join(
306
+ [
307
+ "{}-{}".format(src_idx, tgt_idx)
308
+ for src_idx, tgt_idx in alignment
309
+ ]
310
+ ),
311
+ ),
312
+ file=output_file,
313
+ )
314
+ if cfg.generation.print_alignment == "soft":
315
+ print(
316
+ "A-{}\t{}".format(
317
+ sample_id,
318
+ " ".join(
319
+ [",".join(src_probs) for src_probs in alignment]
320
+ ),
321
+ ),
322
+ file=output_file,
323
+ )
324
+
325
+ if cfg.generation.print_step:
326
+ print(
327
+ "I-{}\t{}".format(sample_id, hypo["steps"]),
328
+ file=output_file,
329
+ )
330
+
331
+ if cfg.generation.retain_iter_history:
332
+ for step, h in enumerate(hypo["history"]):
333
+ _, h_str, _ = utils.post_process_prediction(
334
+ hypo_tokens=h["tokens"].int().cpu(),
335
+ src_str=src_str,
336
+ alignment=None,
337
+ align_dict=None,
338
+ tgt_dict=tgt_dict,
339
+ remove_bpe=None,
340
+ )
341
+ print(
342
+ "E-{}_{}\t{}".format(sample_id, step, h_str),
343
+ file=output_file,
344
+ )
345
+
346
+ # Score only the top hypothesis
347
+ if has_target and j == 0:
348
+ if (
349
+ align_dict is not None
350
+ or cfg.common_eval.post_process is not None
351
+ ):
352
+ # Convert back to tokens for evaluation with unk replacement and/or without BPE
353
+ target_tokens = tgt_dict.encode_line(
354
+ target_str, add_if_not_exist=True
355
+ )
356
+ hypo_tokens = tgt_dict.encode_line(
357
+ detok_hypo_str, add_if_not_exist=True
358
+ )
359
+ if hasattr(scorer, "add_string"):
360
+ scorer.add_string(target_str, detok_hypo_str)
361
+ else:
362
+ scorer.add(target_tokens, hypo_tokens)
363
+
364
+ wps_meter.update(num_generated_tokens)
365
+ progress.log({"wps": round(wps_meter.avg)})
366
+ num_sentences += (
367
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
368
+ )
369
+
370
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
371
+ logger.info(
372
+ "Translated {:,} sentences ({:,} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format(
373
+ num_sentences,
374
+ gen_timer.n,
375
+ gen_timer.sum,
376
+ num_sentences / gen_timer.sum,
377
+ 1.0 / gen_timer.avg,
378
+ )
379
+ )
380
+ if has_target:
381
+ if cfg.bpe and not cfg.generation.sacrebleu:
382
+ if cfg.common_eval.post_process:
383
+ logger.warning(
384
+ "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
385
+ )
386
+ else:
387
+ logger.warning(
388
+ "If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization"
389
+ )
390
+ # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
391
+ print(
392
+ "Generate {} with beam={}: {}".format(
393
+ cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
394
+ ),
395
+ file=output_file,
396
+ )
397
+
398
+ return scorer
399
+
400
+
401
+ def cli_main():
402
+ parser = options.get_generation_parser()
403
+ # TODO: replace this workaround with refactoring of `AudioPretraining`
404
+ parser.add_argument(
405
+ "--arch",
406
+ "-a",
407
+ metavar="ARCH",
408
+ default="wav2vec2",
409
+ help="Model architecture. For constructing tasks that rely on "
410
+ "model args (e.g. `AudioPretraining`)",
411
+ )
412
+ args = options.parse_args_and_arch(parser)
413
+ main(args)
414
+
415
+
416
+ if __name__ == "__main__":
417
+ cli_main()
fairseq/fairseq_cli/hydra_train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+
10
+ import hydra
11
+ import torch
12
+ from hydra.core.hydra_config import HydraConfig
13
+ from omegaconf import OmegaConf, open_dict
14
+
15
+ from fairseq import distributed_utils, metrics
16
+ from fairseq.dataclass.configs import FairseqConfig
17
+ from fairseq.dataclass.initialize import add_defaults, hydra_init
18
+ from fairseq.dataclass.utils import omegaconf_no_object_check
19
+ from fairseq.utils import reset_logging
20
+ from fairseq_cli.train import main as pre_main
21
+
22
+ logger = logging.getLogger("fairseq_cli.hydra_train")
23
+
24
+
25
+ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
26
+ def hydra_main(cfg: FairseqConfig) -> float:
27
+ _hydra_main(cfg)
28
+
29
+
30
+ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
31
+ add_defaults(cfg)
32
+
33
+ if cfg.common.reset_logging:
34
+ reset_logging() # Hydra hijacks logging, fix that
35
+ else:
36
+ # check if directly called or called through hydra_main
37
+ if HydraConfig.initialized():
38
+ with open_dict(cfg):
39
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
40
+ cfg.job_logging_cfg = OmegaConf.to_container(
41
+ HydraConfig.get().job_logging, resolve=True
42
+ )
43
+
44
+ with omegaconf_no_object_check():
45
+ cfg = OmegaConf.create(
46
+ OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
47
+ )
48
+ OmegaConf.set_struct(cfg, True)
49
+
50
+ try:
51
+ if cfg.common.profile:
52
+ with torch.cuda.profiler.profile():
53
+ with torch.autograd.profiler.emit_nvtx():
54
+ distributed_utils.call_main(cfg, pre_main, **kwargs)
55
+ else:
56
+ distributed_utils.call_main(cfg, pre_main, **kwargs)
57
+ except BaseException as e:
58
+ if not cfg.common.suppress_crashes:
59
+ raise
60
+ else:
61
+ logger.error("Crashed! " + str(e))
62
+
63
+ # get best val and return - useful for sweepers
64
+ try:
65
+ best_val = metrics.get_smoothed_value(
66
+ "valid", cfg.checkpoint.best_checkpoint_metric
67
+ )
68
+ except:
69
+ best_val = None
70
+
71
+ if best_val is None:
72
+ best_val = float("inf")
73
+
74
+ return best_val
75
+
76
+
77
+ def cli_main():
78
+ try:
79
+ from hydra._internal.utils import get_args
80
+
81
+ cfg_name = get_args().config_name or "config"
82
+ except:
83
+ logger.warning("Failed to get config name from hydra args")
84
+ cfg_name = "config"
85
+
86
+ hydra_init(cfg_name)
87
+ hydra_main()
88
+
89
+
90
+ if __name__ == "__main__":
91
+ cli_main()
fairseq/fairseq_cli/hydra_validate.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ from itertools import chain
11
+
12
+ import torch
13
+ from hydra.core.hydra_config import HydraConfig
14
+ from omegaconf import OmegaConf, open_dict
15
+ import hydra
16
+
17
+ from fairseq import checkpoint_utils, distributed_utils, utils
18
+ from fairseq.dataclass.configs import FairseqConfig
19
+ from fairseq.dataclass.initialize import add_defaults, hydra_init
20
+ from fairseq.dataclass.utils import omegaconf_no_object_check
21
+ from fairseq.distributed import utils as distributed_utils
22
+ from fairseq.logging import metrics, progress_bar
23
+ from fairseq.utils import reset_logging
24
+
25
+ logging.basicConfig(
26
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
29
+ stream=sys.stdout,
30
+ )
31
+ logger = logging.getLogger("fairseq_cli.validate")
32
+
33
+
34
+ @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config")
35
+ def hydra_main(cfg: FairseqConfig) -> float:
36
+ return _hydra_main(cfg)
37
+
38
+
39
+ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
40
+ add_defaults(cfg)
41
+
42
+ if cfg.common.reset_logging:
43
+ reset_logging() # Hydra hijacks logging, fix that
44
+ else:
45
+ # check if directly called or called through hydra_main
46
+ if HydraConfig.initialized():
47
+ with open_dict(cfg):
48
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
49
+ cfg.job_logging_cfg = OmegaConf.to_container(
50
+ HydraConfig.get().job_logging, resolve=True
51
+ )
52
+
53
+ with omegaconf_no_object_check():
54
+ cfg = OmegaConf.create(
55
+ OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)
56
+ )
57
+ OmegaConf.set_struct(cfg, True)
58
+
59
+ assert (
60
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
61
+ ), "Must specify batch size either with --max-tokens or --batch-size"
62
+
63
+ distributed_utils.call_main(cfg, validate, **kwargs)
64
+
65
+
66
+ def validate(cfg):
67
+ utils.import_user_module(cfg.common)
68
+
69
+ use_fp16 = cfg.common.fp16
70
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
71
+
72
+ if use_cuda:
73
+ torch.cuda.set_device(cfg.distributed_training.device_id)
74
+
75
+ if cfg.distributed_training.distributed_world_size > 1:
76
+ data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
77
+ data_parallel_rank = distributed_utils.get_data_parallel_rank()
78
+ else:
79
+ data_parallel_world_size = 1
80
+ data_parallel_rank = 0
81
+
82
+ overrides = {"task": {"data": cfg.task.data}}
83
+
84
+ # Load ensemble
85
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
86
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
87
+ [cfg.common_eval.path],
88
+ arg_overrides=overrides,
89
+ suffix=cfg.checkpoint.checkpoint_suffix,
90
+ )
91
+ model = models[0]
92
+
93
+ # Move models to GPU
94
+ for model in models:
95
+ model.eval()
96
+ if use_fp16:
97
+ model.half()
98
+ if use_cuda:
99
+ model.cuda()
100
+
101
+ # Print args
102
+ logger.info(saved_cfg)
103
+
104
+ # Build criterion
105
+ criterion = task.build_criterion(saved_cfg.criterion, from_checkpoint=True)
106
+ criterion.eval()
107
+
108
+ for subset in cfg.dataset.valid_subset.split(","):
109
+ try:
110
+ task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
111
+ dataset = task.dataset(subset)
112
+ except KeyError:
113
+ raise Exception("Cannot find dataset: " + subset)
114
+
115
+ # Initialize data iterator
116
+ itr = task.get_batch_iterator(
117
+ dataset=dataset,
118
+ max_tokens=cfg.dataset.max_tokens,
119
+ max_sentences=cfg.dataset.batch_size,
120
+ max_positions=utils.resolve_max_positions(
121
+ task.max_positions(),
122
+ *[m.max_positions() for m in models],
123
+ ),
124
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
125
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
126
+ seed=cfg.common.seed,
127
+ num_shards=data_parallel_world_size,
128
+ shard_id=data_parallel_rank,
129
+ num_workers=cfg.dataset.num_workers,
130
+ data_buffer_size=cfg.dataset.data_buffer_size,
131
+ ).next_epoch_itr(shuffle=False)
132
+ progress = progress_bar.progress_bar(
133
+ itr,
134
+ log_format=cfg.common.log_format,
135
+ log_interval=cfg.common.log_interval,
136
+ prefix=f"valid on '{subset}' subset",
137
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
138
+ )
139
+
140
+ def apply_half(t):
141
+ if t.dtype is torch.float32:
142
+ return t.to(dtype=torch.half)
143
+ return t
144
+
145
+ log_outputs = []
146
+ for i, sample in enumerate(progress):
147
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
148
+
149
+ if use_fp16:
150
+ sample = utils.apply_to_sample(apply_half, sample)
151
+
152
+ _loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
153
+ with metrics.aggregate() as agg:
154
+ task.reduce_metrics([log_output], criterion)
155
+ progress.log(agg.get_smoothed_values(), step=i)
156
+ # progress.log(log_output, step=i) from vision
157
+ log_outputs.append(log_output)
158
+
159
+ if data_parallel_world_size > 1:
160
+ log_outputs = distributed_utils.all_gather_list(
161
+ log_outputs,
162
+ max_size=cfg.common.all_gather_list_size,
163
+ group=distributed_utils.get_data_parallel_group(),
164
+ )
165
+ log_outputs = list(chain.from_iterable(log_outputs))
166
+
167
+ with metrics.aggregate() as agg:
168
+ task.reduce_metrics(log_outputs, criterion)
169
+ log_output = agg.get_smoothed_values()
170
+
171
+ progress.print(log_output, tag=subset, step=i)
172
+
173
+
174
+ def cli_main():
175
+ try:
176
+ from hydra._internal.utils import get_args
177
+
178
+ cfg_name = get_args().config_name or "config"
179
+ except:
180
+ logger.warning("Failed to get config name from hydra args")
181
+ cfg_name = "config"
182
+
183
+ hydra_init(cfg_name)
184
+ hydra_main()
185
+
186
+
187
+ if __name__ == "__main__":
188
+ cli_main()
fairseq/fairseq_cli/interactive.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate raw text with a trained model. Batches data on-the-fly.
8
+ """
9
+
10
+ import ast
11
+ import fileinput
12
+ import logging
13
+ import math
14
+ import os
15
+ import sys
16
+ import time
17
+ from argparse import Namespace
18
+ from collections import namedtuple
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
24
+ from fairseq.dataclass.configs import FairseqConfig
25
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
26
+ from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
27
+ from fairseq_cli.generate import get_symbols_to_strip_from_output
28
+
29
+ logging.basicConfig(
30
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
31
+ datefmt="%Y-%m-%d %H:%M:%S",
32
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
33
+ stream=sys.stdout,
34
+ )
35
+ logger = logging.getLogger("fairseq_cli.interactive")
36
+
37
+
38
+ Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
39
+ Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
40
+
41
+
42
+ def buffered_read(input, buffer_size):
43
+ buffer = []
44
+ with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
45
+ for src_str in h:
46
+ buffer.append(src_str.strip())
47
+ if len(buffer) >= buffer_size:
48
+ yield buffer
49
+ buffer = []
50
+
51
+ if len(buffer) > 0:
52
+ yield buffer
53
+
54
+
55
+ def make_batches(lines, cfg, task, max_positions, encode_fn):
56
+ def encode_fn_target(x):
57
+ return encode_fn(x)
58
+
59
+ if cfg.generation.constraints:
60
+ # Strip (tab-delimited) contraints, if present, from input lines,
61
+ # store them in batch_constraints
62
+ batch_constraints = [list() for _ in lines]
63
+ for i, line in enumerate(lines):
64
+ if "\t" in line:
65
+ lines[i], *batch_constraints[i] = line.split("\t")
66
+
67
+ # Convert each List[str] to List[Tensor]
68
+ for i, constraint_list in enumerate(batch_constraints):
69
+ batch_constraints[i] = [
70
+ task.target_dictionary.encode_line(
71
+ encode_fn_target(constraint),
72
+ append_eos=False,
73
+ add_if_not_exist=False,
74
+ )
75
+ for constraint in constraint_list
76
+ ]
77
+
78
+ if cfg.generation.constraints:
79
+ constraints_tensor = pack_constraints(batch_constraints)
80
+ else:
81
+ constraints_tensor = None
82
+
83
+ tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
84
+
85
+ itr = task.get_batch_iterator(
86
+ dataset=task.build_dataset_for_inference(
87
+ tokens, lengths, constraints=constraints_tensor
88
+ ),
89
+ max_tokens=cfg.dataset.max_tokens,
90
+ max_sentences=cfg.dataset.batch_size,
91
+ max_positions=max_positions,
92
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
93
+ ).next_epoch_itr(shuffle=False)
94
+ for batch in itr:
95
+ ids = batch["id"]
96
+ src_tokens = batch["net_input"]["src_tokens"]
97
+ src_lengths = batch["net_input"]["src_lengths"]
98
+ constraints = batch.get("constraints", None)
99
+
100
+ yield Batch(
101
+ ids=ids,
102
+ src_tokens=src_tokens,
103
+ src_lengths=src_lengths,
104
+ constraints=constraints,
105
+ )
106
+
107
+
108
+ def main(cfg: FairseqConfig):
109
+ if isinstance(cfg, Namespace):
110
+ cfg = convert_namespace_to_omegaconf(cfg)
111
+
112
+ start_time = time.time()
113
+ total_translate_time = 0
114
+
115
+ utils.import_user_module(cfg.common)
116
+
117
+ if cfg.interactive.buffer_size < 1:
118
+ cfg.interactive.buffer_size = 1
119
+ if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
120
+ cfg.dataset.batch_size = 1
121
+
122
+ assert (
123
+ not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
124
+ ), "--sampling requires --nbest to be equal to --beam"
125
+ assert (
126
+ not cfg.dataset.batch_size
127
+ or cfg.dataset.batch_size <= cfg.interactive.buffer_size
128
+ ), "--batch-size cannot be larger than --buffer-size"
129
+
130
+ logger.info(cfg)
131
+
132
+ # Fix seed for stochastic decoding
133
+ if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
134
+ np.random.seed(cfg.common.seed)
135
+ utils.set_torch_seed(cfg.common.seed)
136
+
137
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
138
+
139
+ # Setup task, e.g., translation
140
+ task = tasks.setup_task(cfg.task)
141
+
142
+ # Load ensemble
143
+ overrides = ast.literal_eval(cfg.common_eval.model_overrides)
144
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
145
+ models, _model_args = checkpoint_utils.load_model_ensemble(
146
+ utils.split_paths(cfg.common_eval.path),
147
+ arg_overrides=overrides,
148
+ task=task,
149
+ suffix=cfg.checkpoint.checkpoint_suffix,
150
+ strict=(cfg.checkpoint.checkpoint_shard_count == 1),
151
+ num_shards=cfg.checkpoint.checkpoint_shard_count,
152
+ )
153
+
154
+ # Set dictionaries
155
+ src_dict = task.source_dictionary
156
+ tgt_dict = task.target_dictionary
157
+
158
+ # Optimize ensemble for generation
159
+ for model in models:
160
+ if model is None:
161
+ continue
162
+ if cfg.common.fp16:
163
+ model.half()
164
+ if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
165
+ model.cuda()
166
+ model.prepare_for_inference_(cfg)
167
+
168
+ # Initialize generator
169
+ generator = task.build_generator(models, cfg.generation)
170
+
171
+ # Handle tokenization and BPE
172
+ tokenizer = task.build_tokenizer(cfg.tokenizer)
173
+ bpe = task.build_bpe(cfg.bpe)
174
+
175
+ def encode_fn(x):
176
+ if tokenizer is not None:
177
+ x = tokenizer.encode(x)
178
+ if bpe is not None:
179
+ x = bpe.encode(x)
180
+ return x
181
+
182
+ def decode_fn(x):
183
+ if bpe is not None:
184
+ x = bpe.decode(x)
185
+ if tokenizer is not None:
186
+ x = tokenizer.decode(x)
187
+ return x
188
+
189
+ # Load alignment dictionary for unknown word replacement
190
+ # (None if no unknown word replacement, empty if no path to align dictionary)
191
+ align_dict = utils.load_align_dict(cfg.generation.replace_unk)
192
+
193
+ max_positions = utils.resolve_max_positions(
194
+ task.max_positions(), *[model.max_positions() for model in models]
195
+ )
196
+
197
+ if cfg.generation.constraints:
198
+ logger.warning(
199
+ "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
200
+ )
201
+
202
+ if cfg.interactive.buffer_size > 1:
203
+ logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
204
+ logger.info("NOTE: hypothesis and token scores are output in base 2")
205
+ logger.info("Type the input sentence and press return:")
206
+ start_id = 0
207
+ for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
208
+ results = []
209
+ for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
210
+ bsz = batch.src_tokens.size(0)
211
+ src_tokens = batch.src_tokens
212
+ src_lengths = batch.src_lengths
213
+ constraints = batch.constraints
214
+ if use_cuda:
215
+ src_tokens = src_tokens.cuda()
216
+ src_lengths = src_lengths.cuda()
217
+ if constraints is not None:
218
+ constraints = constraints.cuda()
219
+
220
+ sample = {
221
+ "net_input": {
222
+ "src_tokens": src_tokens,
223
+ "src_lengths": src_lengths,
224
+ },
225
+ }
226
+ translate_start_time = time.time()
227
+ translations = task.inference_step(
228
+ generator, models, sample, constraints=constraints
229
+ )
230
+ translate_time = time.time() - translate_start_time
231
+ total_translate_time += translate_time
232
+ list_constraints = [[] for _ in range(bsz)]
233
+ if cfg.generation.constraints:
234
+ list_constraints = [unpack_constraints(c) for c in constraints]
235
+ for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
236
+ src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
237
+ constraints = list_constraints[i]
238
+ results.append(
239
+ (
240
+ start_id + id,
241
+ src_tokens_i,
242
+ hypos,
243
+ {
244
+ "constraints": constraints,
245
+ "time": translate_time / len(translations),
246
+ },
247
+ )
248
+ )
249
+
250
+ # sort output to match input order
251
+ for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
252
+ src_str = ""
253
+ if src_dict is not None:
254
+ src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
255
+ print("S-{}\t{}".format(id_, src_str))
256
+ print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
257
+ for constraint in info["constraints"]:
258
+ print(
259
+ "C-{}\t{}".format(
260
+ id_,
261
+ tgt_dict.string(constraint, cfg.common_eval.post_process),
262
+ )
263
+ )
264
+
265
+ # Process top predictions
266
+ for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
267
+ hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
268
+ hypo_tokens=hypo["tokens"].int().cpu(),
269
+ src_str=src_str,
270
+ alignment=hypo["alignment"],
271
+ align_dict=align_dict,
272
+ tgt_dict=tgt_dict,
273
+ remove_bpe=cfg.common_eval.post_process,
274
+ extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
275
+ )
276
+ detok_hypo_str = decode_fn(hypo_str)
277
+ score = hypo["score"] / math.log(2) # convert to base 2
278
+ # original hypothesis (after tokenization and BPE)
279
+ print("H-{}\t{}\t{}".format(id_, score, hypo_str))
280
+ # detokenized hypothesis
281
+ print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str))
282
+ print(
283
+ "P-{}\t{}".format(
284
+ id_,
285
+ " ".join(
286
+ map(
287
+ lambda x: "{:.4f}".format(x),
288
+ # convert from base e to base 2
289
+ hypo["positional_scores"].div_(math.log(2)).tolist(),
290
+ )
291
+ ),
292
+ )
293
+ )
294
+ if cfg.generation.print_alignment:
295
+ alignment_str = " ".join(
296
+ ["{}-{}".format(src, tgt) for src, tgt in alignment]
297
+ )
298
+ print("A-{}\t{}".format(id_, alignment_str))
299
+
300
+ # update running id_ counter
301
+ start_id += len(inputs)
302
+
303
+ logger.info(
304
+ "Total time: {:.3f} seconds; translation time: {:.3f}".format(
305
+ time.time() - start_time, total_translate_time
306
+ )
307
+ )
308
+
309
+
310
+ def cli_main():
311
+ parser = options.get_interactive_generation_parser()
312
+ args = options.parse_args_and_arch(parser)
313
+ distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ cli_main()
fairseq/fairseq_cli/preprocess.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Data pre-processing: build vocabularies and binarize training data.
8
+ """
9
+
10
+ import logging
11
+ import os
12
+ import shutil
13
+ import sys
14
+ import typing as tp
15
+ from argparse import Namespace
16
+ from itertools import zip_longest
17
+
18
+ from fairseq import options, tasks, utils
19
+ from fairseq.binarizer import (
20
+ AlignmentDatasetBinarizer,
21
+ FileBinarizer,
22
+ VocabularyDatasetBinarizer,
23
+ )
24
+ from fairseq.data import Dictionary
25
+
26
+ logging.basicConfig(
27
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
30
+ stream=sys.stdout,
31
+ )
32
+ logger = logging.getLogger("fairseq_cli.preprocess")
33
+
34
+ #####################################################################
35
+ # file name tools
36
+ #####################################################################
37
+
38
+
39
+ def _train_path(lang, trainpref):
40
+ return "{}{}".format(trainpref, ("." + lang) if lang else "")
41
+
42
+
43
+ def _file_name(prefix, lang):
44
+ fname = prefix
45
+ if lang is not None:
46
+ fname += ".{lang}".format(lang=lang)
47
+ return fname
48
+
49
+
50
+ def _dest_path(prefix, lang, destdir):
51
+ return os.path.join(destdir, _file_name(prefix, lang))
52
+
53
+
54
+ def _dict_path(lang, destdir):
55
+ return _dest_path("dict", lang, destdir) + ".txt"
56
+
57
+
58
+ def dataset_dest_prefix(args, output_prefix, lang):
59
+ base = os.path.join(args.destdir, output_prefix)
60
+ if lang is not None:
61
+ lang_part = f".{args.source_lang}-{args.target_lang}.{lang}"
62
+ elif args.only_source:
63
+ lang_part = ""
64
+ else:
65
+ lang_part = f".{args.source_lang}-{args.target_lang}"
66
+
67
+ return "{}{}".format(base, lang_part)
68
+
69
+
70
+ def dataset_dest_file(args, output_prefix, lang, extension):
71
+ return "{}.{}".format(dataset_dest_prefix(args, output_prefix, lang), extension)
72
+
73
+
74
+ #####################################################################
75
+ # dictionary tools
76
+ #####################################################################
77
+
78
+
79
+ def _build_dictionary(
80
+ filenames,
81
+ task,
82
+ args,
83
+ src=False,
84
+ tgt=False,
85
+ ):
86
+ assert src ^ tgt
87
+ return task.build_dictionary(
88
+ filenames,
89
+ workers=args.workers,
90
+ threshold=args.thresholdsrc if src else args.thresholdtgt,
91
+ nwords=args.nwordssrc if src else args.nwordstgt,
92
+ padding_factor=args.padding_factor,
93
+ )
94
+
95
+
96
+ #####################################################################
97
+ # bin file creation logic
98
+ #####################################################################
99
+
100
+
101
+ def _make_binary_dataset(
102
+ vocab: Dictionary,
103
+ input_prefix: str,
104
+ output_prefix: str,
105
+ lang: tp.Optional[str],
106
+ num_workers: int,
107
+ args: Namespace,
108
+ ):
109
+ logger.info("[{}] Dictionary: {} types".format(lang, len(vocab)))
110
+
111
+ binarizer = VocabularyDatasetBinarizer(
112
+ vocab,
113
+ append_eos=True,
114
+ )
115
+
116
+ input_file = "{}{}".format(input_prefix, ("." + lang) if lang is not None else "")
117
+ full_output_prefix = dataset_dest_prefix(args, output_prefix, lang)
118
+
119
+ final_summary = FileBinarizer.multiprocess_dataset(
120
+ input_file,
121
+ args.dataset_impl,
122
+ binarizer,
123
+ full_output_prefix,
124
+ vocab_size=len(vocab),
125
+ num_workers=num_workers,
126
+ )
127
+
128
+ logger.info(f"[{lang}] {input_file}: {final_summary} (by {vocab.unk_word})")
129
+
130
+
131
+ def _make_binary_alignment_dataset(
132
+ input_prefix: str, output_prefix: str, num_workers: int, args: Namespace
133
+ ):
134
+
135
+ binarizer = AlignmentDatasetBinarizer(utils.parse_alignment)
136
+
137
+ input_file = input_prefix
138
+ full_output_prefix = dataset_dest_prefix(args, output_prefix, lang=None)
139
+
140
+ final_summary = FileBinarizer.multiprocess_dataset(
141
+ input_file,
142
+ args.dataset_impl,
143
+ binarizer,
144
+ full_output_prefix,
145
+ vocab_size=None,
146
+ num_workers=num_workers,
147
+ )
148
+
149
+ logger.info(
150
+ "[alignments] {}: parsed {} alignments".format(
151
+ input_file, final_summary.num_seq
152
+ )
153
+ )
154
+
155
+
156
+ #####################################################################
157
+ # routing logic
158
+ #####################################################################
159
+
160
+
161
+ def _make_dataset(
162
+ vocab: Dictionary,
163
+ input_prefix: str,
164
+ output_prefix: str,
165
+ lang: tp.Optional[str],
166
+ args: Namespace,
167
+ num_workers: int,
168
+ ):
169
+ if args.dataset_impl == "raw":
170
+ # Copy original text file to destination folder
171
+ output_text_file = _dest_path(
172
+ output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
173
+ lang,
174
+ args.destdir,
175
+ )
176
+ shutil.copyfile(_file_name(input_prefix, lang), output_text_file)
177
+ else:
178
+ _make_binary_dataset(
179
+ vocab, input_prefix, output_prefix, lang, num_workers, args
180
+ )
181
+
182
+
183
+ def _make_all(lang, vocab, args):
184
+ if args.trainpref:
185
+ _make_dataset(
186
+ vocab, args.trainpref, "train", lang, args=args, num_workers=args.workers
187
+ )
188
+ if args.validpref:
189
+ for k, validpref in enumerate(args.validpref.split(",")):
190
+ outprefix = "valid{}".format(k) if k > 0 else "valid"
191
+ _make_dataset(
192
+ vocab, validpref, outprefix, lang, args=args, num_workers=args.workers
193
+ )
194
+ if args.testpref:
195
+ for k, testpref in enumerate(args.testpref.split(",")):
196
+ outprefix = "test{}".format(k) if k > 0 else "test"
197
+ _make_dataset(
198
+ vocab, testpref, outprefix, lang, args=args, num_workers=args.workers
199
+ )
200
+
201
+
202
+ def _make_all_alignments(args):
203
+ if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix):
204
+ _make_binary_alignment_dataset(
205
+ args.trainpref + "." + args.align_suffix,
206
+ "train.align",
207
+ num_workers=args.workers,
208
+ args=args,
209
+ )
210
+ if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix):
211
+ _make_binary_alignment_dataset(
212
+ args.validpref + "." + args.align_suffix,
213
+ "valid.align",
214
+ num_workers=args.workers,
215
+ args=args,
216
+ )
217
+ if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix):
218
+ _make_binary_alignment_dataset(
219
+ args.testpref + "." + args.align_suffix,
220
+ "test.align",
221
+ num_workers=args.workers,
222
+ args=args,
223
+ )
224
+
225
+
226
+ #####################################################################
227
+ # align
228
+ #####################################################################
229
+
230
+
231
+ def _align_files(args, src_dict, tgt_dict):
232
+ assert args.trainpref, "--trainpref must be set if --alignfile is specified"
233
+ src_file_name = _train_path(args.source_lang, args.trainpref)
234
+ tgt_file_name = _train_path(args.target_lang, args.trainpref)
235
+ freq_map = {}
236
+ with open(args.alignfile, "r", encoding="utf-8") as align_file:
237
+ with open(src_file_name, "r", encoding="utf-8") as src_file:
238
+ with open(tgt_file_name, "r", encoding="utf-8") as tgt_file:
239
+ for a, s, t in zip_longest(align_file, src_file, tgt_file):
240
+ si = src_dict.encode_line(s, add_if_not_exist=False)
241
+ ti = tgt_dict.encode_line(t, add_if_not_exist=False)
242
+ ai = list(map(lambda x: tuple(x.split("-")), a.split()))
243
+ for sai, tai in ai:
244
+ srcidx = si[int(sai)]
245
+ tgtidx = ti[int(tai)]
246
+ if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
247
+ assert srcidx != src_dict.pad()
248
+ assert srcidx != src_dict.eos()
249
+ assert tgtidx != tgt_dict.pad()
250
+ assert tgtidx != tgt_dict.eos()
251
+ if srcidx not in freq_map:
252
+ freq_map[srcidx] = {}
253
+ if tgtidx not in freq_map[srcidx]:
254
+ freq_map[srcidx][tgtidx] = 1
255
+ else:
256
+ freq_map[srcidx][tgtidx] += 1
257
+ align_dict = {}
258
+ for srcidx in freq_map.keys():
259
+ align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
260
+ with open(
261
+ os.path.join(
262
+ args.destdir,
263
+ "alignment.{}-{}.txt".format(args.source_lang, args.target_lang),
264
+ ),
265
+ "w",
266
+ encoding="utf-8",
267
+ ) as f:
268
+ for k, v in align_dict.items():
269
+ print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
270
+
271
+
272
+ #####################################################################
273
+ # MAIN
274
+ #####################################################################
275
+
276
+
277
+ def main(args):
278
+ # setup some basic things
279
+ utils.import_user_module(args)
280
+
281
+ os.makedirs(args.destdir, exist_ok=True)
282
+
283
+ logger.addHandler(
284
+ logging.FileHandler(
285
+ filename=os.path.join(args.destdir, "preprocess.log"),
286
+ )
287
+ )
288
+ logger.info(args)
289
+
290
+ assert (
291
+ args.dataset_impl != "huffman"
292
+ ), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly."
293
+
294
+ # build dictionaries
295
+
296
+ target = not args.only_source
297
+
298
+ if not args.srcdict and os.path.exists(_dict_path(args.source_lang, args.destdir)):
299
+ raise FileExistsError(_dict_path(args.source_lang, args.destdir))
300
+
301
+ if (
302
+ target
303
+ and not args.tgtdict
304
+ and os.path.exists(_dict_path(args.target_lang, args.destdir))
305
+ ):
306
+ raise FileExistsError(_dict_path(args.target_lang, args.destdir))
307
+
308
+ task = tasks.get_task(args.task)
309
+
310
+ if args.joined_dictionary:
311
+ assert (
312
+ not args.srcdict or not args.tgtdict
313
+ ), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
314
+
315
+ if args.srcdict:
316
+ src_dict = task.load_dictionary(args.srcdict)
317
+ elif args.tgtdict:
318
+ src_dict = task.load_dictionary(args.tgtdict)
319
+ else:
320
+ assert (
321
+ args.trainpref
322
+ ), "--trainpref must be set if --srcdict is not specified"
323
+ src_dict = _build_dictionary(
324
+ {
325
+ _train_path(lang, args.trainpref)
326
+ for lang in [args.source_lang, args.target_lang]
327
+ },
328
+ task=task,
329
+ args=args,
330
+ src=True,
331
+ )
332
+ tgt_dict = src_dict
333
+ else:
334
+ if args.srcdict:
335
+ src_dict = task.load_dictionary(args.srcdict)
336
+ else:
337
+ assert (
338
+ args.trainpref
339
+ ), "--trainpref must be set if --srcdict is not specified"
340
+ src_dict = _build_dictionary(
341
+ [_train_path(args.source_lang, args.trainpref)],
342
+ task=task,
343
+ args=args,
344
+ src=True,
345
+ )
346
+
347
+ if target:
348
+ if args.tgtdict:
349
+ tgt_dict = task.load_dictionary(args.tgtdict)
350
+ else:
351
+ assert (
352
+ args.trainpref
353
+ ), "--trainpref must be set if --tgtdict is not specified"
354
+ tgt_dict = _build_dictionary(
355
+ [_train_path(args.target_lang, args.trainpref)],
356
+ task=task,
357
+ args=args,
358
+ tgt=True,
359
+ )
360
+ else:
361
+ tgt_dict = None
362
+
363
+ # save dictionaries
364
+
365
+ src_dict.save(_dict_path(args.source_lang, args.destdir))
366
+ if target and tgt_dict is not None:
367
+ tgt_dict.save(_dict_path(args.target_lang, args.destdir))
368
+
369
+ if args.dict_only:
370
+ return
371
+
372
+ _make_all(args.source_lang, src_dict, args)
373
+ if target:
374
+ _make_all(args.target_lang, tgt_dict, args)
375
+
376
+ # align the datasets if needed
377
+ if args.align_suffix:
378
+ _make_all_alignments(args)
379
+
380
+ logger.info("Wrote preprocessed data to {}".format(args.destdir))
381
+
382
+ if args.alignfile:
383
+ _align_files(args, src_dict=src_dict, tgt_dict=tgt_dict)
384
+
385
+
386
+ def cli_main():
387
+ parser = options.get_preprocessing_parser()
388
+ args = parser.parse_args()
389
+ main(args)
390
+
391
+
392
+ if __name__ == "__main__":
393
+ cli_main()
fairseq/fairseq_cli/score.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ BLEU scoring of generated translations against reference translations.
8
+ """
9
+
10
+ import argparse
11
+ import os
12
+ import sys
13
+
14
+ from fairseq.data import dictionary
15
+ from fairseq.scoring import bleu
16
+
17
+
18
+ def get_parser():
19
+ parser = argparse.ArgumentParser(
20
+ description="Command-line script for BLEU scoring."
21
+ )
22
+ # fmt: off
23
+ parser.add_argument('-s', '--sys', default='-', help='system output')
24
+ parser.add_argument('-r', '--ref', required=True, help='references')
25
+ parser.add_argument('-o', '--order', default=4, metavar='N',
26
+ type=int, help='consider ngrams up to this order')
27
+ parser.add_argument('--ignore-case', action='store_true',
28
+ help='case-insensitive scoring')
29
+ parser.add_argument('--sacrebleu', action='store_true',
30
+ help='score with sacrebleu')
31
+ parser.add_argument('--sentence-bleu', action='store_true',
32
+ help='report sentence-level BLEUs (i.e., with +1 smoothing)')
33
+ # fmt: on
34
+ return parser
35
+
36
+
37
+ def cli_main():
38
+ parser = get_parser()
39
+ args = parser.parse_args()
40
+ print(args)
41
+
42
+ assert args.sys == "-" or os.path.exists(
43
+ args.sys
44
+ ), "System output file {} does not exist".format(args.sys)
45
+ assert os.path.exists(args.ref), "Reference file {} does not exist".format(args.ref)
46
+
47
+ dict = dictionary.Dictionary()
48
+
49
+ def readlines(fd):
50
+ for line in fd.readlines():
51
+ if args.ignore_case:
52
+ yield line.lower()
53
+ else:
54
+ yield line
55
+
56
+ if args.sacrebleu:
57
+ import sacrebleu
58
+
59
+ def score(fdsys):
60
+ with open(args.ref) as fdref:
61
+ print(sacrebleu.corpus_bleu(fdsys, [fdref]).format())
62
+
63
+ elif args.sentence_bleu:
64
+
65
+ def score(fdsys):
66
+ with open(args.ref) as fdref:
67
+ scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
68
+ for i, (sys_tok, ref_tok) in enumerate(
69
+ zip(readlines(fdsys), readlines(fdref))
70
+ ):
71
+ scorer.reset(one_init=True)
72
+ sys_tok = dict.encode_line(sys_tok)
73
+ ref_tok = dict.encode_line(ref_tok)
74
+ scorer.add(ref_tok, sys_tok)
75
+ print(i, scorer.result_string(args.order))
76
+
77
+ else:
78
+
79
+ def score(fdsys):
80
+ with open(args.ref) as fdref:
81
+ scorer = bleu.Scorer(
82
+ bleu.BleuConfig(
83
+ pad=dict.pad(),
84
+ eos=dict.eos(),
85
+ unk=dict.unk(),
86
+ )
87
+ )
88
+ for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
89
+ sys_tok = dict.encode_line(sys_tok)
90
+ ref_tok = dict.encode_line(ref_tok)
91
+ scorer.add(ref_tok, sys_tok)
92
+ print(scorer.result_string(args.order))
93
+
94
+ if args.sys == "-":
95
+ score(sys.stdin)
96
+ else:
97
+ with open(args.sys, "r") as f:
98
+ score(f)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ cli_main()
fairseq/fairseq_cli/train.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Train a new model on one or across multiple GPUs.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple
16
+
17
+ # We need to setup root logger before importing any fairseq libraries.
18
+ logging.basicConfig(
19
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
22
+ stream=sys.stdout,
23
+ )
24
+ logger = logging.getLogger("fairseq_cli.train")
25
+
26
+ import numpy as np
27
+ import torch
28
+ from omegaconf import DictConfig, OmegaConf
29
+
30
+ from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils
31
+ from fairseq.data import data_utils, iterators
32
+ from fairseq.data.plasma_utils import PlasmaStore
33
+ from fairseq.dataclass.configs import FairseqConfig
34
+ from fairseq.dataclass.initialize import add_defaults
35
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
36
+ from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap
37
+ from fairseq.distributed import utils as distributed_utils
38
+ from fairseq.file_io import PathManager
39
+ from fairseq.logging import meters, metrics, progress_bar
40
+ from fairseq.model_parallel.megatron_trainer import MegatronTrainer
41
+ from fairseq.trainer import Trainer
42
+
43
+
44
+ def main(cfg: FairseqConfig) -> None:
45
+ if isinstance(cfg, argparse.Namespace):
46
+ cfg = convert_namespace_to_omegaconf(cfg)
47
+
48
+ utils.import_user_module(cfg.common)
49
+ add_defaults(cfg)
50
+
51
+ if (
52
+ distributed_utils.is_master(cfg.distributed_training)
53
+ and "job_logging_cfg" in cfg
54
+ ):
55
+ # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
56
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
57
+
58
+ assert (
59
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
60
+ ), "Must specify batch size either with --max-tokens or --batch-size"
61
+ metrics.reset()
62
+
63
+ if cfg.common.log_file is not None:
64
+ handler = logging.FileHandler(filename=cfg.common.log_file)
65
+ logger.addHandler(handler)
66
+
67
+ np.random.seed(cfg.common.seed)
68
+ utils.set_torch_seed(cfg.common.seed)
69
+
70
+ if distributed_utils.is_master(cfg.distributed_training):
71
+ checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
72
+
73
+ # Print args
74
+ logger.info(cfg)
75
+
76
+ if cfg.checkpoint.write_checkpoints_asynchronously:
77
+ try:
78
+ import iopath # noqa: F401
79
+ except ImportError:
80
+ logging.exception(
81
+ "Asynchronous checkpoint writing is specified but iopath is "
82
+ "not installed: `pip install iopath`"
83
+ )
84
+ return
85
+
86
+ # Setup task, e.g., translation, language modeling, etc.
87
+ task = tasks.setup_task(cfg.task)
88
+
89
+ assert cfg.criterion, "Please specify criterion to train a model"
90
+
91
+ # Build model and criterion
92
+ if cfg.distributed_training.ddp_backend == "fully_sharded":
93
+ with fsdp_enable_wrap(cfg.distributed_training):
94
+ model = fsdp_wrap(task.build_model(cfg.model))
95
+ else:
96
+ model = task.build_model(cfg.model)
97
+ criterion = task.build_criterion(cfg.criterion)
98
+ logger.info(model)
99
+ logger.info("task: {}".format(task.__class__.__name__))
100
+ logger.info("model: {}".format(model.__class__.__name__))
101
+ logger.info("criterion: {}".format(criterion.__class__.__name__))
102
+ logger.info(
103
+ "num. shared model params: {:,} (num. trained: {:,})".format(
104
+ sum(
105
+ p.numel() for p in model.parameters() if not getattr(p, "expert", False)
106
+ ),
107
+ sum(
108
+ p.numel()
109
+ for p in model.parameters()
110
+ if not getattr(p, "expert", False) and p.requires_grad
111
+ ),
112
+ )
113
+ )
114
+
115
+ logger.info(
116
+ "num. expert model params: {} (num. trained: {})".format(
117
+ sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
118
+ sum(
119
+ p.numel()
120
+ for p in model.parameters()
121
+ if getattr(p, "expert", False) and p.requires_grad
122
+ ),
123
+ )
124
+ )
125
+
126
+ # Load valid dataset (we load training data below, based on the latest checkpoint)
127
+ # We load the valid dataset AFTER building the model
128
+ if not cfg.dataset.disable_validation:
129
+ data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
130
+ if cfg.dataset.combine_valid_subsets:
131
+ task.load_dataset("valid", combine=True, epoch=1)
132
+ else:
133
+ for valid_sub_split in cfg.dataset.valid_subset.split(","):
134
+ task.load_dataset(valid_sub_split, combine=False, epoch=1)
135
+
136
+ # (optionally) Configure quantization
137
+ if cfg.common.quantization_config_path is not None:
138
+ quantizer = quantization_utils.Quantizer(
139
+ config_path=cfg.common.quantization_config_path,
140
+ max_epoch=cfg.optimization.max_epoch,
141
+ max_update=cfg.optimization.max_update,
142
+ )
143
+ else:
144
+ quantizer = None
145
+
146
+ # Build trainer
147
+ if cfg.common.model_parallel_size == 1:
148
+ trainer = Trainer(cfg, task, model, criterion, quantizer)
149
+ else:
150
+ trainer = MegatronTrainer(cfg, task, model, criterion)
151
+ logger.info(
152
+ "training on {} devices (GPUs/TPUs)".format(
153
+ cfg.distributed_training.distributed_world_size
154
+ )
155
+ )
156
+ logger.info(
157
+ "max tokens per device = {} and max sentences per device = {}".format(
158
+ cfg.dataset.max_tokens,
159
+ cfg.dataset.batch_size,
160
+ )
161
+ )
162
+
163
+ # Load the latest checkpoint if one is available and restore the
164
+ # corresponding train iterator
165
+ extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
166
+ cfg.checkpoint,
167
+ trainer,
168
+ # don't cache epoch iterators for sharded datasets
169
+ disable_iterator_cache=task.has_sharded_data("train"),
170
+ )
171
+ if cfg.common.tpu:
172
+ import torch_xla.core.xla_model as xm
173
+
174
+ xm.rendezvous("load_checkpoint") # wait for all workers
175
+
176
+ max_epoch = cfg.optimization.max_epoch or math.inf
177
+ lr = trainer.get_lr()
178
+
179
+ # TODO: a dry run on validation set to pin the memory
180
+ valid_subsets = cfg.dataset.valid_subset.split(",")
181
+ if not cfg.dataset.disable_validation:
182
+ for subset in valid_subsets:
183
+ logger.info('begin dry-run validation on "{}" subset'.format(subset))
184
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(
185
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
186
+ )
187
+ if cfg.common.tpu:
188
+ itr = utils.tpu_data_loader(itr)
189
+ for _ in itr:
190
+ pass
191
+ # TODO: end of dry run section
192
+
193
+ train_meter = meters.StopwatchMeter()
194
+ train_meter.start()
195
+ while epoch_itr.next_epoch_idx <= max_epoch:
196
+ if lr <= cfg.optimization.stop_min_lr:
197
+ logger.info(
198
+ f"stopping training because current learning rate ({lr}) is smaller "
199
+ "than or equal to minimum learning rate "
200
+ f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
201
+ )
202
+ break
203
+
204
+ # train for one epoch
205
+ valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
206
+ if should_stop:
207
+ break
208
+
209
+ # only use first validation loss to update the learning rate
210
+ lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
211
+
212
+ epoch_itr = trainer.get_train_iterator(
213
+ epoch_itr.next_epoch_idx,
214
+ # sharded data: get train iterator for next epoch
215
+ load_dataset=task.has_sharded_data("train"),
216
+ # don't cache epoch iterators for sharded datasets
217
+ disable_iterator_cache=task.has_sharded_data("train"),
218
+ )
219
+ train_meter.stop()
220
+ logger.info("done training in {:.1f} seconds".format(train_meter.sum))
221
+
222
+ # ioPath implementation to wait for all asynchronous file writes to complete.
223
+ if cfg.checkpoint.write_checkpoints_asynchronously:
224
+ logger.info(
225
+ "ioPath PathManager waiting for all asynchronous checkpoint "
226
+ "writes to finish."
227
+ )
228
+ PathManager.async_close()
229
+ logger.info("ioPath PathManager finished waiting.")
230
+
231
+
232
+ def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
233
+ # skip check if no validation was done in the current epoch
234
+ if valid_loss is None:
235
+ return False
236
+ if cfg.checkpoint.patience <= 0:
237
+ return False
238
+
239
+ def is_better(a, b):
240
+ return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
241
+
242
+ prev_best = getattr(should_stop_early, "best", None)
243
+ if prev_best is None or is_better(valid_loss, prev_best):
244
+ should_stop_early.best = valid_loss
245
+ should_stop_early.num_runs = 0
246
+ return False
247
+ else:
248
+ should_stop_early.num_runs += 1
249
+ if should_stop_early.num_runs >= cfg.checkpoint.patience:
250
+ logger.info(
251
+ "early stop since valid performance hasn't improved for last {} runs".format(
252
+ cfg.checkpoint.patience
253
+ )
254
+ )
255
+ return True
256
+ else:
257
+ return False
258
+
259
+
260
+ @metrics.aggregate("train")
261
+ def train(
262
+ cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr
263
+ ) -> Tuple[List[Optional[float]], bool]:
264
+ """Train the model for one epoch and return validation losses."""
265
+ # Initialize data iterator
266
+ itr = epoch_itr.next_epoch_itr(
267
+ fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
268
+ shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
269
+ )
270
+ update_freq = (
271
+ cfg.optimization.update_freq[epoch_itr.epoch - 1]
272
+ if epoch_itr.epoch <= len(cfg.optimization.update_freq)
273
+ else cfg.optimization.update_freq[-1]
274
+ )
275
+ itr = iterators.GroupedIterator(
276
+ itr,
277
+ update_freq,
278
+ skip_remainder_batch=cfg.optimization.skip_remainder_batch,
279
+ )
280
+ if cfg.common.tpu:
281
+ itr = utils.tpu_data_loader(itr)
282
+ progress = progress_bar.progress_bar(
283
+ itr,
284
+ log_format=cfg.common.log_format,
285
+ log_file=cfg.common.log_file,
286
+ log_interval=cfg.common.log_interval,
287
+ epoch=epoch_itr.epoch,
288
+ aim_repo=(
289
+ cfg.common.aim_repo
290
+ if distributed_utils.is_master(cfg.distributed_training)
291
+ else None
292
+ ),
293
+ aim_run_hash=(
294
+ cfg.common.aim_run_hash
295
+ if distributed_utils.is_master(cfg.distributed_training)
296
+ else None
297
+ ),
298
+ aim_param_checkpoint_dir=cfg.checkpoint.save_dir,
299
+ tensorboard_logdir=(
300
+ cfg.common.tensorboard_logdir
301
+ if distributed_utils.is_master(cfg.distributed_training)
302
+ else None
303
+ ),
304
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
305
+ wandb_project=(
306
+ cfg.common.wandb_project
307
+ if distributed_utils.is_master(cfg.distributed_training)
308
+ else None
309
+ ),
310
+ wandb_run_name=os.environ.get(
311
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
312
+ ),
313
+ azureml_logging=(
314
+ cfg.common.azureml_logging
315
+ if distributed_utils.is_master(cfg.distributed_training)
316
+ else False
317
+ ),
318
+ )
319
+ progress.update_config(_flatten_config(cfg))
320
+
321
+ trainer.begin_epoch(epoch_itr.epoch)
322
+
323
+ valid_subsets = cfg.dataset.valid_subset.split(",")
324
+ should_stop = False
325
+ num_updates = trainer.get_num_updates()
326
+ logger.info("Start iterating over samples")
327
+ for i, samples in enumerate(progress):
328
+ with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
329
+ "train_step-%d" % i
330
+ ):
331
+ log_output = trainer.train_step(samples)
332
+
333
+ if log_output is not None: # not OOM, overflow, ...
334
+ # log mid-epoch stats
335
+ num_updates = trainer.get_num_updates()
336
+ if num_updates % cfg.common.log_interval == 0:
337
+ stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
338
+ progress.log(stats, tag="train_inner", step=num_updates)
339
+
340
+ # reset mid-epoch stats after each log interval
341
+ # the end-of-epoch stats will still be preserved
342
+ metrics.reset_meters("train_inner")
343
+
344
+ end_of_epoch = not itr.has_next()
345
+ valid_losses, should_stop = validate_and_save(
346
+ cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
347
+ )
348
+
349
+ if should_stop:
350
+ break
351
+
352
+ # log end-of-epoch stats
353
+ logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
354
+ stats = get_training_stats(metrics.get_smoothed_values("train"))
355
+ progress.print(stats, tag="train", step=num_updates)
356
+
357
+ # reset epoch-level meters
358
+ metrics.reset_meters("train")
359
+ return valid_losses, should_stop
360
+
361
+
362
+ def _flatten_config(cfg: DictConfig):
363
+ config = OmegaConf.to_container(cfg)
364
+ # remove any legacy Namespaces and replace with a single "args"
365
+ namespace = None
366
+ for k, v in list(config.items()):
367
+ if isinstance(v, argparse.Namespace):
368
+ namespace = v
369
+ del config[k]
370
+ if namespace is not None:
371
+ config["args"] = vars(namespace)
372
+ return config
373
+
374
+
375
+ def validate_and_save(
376
+ cfg: DictConfig,
377
+ trainer: Trainer,
378
+ task: tasks.FairseqTask,
379
+ epoch_itr,
380
+ valid_subsets: List[str],
381
+ end_of_epoch: bool,
382
+ ) -> Tuple[List[Optional[float]], bool]:
383
+ num_updates = trainer.get_num_updates()
384
+ max_update = cfg.optimization.max_update or math.inf
385
+
386
+ # Stopping conditions (and an additional one based on validation loss later
387
+ # on)
388
+ should_stop = False
389
+ if num_updates >= max_update:
390
+ should_stop = True
391
+ logger.info(
392
+ f"Stopping training due to "
393
+ f"num_updates: {num_updates} >= max_update: {max_update}"
394
+ )
395
+
396
+ training_time_hours = trainer.cumulative_training_time() / (60 * 60)
397
+ if (
398
+ cfg.optimization.stop_time_hours > 0
399
+ and training_time_hours > cfg.optimization.stop_time_hours
400
+ ):
401
+ should_stop = True
402
+ logger.info(
403
+ f"Stopping training due to "
404
+ f"cumulative_training_time: {training_time_hours} > "
405
+ f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
406
+ )
407
+
408
+ do_save = (
409
+ (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
410
+ or should_stop
411
+ or (
412
+ cfg.checkpoint.save_interval_updates > 0
413
+ and num_updates > 0
414
+ and num_updates % cfg.checkpoint.save_interval_updates == 0
415
+ and num_updates >= cfg.dataset.validate_after_updates
416
+ )
417
+ )
418
+ do_validate = (
419
+ (
420
+ (not end_of_epoch and do_save) # validate during mid-epoch saves
421
+ or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
422
+ or should_stop
423
+ or (
424
+ cfg.dataset.validate_interval_updates > 0
425
+ and num_updates > 0
426
+ and num_updates % cfg.dataset.validate_interval_updates == 0
427
+ )
428
+ )
429
+ and not cfg.dataset.disable_validation
430
+ and num_updates >= cfg.dataset.validate_after_updates
431
+ )
432
+
433
+ # Validate
434
+ valid_losses = [None]
435
+ if do_validate:
436
+ valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
437
+
438
+ should_stop |= should_stop_early(cfg, valid_losses[0])
439
+
440
+ # Save checkpoint
441
+ if do_save or should_stop:
442
+ cp_path = checkpoint_utils.save_checkpoint(
443
+ cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
444
+ )
445
+ if cp_path is not None and hasattr(task, "post_save"):
446
+ task.post_save(cp_path, num_updates)
447
+
448
+ return valid_losses, should_stop
449
+
450
+
451
+ def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
452
+ stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
453
+ return stats
454
+
455
+
456
+ def validate(
457
+ cfg: DictConfig,
458
+ trainer: Trainer,
459
+ task: tasks.FairseqTask,
460
+ epoch_itr,
461
+ subsets: List[str],
462
+ ) -> List[Optional[float]]:
463
+ """Evaluate the model on the validation set(s) and return the losses."""
464
+
465
+ if cfg.dataset.fixed_validation_seed is not None:
466
+ # set fixed seed for every validation
467
+ utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
468
+
469
+ trainer.begin_valid_epoch(epoch_itr.epoch)
470
+ valid_losses = []
471
+ for subset_idx, subset in enumerate(subsets):
472
+ logger.info('begin validation on "{}" subset'.format(subset))
473
+
474
+ # Initialize data iterator
475
+ itr = trainer.get_valid_iterator(subset).next_epoch_itr(
476
+ shuffle=False, set_dataset_epoch=False # use a fixed valid set
477
+ )
478
+ if cfg.common.tpu:
479
+ itr = utils.tpu_data_loader(itr)
480
+ progress = progress_bar.progress_bar(
481
+ itr,
482
+ log_format=cfg.common.log_format,
483
+ log_interval=cfg.common.log_interval,
484
+ epoch=epoch_itr.epoch,
485
+ prefix=f"valid on '{subset}' subset",
486
+ aim_repo=(
487
+ cfg.common.aim_repo
488
+ if distributed_utils.is_master(cfg.distributed_training)
489
+ else None
490
+ ),
491
+ aim_run_hash=(
492
+ cfg.common.aim_run_hash
493
+ if distributed_utils.is_master(cfg.distributed_training)
494
+ else None
495
+ ),
496
+ aim_param_checkpoint_dir=cfg.checkpoint.save_dir,
497
+ tensorboard_logdir=(
498
+ cfg.common.tensorboard_logdir
499
+ if distributed_utils.is_master(cfg.distributed_training)
500
+ else None
501
+ ),
502
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
503
+ wandb_project=(
504
+ cfg.common.wandb_project
505
+ if distributed_utils.is_master(cfg.distributed_training)
506
+ else None
507
+ ),
508
+ wandb_run_name=os.environ.get(
509
+ "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
510
+ ),
511
+ )
512
+
513
+ # create a new root metrics aggregator so validation metrics
514
+ # don't pollute other aggregators (e.g., train meters)
515
+ with metrics.aggregate(new_root=True) as agg:
516
+ for i, sample in enumerate(progress):
517
+ if (
518
+ cfg.dataset.max_valid_steps is not None
519
+ and i > cfg.dataset.max_valid_steps
520
+ ):
521
+ break
522
+ trainer.valid_step(sample)
523
+
524
+ # log validation stats
525
+ # only tracking the best metric on the 1st validation subset
526
+ tracking_best = subset_idx == 0
527
+ stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(), tracking_best)
528
+
529
+ if hasattr(task, "post_validate"):
530
+ task.post_validate(trainer.get_model(), stats, agg)
531
+
532
+ progress.print(stats, tag=subset, step=trainer.get_num_updates())
533
+
534
+ valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
535
+ return valid_losses
536
+
537
+
538
+ def get_valid_stats(
539
+ cfg: DictConfig,
540
+ trainer: Trainer,
541
+ stats: Dict[str, Any],
542
+ tracking_best: bool,
543
+ ) -> Dict[str, Any]:
544
+ stats["num_updates"] = trainer.get_num_updates()
545
+ if tracking_best and hasattr(checkpoint_utils.save_checkpoint, "best"):
546
+ key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
547
+ best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
548
+ stats[key] = best_function(
549
+ checkpoint_utils.save_checkpoint.best,
550
+ stats[cfg.checkpoint.best_checkpoint_metric],
551
+ )
552
+ return stats
553
+
554
+
555
+ def cli_main(
556
+ modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
557
+ ) -> None:
558
+ parser = options.get_training_parser()
559
+ args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
560
+
561
+ cfg = convert_namespace_to_omegaconf(args)
562
+
563
+ if cfg.common.use_plasma_view:
564
+ server = PlasmaStore(path=cfg.common.plasma_path)
565
+ logger.info(
566
+ f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}"
567
+ )
568
+
569
+ if args.profile:
570
+ with torch.cuda.profiler.profile():
571
+ with torch.autograd.profiler.emit_nvtx():
572
+ distributed_utils.call_main(cfg, main)
573
+ else:
574
+ distributed_utils.call_main(cfg, main)
575
+
576
+ # if cfg.common.use_plasma_view:
577
+ # server.server.kill()
578
+
579
+
580
+ if __name__ == "__main__":
581
+ cli_main()
fairseq/fairseq_cli/validate.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ from argparse import Namespace
11
+ from itertools import chain
12
+
13
+ import torch
14
+ from omegaconf import DictConfig
15
+
16
+ from fairseq import checkpoint_utils, distributed_utils, options, utils
17
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
18
+ from fairseq.logging import metrics, progress_bar
19
+ from fairseq.utils import reset_logging
20
+
21
+ logging.basicConfig(
22
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
23
+ datefmt="%Y-%m-%d %H:%M:%S",
24
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
25
+ stream=sys.stdout,
26
+ )
27
+ logger = logging.getLogger("fairseq_cli.validate")
28
+
29
+
30
+ def main(cfg: DictConfig, override_args=None):
31
+ if isinstance(cfg, Namespace):
32
+ cfg = convert_namespace_to_omegaconf(cfg)
33
+
34
+ utils.import_user_module(cfg.common)
35
+
36
+ reset_logging()
37
+
38
+ assert (
39
+ cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
40
+ ), "Must specify batch size either with --max-tokens or --batch-size"
41
+
42
+ use_fp16 = cfg.common.fp16
43
+ use_cuda = torch.cuda.is_available() and not cfg.common.cpu
44
+
45
+ if use_cuda:
46
+ torch.cuda.set_device(cfg.distributed_training.device_id)
47
+
48
+ if cfg.distributed_training.distributed_world_size > 1:
49
+ data_parallel_world_size = distributed_utils.get_data_parallel_world_size()
50
+ data_parallel_rank = distributed_utils.get_data_parallel_rank()
51
+ else:
52
+ data_parallel_world_size = 1
53
+ data_parallel_rank = 0
54
+
55
+ if override_args is not None:
56
+ overrides = vars(override_args)
57
+ overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
58
+ else:
59
+ overrides = None
60
+
61
+ # Load ensemble
62
+ logger.info("loading model(s) from {}".format(cfg.common_eval.path))
63
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
64
+ [cfg.common_eval.path],
65
+ arg_overrides=overrides,
66
+ suffix=cfg.checkpoint.checkpoint_suffix,
67
+ )
68
+ model = models[0]
69
+
70
+ # Move models to GPU
71
+ for model in models:
72
+ model.eval()
73
+ if use_fp16:
74
+ model.half()
75
+ if use_cuda:
76
+ model.cuda()
77
+
78
+ # Print args
79
+ logger.info(saved_cfg)
80
+
81
+ # Build criterion
82
+ criterion = task.build_criterion(saved_cfg.criterion)
83
+ criterion.eval()
84
+
85
+ for subset in cfg.dataset.valid_subset.split(","):
86
+ try:
87
+ task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task)
88
+ dataset = task.dataset(subset)
89
+ except KeyError:
90
+ raise Exception("Cannot find dataset: " + subset)
91
+
92
+ # Initialize data iterator
93
+ itr = task.get_batch_iterator(
94
+ dataset=dataset,
95
+ max_tokens=cfg.dataset.max_tokens,
96
+ max_sentences=cfg.dataset.batch_size,
97
+ max_positions=utils.resolve_max_positions(
98
+ task.max_positions(),
99
+ *[m.max_positions() for m in models],
100
+ ),
101
+ ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
102
+ required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
103
+ seed=cfg.common.seed,
104
+ num_shards=data_parallel_world_size,
105
+ shard_id=data_parallel_rank,
106
+ num_workers=cfg.dataset.num_workers,
107
+ data_buffer_size=cfg.dataset.data_buffer_size,
108
+ ).next_epoch_itr(shuffle=False)
109
+ progress = progress_bar.progress_bar(
110
+ itr,
111
+ log_format=cfg.common.log_format,
112
+ log_interval=cfg.common.log_interval,
113
+ prefix=f"valid on '{subset}' subset",
114
+ default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
115
+ )
116
+
117
+ log_outputs = []
118
+ for i, sample in enumerate(progress):
119
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
120
+ _loss, _sample_size, log_output = task.valid_step(sample, model, criterion)
121
+ progress.log(log_output, step=i)
122
+ log_outputs.append(log_output)
123
+
124
+ if data_parallel_world_size > 1:
125
+ log_outputs = distributed_utils.all_gather_list(
126
+ log_outputs,
127
+ max_size=cfg.common.all_gather_list_size,
128
+ group=distributed_utils.get_data_parallel_group(),
129
+ )
130
+ log_outputs = list(chain.from_iterable(log_outputs))
131
+
132
+ with metrics.aggregate() as agg:
133
+ task.reduce_metrics(log_outputs, criterion)
134
+ log_output = agg.get_smoothed_values()
135
+
136
+ progress.print(log_output, tag=subset, step=i)
137
+
138
+
139
+ def cli_main():
140
+ parser = options.get_validation_parser()
141
+ args = options.parse_args_and_arch(parser)
142
+
143
+ # only override args that are explicitly given on the command line
144
+ override_parser = options.get_validation_parser()
145
+ override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
146
+
147
+ distributed_utils.call_main(
148
+ convert_namespace_to_omegaconf(args), main, override_args=override_args
149
+ )
150
+
151
+
152
+ if __name__ == "__main__":
153
+ cli_main()
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+
3
+ __version__ = "0.1"
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ from dataclasses import dataclass, field
3
+
4
+ from hydra.core.config_store import ConfigStore
5
+
6
+ from hydra_plugins.hydra_submitit_launcher.config import SlurmQueueConf
7
+
8
+
9
+ @dataclass
10
+ class DependencySubmititConf(SlurmQueueConf):
11
+ """Slurm configuration overrides and specific parameters"""
12
+
13
+ _target_: str = (
14
+ "hydra_plugins.dependency_submitit_launcher.launcher.DependencySubmititLauncher"
15
+ )
16
+
17
+
18
+ ConfigStore.instance().store(
19
+ group="hydra/launcher",
20
+ name="dependency_submitit_slurm",
21
+ node=DependencySubmititConf(),
22
+ provider="dependency_submitit_slurm",
23
+ )
fairseq/hydra_plugins/dependency_submitit_launcher/hydra_plugins/dependency_submitit_launcher/launcher.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import logging
3
+ import os
4
+ import subprocess
5
+ from pathlib import Path
6
+ from typing import Any, List, Sequence
7
+
8
+ from hydra.core.singleton import Singleton
9
+ from hydra.core.utils import JobReturn, filter_overrides
10
+ from omegaconf import OmegaConf
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ from .config import DependencySubmititConf
15
+ from hydra_plugins.hydra_submitit_launcher.submitit_launcher import BaseSubmititLauncher
16
+
17
+
18
+ class DependencySubmititLauncher(BaseSubmititLauncher):
19
+ _EXECUTOR = "slurm"
20
+
21
+ def launch(
22
+ self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
23
+ ) -> Sequence[JobReturn]:
24
+
25
+ # lazy import to ensure plugin discovery remains fast
26
+ import submitit
27
+
28
+ assert self.config is not None
29
+
30
+ num_jobs = len(job_overrides)
31
+ assert num_jobs > 0
32
+
33
+ next_script = None
34
+
35
+ for jo in job_overrides:
36
+ if next_script is None:
37
+ for item in jo:
38
+ if "next_script=" in item:
39
+ next_script = item
40
+ break
41
+ assert (
42
+ next_script is not None
43
+ ), "job overrides must contain +next_script=path/to/next/script"
44
+ jo.remove(next_script)
45
+
46
+ idx = next_script.find("=")
47
+ next_script = next_script[idx + 1 :]
48
+
49
+ params = self.params
50
+ # build executor
51
+ init_params = {"folder": self.params["submitit_folder"]}
52
+ specific_init_keys = {"max_num_timeout"}
53
+
54
+ init_params.update(
55
+ **{
56
+ f"{self._EXECUTOR}_{x}": y
57
+ for x, y in params.items()
58
+ if x in specific_init_keys
59
+ }
60
+ )
61
+ init_keys = specific_init_keys | {"submitit_folder"}
62
+ executor = submitit.AutoExecutor(cluster=self._EXECUTOR, **init_params)
63
+
64
+ # specify resources/parameters
65
+ baseparams = set(OmegaConf.structured(DependencySubmititConf).keys())
66
+ params = {
67
+ x if x in baseparams else f"{self._EXECUTOR}_{x}": y
68
+ for x, y in params.items()
69
+ if x not in init_keys
70
+ }
71
+ executor.update_parameters(**params)
72
+
73
+ log.info(
74
+ f"Submitit '{self._EXECUTOR}' sweep output dir : "
75
+ f"{self.config.hydra.sweep.dir}"
76
+ )
77
+ sweep_dir = Path(str(self.config.hydra.sweep.dir))
78
+ sweep_dir.mkdir(parents=True, exist_ok=True)
79
+ if "mode" in self.config.hydra.sweep:
80
+ mode = int(str(self.config.hydra.sweep.mode), 8)
81
+ os.chmod(sweep_dir, mode=mode)
82
+
83
+ job_params: List[Any] = []
84
+ for idx, overrides in enumerate(job_overrides):
85
+ idx = initial_job_idx + idx
86
+ lst = " ".join(filter_overrides(overrides))
87
+ log.info(f"\t#{idx} : {lst}")
88
+ job_params.append(
89
+ (
90
+ list(overrides),
91
+ "hydra.sweep.dir",
92
+ idx,
93
+ f"job_id_for_{idx}",
94
+ Singleton.get_state(),
95
+ )
96
+ )
97
+
98
+ jobs = executor.map_array(self, *zip(*job_params))
99
+
100
+ for j, jp in zip(jobs, job_params):
101
+ job_id = str(j.job_id)
102
+ task_id = "0" if "_" not in job_id else job_id.split("_")[1]
103
+ sweep_config = self.config_loader.load_sweep_config(self.config, jp[0])
104
+ dir = sweep_config.hydra.sweep.dir
105
+
106
+ dir = (
107
+ dir.replace("[", "")
108
+ .replace("]", "")
109
+ .replace("{", "")
110
+ .replace("}", "")
111
+ .replace(",", "_")
112
+ .replace("'", "")
113
+ .replace('"', "")
114
+ )
115
+
116
+ subprocess.call(
117
+ [next_script, job_id, task_id, dir],
118
+ shell=False,
119
+ )
120
+
121
+ return [j.results()[0] for j in jobs]
fairseq/hydra_plugins/dependency_submitit_launcher/setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # type: ignore
3
+ from pathlib import Path
4
+
5
+ from read_version import read_version
6
+ from setuptools import find_namespace_packages, setup
7
+
8
+ setup(
9
+ name="dependency-submitit-launcher",
10
+ version=read_version("hydra_plugins/dependency_submitit_launcher", "__init__.py"),
11
+ author="Alexei Baevski",
12
+ author_email="[email protected]",
13
+ description="Dependency-supporting Submitit Launcher for Hydra apps",
14
+ packages=find_namespace_packages(include=["hydra_plugins.*"]),
15
+ classifiers=[
16
+ "License :: OSI Approved :: MIT License",
17
+ "Programming Language :: Python :: 3.7",
18
+ "Programming Language :: Python :: 3.8",
19
+ "Programming Language :: Python :: 3.9",
20
+ "Operating System :: MacOS",
21
+ "Operating System :: POSIX :: Linux",
22
+ "Development Status :: 4 - Beta",
23
+ ],
24
+ install_requires=[
25
+ "hydra-core>=1.0.4",
26
+ "submitit>=1.0.0",
27
+ ],
28
+ include_package_data=True,
29
+ )
fairseq/scripts/__init__.py ADDED
File without changes
fairseq/scripts/average_checkpoints.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import collections
9
+ import os
10
+ import re
11
+
12
+ import torch
13
+
14
+ from fairseq.file_io import PathManager
15
+
16
+
17
+ def average_checkpoints(inputs):
18
+ """Loads checkpoints from inputs and returns a model with averaged weights.
19
+
20
+ Args:
21
+ inputs: An iterable of string paths of checkpoints to load from.
22
+
23
+ Returns:
24
+ A dict of string keys mapping to various values. The 'model' key
25
+ from the returned dict should correspond to an OrderedDict mapping
26
+ string parameter names to torch Tensors.
27
+ """
28
+ params_dict = collections.OrderedDict()
29
+ params_keys = None
30
+ new_state = None
31
+ num_models = len(inputs)
32
+
33
+ for fpath in inputs:
34
+ with PathManager.open(fpath, "rb") as f:
35
+ state = torch.load(
36
+ f,
37
+ map_location=(
38
+ lambda s, _: torch.serialization.default_restore_location(s, "cpu")
39
+ ),
40
+ )
41
+ # Copies over the settings from the first checkpoint
42
+ if new_state is None:
43
+ new_state = state
44
+
45
+ model_params = state["model"]
46
+
47
+ model_params_keys = list(model_params.keys())
48
+ if params_keys is None:
49
+ params_keys = model_params_keys
50
+ elif params_keys != model_params_keys:
51
+ raise KeyError(
52
+ "For checkpoint {}, expected list of params: {}, "
53
+ "but found: {}".format(f, params_keys, model_params_keys)
54
+ )
55
+
56
+ for k in params_keys:
57
+ p = model_params[k]
58
+ if isinstance(p, torch.HalfTensor):
59
+ p = p.float()
60
+ if k not in params_dict:
61
+ params_dict[k] = p.clone()
62
+ # NOTE: clone() is needed in case of p is a shared parameter
63
+ else:
64
+ params_dict[k] += p
65
+
66
+ averaged_params = collections.OrderedDict()
67
+ for k, v in params_dict.items():
68
+ averaged_params[k] = v
69
+ if averaged_params[k].is_floating_point():
70
+ averaged_params[k].div_(num_models)
71
+ else:
72
+ averaged_params[k] //= num_models
73
+ new_state["model"] = averaged_params
74
+ return new_state
75
+
76
+
77
+ def last_n_checkpoints(paths, n, update_based, upper_bound=None):
78
+ assert len(paths) == 1
79
+ path = paths[0]
80
+ if update_based:
81
+ pt_regexp = re.compile(r"checkpoint_\d+_(\d+)\.pt")
82
+ else:
83
+ pt_regexp = re.compile(r"checkpoint(\d+)\.pt")
84
+ files = PathManager.ls(path)
85
+
86
+ entries = []
87
+ for f in files:
88
+ m = pt_regexp.fullmatch(f)
89
+ if m is not None:
90
+ sort_key = int(m.group(1))
91
+ if upper_bound is None or sort_key <= upper_bound:
92
+ entries.append((sort_key, m.group(0)))
93
+ if len(entries) < n:
94
+ raise Exception(
95
+ "Found {} checkpoint files but need at least {}", len(entries), n
96
+ )
97
+ return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
98
+
99
+
100
+ def main():
101
+ parser = argparse.ArgumentParser(
102
+ description="Tool to average the params of input checkpoints to "
103
+ "produce a new checkpoint",
104
+ )
105
+ # fmt: off
106
+ parser.add_argument('--inputs', required=True, nargs='+',
107
+ help='Input checkpoint file paths.')
108
+ parser.add_argument('--output', required=True, metavar='FILE',
109
+ help='Write the new checkpoint containing the averaged weights to this path.')
110
+ num_group = parser.add_mutually_exclusive_group()
111
+ num_group.add_argument('--num-epoch-checkpoints', type=int,
112
+ help='if set, will try to find checkpoints with names checkpoint_xx.pt in the '
113
+ 'path specified by input, and average last this many of them.')
114
+ num_group.add_argument('--num-update-checkpoints', type=int,
115
+ help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
116
+ ' input, and average last this many of them.')
117
+ num_group.add_argument('--num-best-checkpoints', type=int, default=0,
118
+ help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by'
119
+ ' input, and average last this many of them.')
120
+ parser.add_argument('--checkpoint-upper-bound', type=int,
121
+ help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
122
+ 'when using --num-update-checkpoints, this will set an upper bound on which update to use'
123
+ 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be'
124
+ ' averaged.'
125
+ 'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would'
126
+ ' be averaged assuming --save-interval-updates 500'
127
+ )
128
+ # fmt: on
129
+ args = parser.parse_args()
130
+ print(args)
131
+
132
+ num = None
133
+ is_update_based = False
134
+ if args.num_update_checkpoints is not None:
135
+ num = args.num_update_checkpoints
136
+ is_update_based = True
137
+ elif args.num_epoch_checkpoints is not None:
138
+ num = args.num_epoch_checkpoints
139
+
140
+ assert args.checkpoint_upper_bound is None or (
141
+ args.num_epoch_checkpoints is not None
142
+ or args.num_update_checkpoints is not None
143
+ ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
144
+ assert (
145
+ args.num_epoch_checkpoints is None or args.num_update_checkpoints is None
146
+ ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"
147
+
148
+ if num is not None:
149
+ args.inputs = last_n_checkpoints(
150
+ args.inputs,
151
+ num,
152
+ is_update_based,
153
+ upper_bound=args.checkpoint_upper_bound,
154
+ )
155
+ print("averaging checkpoints: ", args.inputs)
156
+
157
+ if args.num_best_checkpoints > 0:
158
+ args.inputs = list(
159
+ sorted(
160
+ args.inputs,
161
+ key=lambda x: float(
162
+ os.path.basename(x).split("_")[-1].replace(".pt", "")
163
+ ),
164
+ )
165
+ )
166
+ args.inputs = args.inputs[: args.num_best_checkpoints]
167
+ for path in args.inputs:
168
+ print(os.path.basename(path))
169
+ new_state = average_checkpoints(args.inputs)
170
+ with PathManager.open(args.output, "wb") as f:
171
+ torch.save(new_state, f)
172
+ print("Finished writing averaged checkpoint to {}".format(args.output))
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()
fairseq/scripts/build_sym_alignment.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """
6
+ Use this script in order to build symmetric alignments for your translation
7
+ dataset.
8
+ This script depends on fast_align and mosesdecoder tools. You will need to
9
+ build those before running the script.
10
+ fast_align:
11
+ github: http://github.com/clab/fast_align
12
+ instructions: follow the instructions in README.md
13
+ mosesdecoder:
14
+ github: http://github.com/moses-smt/mosesdecoder
15
+ instructions: http://www.statmt.org/moses/?n=Development.GetStarted
16
+ The script produces the following files under --output_dir:
17
+ text.joined - concatenation of lines from the source_file and the
18
+ target_file.
19
+ align.forward - forward pass of fast_align.
20
+ align.backward - backward pass of fast_align.
21
+ aligned.sym_heuristic - symmetrized alignment.
22
+ """
23
+
24
+ import argparse
25
+ import os
26
+ from itertools import zip_longest
27
+
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser(description="symmetric alignment builer")
31
+ # fmt: off
32
+ parser.add_argument('--fast_align_dir',
33
+ help='path to fast_align build directory')
34
+ parser.add_argument('--mosesdecoder_dir',
35
+ help='path to mosesdecoder root directory')
36
+ parser.add_argument('--sym_heuristic',
37
+ help='heuristic to use for symmetrization',
38
+ default='grow-diag-final-and')
39
+ parser.add_argument('--source_file',
40
+ help='path to a file with sentences '
41
+ 'in the source language')
42
+ parser.add_argument('--target_file',
43
+ help='path to a file with sentences '
44
+ 'in the target language')
45
+ parser.add_argument('--output_dir',
46
+ help='output directory')
47
+ # fmt: on
48
+ args = parser.parse_args()
49
+
50
+ fast_align_bin = os.path.join(args.fast_align_dir, "fast_align")
51
+ symal_bin = os.path.join(args.mosesdecoder_dir, "bin", "symal")
52
+ sym_fast_align_bin = os.path.join(
53
+ args.mosesdecoder_dir, "scripts", "ems", "support", "symmetrize-fast-align.perl"
54
+ )
55
+
56
+ # create joined file
57
+ joined_file = os.path.join(args.output_dir, "text.joined")
58
+ with open(args.source_file, "r", encoding="utf-8") as src, open(
59
+ args.target_file, "r", encoding="utf-8"
60
+ ) as tgt:
61
+ with open(joined_file, "w", encoding="utf-8") as joined:
62
+ for s, t in zip_longest(src, tgt):
63
+ print("{} ||| {}".format(s.strip(), t.strip()), file=joined)
64
+
65
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
66
+
67
+ # run forward alignment
68
+ fwd_align_file = os.path.join(args.output_dir, "align.forward")
69
+ fwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v > {FWD}".format(
70
+ FASTALIGN=fast_align_bin, JOINED=joined_file, FWD=fwd_align_file
71
+ )
72
+ assert os.system(fwd_fast_align_cmd) == 0
73
+
74
+ # run backward alignment
75
+ bwd_align_file = os.path.join(args.output_dir, "align.backward")
76
+ bwd_fast_align_cmd = "{FASTALIGN} -i {JOINED} -d -o -v -r > {BWD}".format(
77
+ FASTALIGN=fast_align_bin, JOINED=joined_file, BWD=bwd_align_file
78
+ )
79
+ assert os.system(bwd_fast_align_cmd) == 0
80
+
81
+ # run symmetrization
82
+ sym_out_file = os.path.join(args.output_dir, "aligned")
83
+ sym_cmd = "{SYMFASTALIGN} {FWD} {BWD} {SRC} {TGT} {OUT} {HEURISTIC} {SYMAL}".format(
84
+ SYMFASTALIGN=sym_fast_align_bin,
85
+ FWD=fwd_align_file,
86
+ BWD=bwd_align_file,
87
+ SRC=args.source_file,
88
+ TGT=args.target_file,
89
+ OUT=sym_out_file,
90
+ HEURISTIC=args.sym_heuristic,
91
+ SYMAL=symal_bin,
92
+ )
93
+ assert os.system(sym_cmd) == 0
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
fairseq/scripts/check_installation.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+
4
+ cwd = Path(".").resolve()
5
+ print("running 'check_installation.py' from:", cwd)
6
+
7
+ # Old versions of numpy/torch can prevent loading the .so files
8
+ import torch
9
+
10
+ print("torch:", torch.__version__)
11
+ import numpy
12
+
13
+ print("numpy:", numpy.__version__)
14
+
15
+ import fairseq
16
+
17
+ print("Fairseq installed at:", fairseq.__file__)
18
+ import fairseq.criterions
19
+ import fairseq.dataclass.configs
20
+
21
+ import _imp
22
+
23
+ print("Should load following .so suffixes:", _imp.extension_suffixes())
24
+
25
+ so_files = list(Path(fairseq.__file__).parent.glob("*.so"))
26
+ so_files.extend(Path(fairseq.__file__).parent.glob("data/*.so"))
27
+ print("Found following .so files:")
28
+ for so_file in so_files:
29
+ print(f"- {so_file}")
30
+
31
+ from fairseq import libbleu
32
+
33
+ print("Found libbleu at", libbleu.__file__)
34
+ from fairseq.data import data_utils_fast
35
+
36
+ print("Found data_utils_fast at", data_utils_fast.__file__)
fairseq/scripts/compare_namespaces.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """Helper script to compare two argparse.Namespace objects."""
3
+
4
+ from argparse import Namespace # noqa
5
+
6
+
7
+ def main():
8
+
9
+ ns1 = eval(input("Namespace 1: "))
10
+ ns2 = eval(input("Namespace 2: "))
11
+
12
+ def keys(ns):
13
+ ks = set()
14
+ for k in dir(ns):
15
+ if not k.startswith("_"):
16
+ ks.add(k)
17
+ return ks
18
+
19
+ k1 = keys(ns1)
20
+ k2 = keys(ns2)
21
+
22
+ def print_keys(ks, ns1, ns2=None):
23
+ for k in ks:
24
+ if ns2 is None:
25
+ print("{}\t{}".format(k, getattr(ns1, k, None)))
26
+ else:
27
+ print(
28
+ "{}\t{}\t{}".format(k, getattr(ns1, k, None), getattr(ns2, k, None))
29
+ )
30
+
31
+ print("Keys unique to namespace 1:")
32
+ print_keys(k1 - k2, ns1)
33
+ print()
34
+
35
+ print("Keys unique to namespace 2:")
36
+ print_keys(k2 - k1, ns2)
37
+ print()
38
+
39
+ print("Overlapping keys with different values:")
40
+ ks = [k for k in k1 & k2 if getattr(ns1, k, "None") != getattr(ns2, k, "None")]
41
+ print_keys(ks, ns1, ns2)
42
+ print()
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
fairseq/scripts/compound_split_bleu.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 1 ]; then
4
+ echo "usage: $0 GENERATE_PY_OUTPUT"
5
+ exit 1
6
+ fi
7
+
8
+ GEN=$1
9
+
10
+ SYS=$GEN.sys
11
+ REF=$GEN.ref
12
+
13
+ if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
14
+ echo "not done generating"
15
+ exit
16
+ fi
17
+
18
+ grep ^H $GEN | awk -F '\t' '{print $NF}' | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
19
+ grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
20
+ fairseq-score --sys $SYS --ref $REF
fairseq/scripts/constraints/extract.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """Extracts random constraints from reference files."""
9
+
10
+ import argparse
11
+ import random
12
+ import sys
13
+
14
+
15
+ def get_phrase(words, index, length):
16
+ assert index < len(words) - length + 1
17
+ phr = " ".join(words[index : index + length])
18
+ for i in range(index, index + length):
19
+ words.pop(index)
20
+ return phr
21
+
22
+
23
+ def main(args):
24
+
25
+ if args.seed:
26
+ random.seed(args.seed)
27
+
28
+ for line in sys.stdin:
29
+ constraints = []
30
+
31
+ def add_constraint(constraint):
32
+ constraints.append(constraint)
33
+
34
+ source = line.rstrip()
35
+ if "\t" in line:
36
+ source, target = line.split("\t")
37
+ if args.add_sos:
38
+ target = f"<s> {target}"
39
+ if args.add_eos:
40
+ target = f"{target} </s>"
41
+
42
+ if len(target.split()) >= args.len:
43
+ words = [target]
44
+
45
+ num = args.number
46
+
47
+ choices = {}
48
+ for i in range(num):
49
+ if len(words) == 0:
50
+ break
51
+ segmentno = random.choice(range(len(words)))
52
+ segment = words.pop(segmentno)
53
+ tokens = segment.split()
54
+ phrase_index = random.choice(range(len(tokens)))
55
+ choice = " ".join(
56
+ tokens[phrase_index : min(len(tokens), phrase_index + args.len)]
57
+ )
58
+ for j in range(
59
+ phrase_index, min(len(tokens), phrase_index + args.len)
60
+ ):
61
+ tokens.pop(phrase_index)
62
+ if phrase_index > 0:
63
+ words.append(" ".join(tokens[0:phrase_index]))
64
+ if phrase_index + 1 < len(tokens):
65
+ words.append(" ".join(tokens[phrase_index:]))
66
+ choices[target.find(choice)] = choice
67
+
68
+ # mask out with spaces
69
+ target = target.replace(choice, " " * len(choice), 1)
70
+
71
+ for key in sorted(choices.keys()):
72
+ add_constraint(choices[key])
73
+
74
+ print(source, *constraints, sep="\t")
75
+
76
+
77
+ if __name__ == "__main__":
78
+ parser = argparse.ArgumentParser()
79
+ parser.add_argument("--number", "-n", type=int, default=1, help="number of phrases")
80
+ parser.add_argument("--len", "-l", type=int, default=1, help="phrase length")
81
+ parser.add_argument(
82
+ "--add-sos", default=False, action="store_true", help="add <s> token"
83
+ )
84
+ parser.add_argument(
85
+ "--add-eos", default=False, action="store_true", help="add </s> token"
86
+ )
87
+ parser.add_argument("--seed", "-s", default=0, type=int)
88
+ args = parser.parse_args()
89
+
90
+ main(args)
fairseq/scripts/constraints/validate.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import sys
9
+
10
+
11
+ """Reads in a fairseq output file, and verifies that the constraints
12
+ (C- lines) are present in the output (the first H- line). Assumes that
13
+ constraints are listed prior to the first hypothesis.
14
+ """
15
+
16
+ constraints = []
17
+ found = 0
18
+ total = 0
19
+ for line in sys.stdin:
20
+ if line.startswith("C-"):
21
+ constraints.append(line.rstrip().split("\t")[1])
22
+ elif line.startswith("H-"):
23
+ text = line.split("\t")[2]
24
+
25
+ for constraint in constraints:
26
+ total += 1
27
+ if constraint in text:
28
+ found += 1
29
+ else:
30
+ print(f"No {constraint} in {text}", file=sys.stderr)
31
+
32
+ constraints = []
33
+
34
+ print(f"Found {found} / {total} = {100 * found / total:.1f}%")
fairseq/scripts/convert_dictionary.lua ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Copyright (c) Facebook, Inc. and its affiliates.
2
+ --
3
+ -- This source code is licensed under the MIT license found in the
4
+ -- LICENSE file in the root directory of this source tree.
5
+ --
6
+ -- Usage: convert_dictionary.lua <dict.th7>
7
+ require 'fairseq'
8
+ require 'torch'
9
+ require 'paths'
10
+
11
+ if #arg < 1 then
12
+ print('usage: convert_dictionary.lua <dict.th7>')
13
+ os.exit(1)
14
+ end
15
+ if not paths.filep(arg[1]) then
16
+ print('error: file does not exit: ' .. arg[1])
17
+ os.exit(1)
18
+ end
19
+
20
+ dict = torch.load(arg[1])
21
+ dst = paths.basename(arg[1]):gsub('.th7', '.txt')
22
+ assert(dst:match('.txt$'))
23
+
24
+ f = io.open(dst, 'w')
25
+ for idx, symbol in ipairs(dict.index_to_symbol) do
26
+ if idx > dict.cutoff then
27
+ break
28
+ end
29
+ f:write(symbol)
30
+ f:write(' ')
31
+ f:write(dict.index_to_freq[idx])
32
+ f:write('\n')
33
+ end
34
+ f:close()
fairseq/scripts/convert_model.lua ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Copyright (c) Facebook, Inc. and its affiliates.
2
+ --
3
+ -- This source code is licensed under the MIT license found in the
4
+ -- LICENSE file in the root directory of this source tree.
5
+ --
6
+ -- Usage: convert_model.lua <model_epoch1.th7>
7
+ require 'torch'
8
+ local fairseq = require 'fairseq'
9
+
10
+ model = torch.load(arg[1])
11
+
12
+ function find_weight_norm(container, module)
13
+ for _, wn in ipairs(container:listModules()) do
14
+ if torch.type(wn) == 'nn.WeightNorm' and wn.modules[1] == module then
15
+ return wn
16
+ end
17
+ end
18
+ end
19
+
20
+ function push_state(dict, key, module)
21
+ if torch.type(module) == 'nn.Linear' then
22
+ local wn = find_weight_norm(model.module, module)
23
+ assert(wn)
24
+ dict[key .. '.weight_v'] = wn.v:float()
25
+ dict[key .. '.weight_g'] = wn.g:float()
26
+ elseif torch.type(module) == 'nn.TemporalConvolutionTBC' then
27
+ local wn = find_weight_norm(model.module, module)
28
+ assert(wn)
29
+ local v = wn.v:float():view(wn.viewOut):transpose(2, 3)
30
+ dict[key .. '.weight_v'] = v
31
+ dict[key .. '.weight_g'] = wn.g:float():view(module.weight:size(3), 1, 1)
32
+ else
33
+ dict[key .. '.weight'] = module.weight:float()
34
+ end
35
+ if module.bias then
36
+ dict[key .. '.bias'] = module.bias:float()
37
+ end
38
+ end
39
+
40
+ encoder_dict = {}
41
+ decoder_dict = {}
42
+ combined_dict = {}
43
+
44
+ function encoder_state(encoder)
45
+ luts = encoder:findModules('nn.LookupTable')
46
+ push_state(encoder_dict, 'embed_tokens', luts[1])
47
+ push_state(encoder_dict, 'embed_positions', luts[2])
48
+
49
+ fcs = encoder:findModules('nn.Linear')
50
+ assert(#fcs >= 2)
51
+ local nInputPlane = fcs[1].weight:size(1)
52
+ push_state(encoder_dict, 'fc1', table.remove(fcs, 1))
53
+ push_state(encoder_dict, 'fc2', table.remove(fcs, #fcs))
54
+
55
+ for i, module in ipairs(encoder:findModules('nn.TemporalConvolutionTBC')) do
56
+ push_state(encoder_dict, 'convolutions.' .. tostring(i - 1), module)
57
+ if nInputPlane ~= module.weight:size(3) / 2 then
58
+ push_state(encoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
59
+ end
60
+ nInputPlane = module.weight:size(3) / 2
61
+ end
62
+ assert(#fcs == 0)
63
+ end
64
+
65
+ function decoder_state(decoder)
66
+ luts = decoder:findModules('nn.LookupTable')
67
+ push_state(decoder_dict, 'embed_tokens', luts[1])
68
+ push_state(decoder_dict, 'embed_positions', luts[2])
69
+
70
+ fcs = decoder:findModules('nn.Linear')
71
+ local nInputPlane = fcs[1].weight:size(1)
72
+ push_state(decoder_dict, 'fc1', table.remove(fcs, 1))
73
+ push_state(decoder_dict, 'fc2', fcs[#fcs - 1])
74
+ push_state(decoder_dict, 'fc3', fcs[#fcs])
75
+
76
+ table.remove(fcs, #fcs)
77
+ table.remove(fcs, #fcs)
78
+
79
+ for i, module in ipairs(decoder:findModules('nn.TemporalConvolutionTBC')) do
80
+ if nInputPlane ~= module.weight:size(3) / 2 then
81
+ push_state(decoder_dict, 'projections.' .. tostring(i - 1), table.remove(fcs, 1))
82
+ end
83
+ nInputPlane = module.weight:size(3) / 2
84
+
85
+ local prefix = 'attention.' .. tostring(i - 1)
86
+ push_state(decoder_dict, prefix .. '.in_projection', table.remove(fcs, 1))
87
+ push_state(decoder_dict, prefix .. '.out_projection', table.remove(fcs, 1))
88
+ push_state(decoder_dict, 'convolutions.' .. tostring(i - 1), module)
89
+ end
90
+ assert(#fcs == 0)
91
+ end
92
+
93
+
94
+ _encoder = model.module.modules[2]
95
+ _decoder = model.module.modules[3]
96
+
97
+ encoder_state(_encoder)
98
+ decoder_state(_decoder)
99
+
100
+ for k, v in pairs(encoder_dict) do
101
+ combined_dict['encoder.' .. k] = v
102
+ end
103
+ for k, v in pairs(decoder_dict) do
104
+ combined_dict['decoder.' .. k] = v
105
+ end
106
+
107
+
108
+ torch.save('state_dict.t7', combined_dict)
fairseq/scripts/count_docs.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Count the number of documents and average number of lines and tokens per
8
+ document in a large file. Documents should be separated by a single empty line.
9
+ """
10
+
11
+ import argparse
12
+ import gzip
13
+ import sys
14
+
15
+ import numpy as np
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("input")
21
+ parser.add_argument("--gzip", action="store_true")
22
+ args = parser.parse_args()
23
+
24
+ def gopen():
25
+ if args.gzip:
26
+ return gzip.open(args.input, "r")
27
+ else:
28
+ return open(args.input, "r", encoding="utf-8")
29
+
30
+ num_lines = []
31
+ num_toks = []
32
+ with gopen() as h:
33
+ num_docs = 1
34
+ num_lines_in_doc = 0
35
+ num_toks_in_doc = 0
36
+ for i, line in enumerate(h):
37
+ if len(line.strip()) == 0: # empty line indicates new document
38
+ num_docs += 1
39
+ num_lines.append(num_lines_in_doc)
40
+ num_toks.append(num_toks_in_doc)
41
+ num_lines_in_doc = 0
42
+ num_toks_in_doc = 0
43
+ else:
44
+ num_lines_in_doc += 1
45
+ num_toks_in_doc += len(line.rstrip().split())
46
+ if i % 1000000 == 0:
47
+ print(i, file=sys.stderr, end="", flush=True)
48
+ elif i % 100000 == 0:
49
+ print(".", file=sys.stderr, end="", flush=True)
50
+ print(file=sys.stderr, flush=True)
51
+
52
+ print("found {} docs".format(num_docs))
53
+ print("average num lines per doc: {}".format(np.mean(num_lines)))
54
+ print("average num toks per doc: {}".format(np.mean(num_toks)))
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
fairseq/scripts/read_binarized.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+
9
+ from fairseq.data import Dictionary, data_utils, indexed_dataset
10
+
11
+
12
+ def get_parser():
13
+ parser = argparse.ArgumentParser(
14
+ description="writes text from binarized file to stdout"
15
+ )
16
+ # fmt: off
17
+ parser.add_argument('--dataset-impl', help='dataset implementation',
18
+ choices=indexed_dataset.get_available_dataset_impl())
19
+ parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
20
+ parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
21
+ # fmt: on
22
+
23
+ return parser
24
+
25
+
26
+ def main():
27
+ parser = get_parser()
28
+ args = parser.parse_args()
29
+
30
+ dictionary = Dictionary.load(args.dict) if args.dict is not None else None
31
+ dataset = data_utils.load_indexed_dataset(
32
+ args.input,
33
+ dictionary,
34
+ dataset_impl=args.dataset_impl,
35
+ default="lazy",
36
+ )
37
+
38
+ for tensor_line in dataset:
39
+ if dictionary is None:
40
+ line = " ".join([str(int(x)) for x in tensor_line])
41
+ else:
42
+ line = dictionary.string(tensor_line)
43
+
44
+ print(line)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
fairseq/scripts/rm_pt.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+ import re
10
+ import shutil
11
+ import sys
12
+
13
+
14
+ pt_regexp = re.compile(r"checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt")
15
+ pt_regexp_epoch_based = re.compile(r"checkpoint(\d+)\.pt")
16
+ pt_regexp_update_based = re.compile(r"checkpoint_\d+_(\d+)\.pt")
17
+
18
+
19
+ def parse_checkpoints(files):
20
+ entries = []
21
+ for f in files:
22
+ m = pt_regexp_epoch_based.fullmatch(f)
23
+ if m is not None:
24
+ entries.append((int(m.group(1)), m.group(0)))
25
+ else:
26
+ m = pt_regexp_update_based.fullmatch(f)
27
+ if m is not None:
28
+ entries.append((int(m.group(1)), m.group(0)))
29
+ return entries
30
+
31
+
32
+ def last_n_checkpoints(files, n):
33
+ entries = parse_checkpoints(files)
34
+ return [x[1] for x in sorted(entries, reverse=True)[:n]]
35
+
36
+
37
+ def every_n_checkpoints(files, n):
38
+ entries = parse_checkpoints(files)
39
+ return [x[1] for x in sorted(sorted(entries)[::-n])]
40
+
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser(
44
+ description=(
45
+ "Recursively delete checkpoint files from `root_dir`, "
46
+ "but preserve checkpoint_best.pt and checkpoint_last.pt"
47
+ )
48
+ )
49
+ parser.add_argument("root_dirs", nargs="*")
50
+ parser.add_argument(
51
+ "--save-last", type=int, default=0, help="number of last checkpoints to save"
52
+ )
53
+ parser.add_argument(
54
+ "--save-every", type=int, default=0, help="interval of checkpoints to save"
55
+ )
56
+ parser.add_argument(
57
+ "--preserve-test",
58
+ action="store_true",
59
+ help="preserve checkpoints in dirs that start with test_ prefix (default: delete them)",
60
+ )
61
+ parser.add_argument(
62
+ "--delete-best", action="store_true", help="delete checkpoint_best.pt"
63
+ )
64
+ parser.add_argument(
65
+ "--delete-last", action="store_true", help="delete checkpoint_last.pt"
66
+ )
67
+ parser.add_argument(
68
+ "--no-dereference", action="store_true", help="don't dereference symlinks"
69
+ )
70
+ args = parser.parse_args()
71
+
72
+ files_to_desymlink = []
73
+ files_to_preserve = []
74
+ files_to_delete = []
75
+ for root_dir in args.root_dirs:
76
+ for root, _subdirs, files in os.walk(root_dir):
77
+ if args.save_last > 0:
78
+ to_save = last_n_checkpoints(files, args.save_last)
79
+ else:
80
+ to_save = []
81
+ if args.save_every > 0:
82
+ to_save += every_n_checkpoints(files, args.save_every)
83
+ for file in files:
84
+ if not pt_regexp.fullmatch(file):
85
+ continue
86
+ full_path = os.path.join(root, file)
87
+ if (
88
+ not os.path.basename(root).startswith("test_") or args.preserve_test
89
+ ) and (
90
+ (file == "checkpoint_last.pt" and not args.delete_last)
91
+ or (file == "checkpoint_best.pt" and not args.delete_best)
92
+ or file in to_save
93
+ ):
94
+ if os.path.islink(full_path) and not args.no_dereference:
95
+ files_to_desymlink.append(full_path)
96
+ else:
97
+ files_to_preserve.append(full_path)
98
+ else:
99
+ files_to_delete.append(full_path)
100
+
101
+ if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
102
+ print("Nothing to do.")
103
+ sys.exit(0)
104
+
105
+ files_to_desymlink = sorted(files_to_desymlink)
106
+ files_to_preserve = sorted(files_to_preserve)
107
+ files_to_delete = sorted(files_to_delete)
108
+
109
+ print("Operations to perform (in order):")
110
+ if len(files_to_desymlink) > 0:
111
+ for file in files_to_desymlink:
112
+ print(" - preserve (and dereference symlink): " + file)
113
+ if len(files_to_preserve) > 0:
114
+ for file in files_to_preserve:
115
+ print(" - preserve: " + file)
116
+ if len(files_to_delete) > 0:
117
+ for file in files_to_delete:
118
+ print(" - delete: " + file)
119
+ while True:
120
+ resp = input("Continue? (Y/N): ")
121
+ if resp.strip().lower() == "y":
122
+ break
123
+ elif resp.strip().lower() == "n":
124
+ sys.exit(0)
125
+
126
+ print("Executing...")
127
+ if len(files_to_desymlink) > 0:
128
+ for file in files_to_desymlink:
129
+ realpath = os.path.realpath(file)
130
+ print("rm " + file)
131
+ os.remove(file)
132
+ print("cp {} {}".format(realpath, file))
133
+ shutil.copyfile(realpath, file)
134
+ if len(files_to_delete) > 0:
135
+ for file in files_to_delete:
136
+ print("rm " + file)
137
+ os.remove(file)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ main()
fairseq/scripts/sacrebleu.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 4 ]; then
4
+ echo "usage: $0 TESTSET SRCLANG TGTLANG GEN"
5
+ exit 1
6
+ fi
7
+
8
+ TESTSET=$1
9
+ SRCLANG=$2
10
+ TGTLANG=$3
11
+
12
+ GEN=$4
13
+
14
+ if ! command -v sacremoses &> /dev/null
15
+ then
16
+ echo "sacremoses could not be found, please install with: pip install sacremoses"
17
+ exit
18
+ fi
19
+
20
+ grep ^H $GEN \
21
+ | sed 's/^H\-//' \
22
+ | sort -n -k 1 \
23
+ | cut -f 3 \
24
+ | sacremoses detokenize \
25
+ > $GEN.sorted.detok
26
+
27
+ sacrebleu --test-set $TESTSET --language-pair "${SRCLANG}-${TGTLANG}" < $GEN.sorted.detok
fairseq/scripts/shard_docs.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Split a large file into shards while respecting document boundaries. Documents
8
+ should be separated by a single empty line.
9
+ """
10
+
11
+ import argparse
12
+ import contextlib
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input")
18
+ parser.add_argument("--num-shards", type=int)
19
+ args = parser.parse_args()
20
+
21
+ assert args.num_shards is not None and args.num_shards > 1
22
+
23
+ with open(args.input, "r", encoding="utf-8") as h:
24
+ with contextlib.ExitStack() as stack:
25
+ outputs = [
26
+ stack.enter_context(
27
+ open(args.input + ".shard" + str(i), "w", encoding="utf-8")
28
+ )
29
+ for i in range(args.num_shards)
30
+ ]
31
+
32
+ doc = []
33
+ first_doc = [True] * args.num_shards
34
+
35
+ def output_doc(i):
36
+ if not first_doc[i]:
37
+ outputs[i].write("\n")
38
+ first_doc[i] = False
39
+ for line in doc:
40
+ outputs[i].write(line)
41
+ doc.clear()
42
+
43
+ num_docs = 0
44
+ for line in h:
45
+ if line.strip() == "": # empty line indicates new document
46
+ output_doc(num_docs % args.num_shards)
47
+ num_docs += 1
48
+ else:
49
+ doc.append(line)
50
+ output_doc(num_docs % args.num_shards)
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
fairseq/scripts/split_train_valid_docs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Split a large file into a train and valid set while respecting document
8
+ boundaries. Documents should be separated by a single empty line.
9
+ """
10
+
11
+ import argparse
12
+ import random
13
+ import sys
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("input")
19
+ parser.add_argument("sample_output", help="train output file")
20
+ parser.add_argument("remainder_output", help="valid output file")
21
+ parser.add_argument("-k", type=int, help="remainder size")
22
+ parser.add_argument(
23
+ "--lines", action="store_true", help="split lines instead of docs"
24
+ )
25
+ args = parser.parse_args()
26
+
27
+ assert args.k is not None
28
+
29
+ sample = []
30
+ remainder = []
31
+ num_docs = [0]
32
+
33
+ def update_sample(doc):
34
+ if len(sample) < args.k:
35
+ sample.append(doc.copy())
36
+ else:
37
+ i = num_docs[0]
38
+ j = random.randrange(i + 1)
39
+ if j < args.k:
40
+ remainder.append(sample[j])
41
+ sample[j] = doc.copy()
42
+ else:
43
+ remainder.append(doc.copy())
44
+ num_docs[0] += 1
45
+ doc.clear()
46
+
47
+ with open(args.input, "r", encoding="utf-8") as h:
48
+ doc = []
49
+ for i, line in enumerate(h):
50
+ if line.strip() == "": # empty line indicates new document
51
+ update_sample(doc)
52
+ else:
53
+ doc.append(line)
54
+ if args.lines:
55
+ update_sample(doc)
56
+ if i % 1000000 == 0:
57
+ print(i, file=sys.stderr, end="", flush=True)
58
+ elif i % 100000 == 0:
59
+ print(".", file=sys.stderr, end="", flush=True)
60
+ if len(doc) > 0:
61
+ update_sample(doc)
62
+ print(file=sys.stderr, flush=True)
63
+
64
+ assert len(sample) == args.k
65
+
66
+ with open(args.sample_output, "w", encoding="utf-8") as out:
67
+ first = True
68
+ for doc in sample:
69
+ if not first and not args.lines:
70
+ out.write("\n")
71
+ first = False
72
+ for line in doc:
73
+ out.write(line)
74
+
75
+ with open(args.remainder_output, "w", encoding="utf-8") as out:
76
+ first = True
77
+ for doc in remainder:
78
+ if not first and not args.lines:
79
+ out.write("\n")
80
+ first = False
81
+ for line in doc:
82
+ out.write(line)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
fairseq/scripts/spm_decode.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import absolute_import, division, print_function, unicode_literals
9
+
10
+ import argparse
11
+
12
+ import sentencepiece as spm
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument(
18
+ "--model", required=True, help="sentencepiece model to use for decoding"
19
+ )
20
+ parser.add_argument("--input", required=True, help="input file to decode")
21
+ parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
22
+ args = parser.parse_args()
23
+
24
+ sp = spm.SentencePieceProcessor()
25
+ sp.Load(args.model)
26
+
27
+ if args.input_format == "piece":
28
+
29
+ def decode(input):
30
+ return "".join(sp.DecodePieces(input))
31
+
32
+ elif args.input_format == "id":
33
+
34
+ def decode(input):
35
+ return "".join(sp.DecodeIds(input))
36
+
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ def tok2int(tok):
41
+ # remap reference-side <unk> (represented as <<unk>>) to 0
42
+ return int(tok) if tok != "<<unk>>" else 0
43
+
44
+ with open(args.input, "r", encoding="utf-8") as h:
45
+ for line in h:
46
+ if args.input_format == "id":
47
+ print(decode(list(map(tok2int, line.rstrip().split()))))
48
+ elif args.input_format == "piece":
49
+ print(decode(line.rstrip().split()))
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()
fairseq/scripts/spm_encode.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import absolute_import, division, print_function, unicode_literals
9
+
10
+ import argparse
11
+ import contextlib
12
+ import sys
13
+
14
+ import sentencepiece as spm
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument(
20
+ "--model", required=True, help="sentencepiece model to use for encoding"
21
+ )
22
+ parser.add_argument(
23
+ "--inputs", nargs="+", default=["-"], help="input files to filter/encode"
24
+ )
25
+ parser.add_argument(
26
+ "--outputs", nargs="+", default=["-"], help="path to save encoded outputs"
27
+ )
28
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
29
+ parser.add_argument(
30
+ "--min-len",
31
+ type=int,
32
+ metavar="N",
33
+ help="filter sentence pairs with fewer than N tokens",
34
+ )
35
+ parser.add_argument(
36
+ "--max-len",
37
+ type=int,
38
+ metavar="N",
39
+ help="filter sentence pairs with more than N tokens",
40
+ )
41
+ args = parser.parse_args()
42
+
43
+ assert len(args.inputs) == len(
44
+ args.outputs
45
+ ), "number of input and output paths should match"
46
+
47
+ sp = spm.SentencePieceProcessor()
48
+ sp.Load(args.model)
49
+
50
+ if args.output_format == "piece":
51
+
52
+ def encode(input):
53
+ return sp.EncodeAsPieces(input)
54
+
55
+ elif args.output_format == "id":
56
+
57
+ def encode(input):
58
+ return list(map(str, sp.EncodeAsIds(input)))
59
+
60
+ else:
61
+ raise NotImplementedError
62
+
63
+ if args.min_len is not None or args.max_len is not None:
64
+
65
+ def valid(line):
66
+ return (args.min_len is None or len(line) >= args.min_len) and (
67
+ args.max_len is None or len(line) <= args.max_len
68
+ )
69
+
70
+ else:
71
+
72
+ def valid(lines):
73
+ return True
74
+
75
+ with contextlib.ExitStack() as stack:
76
+ inputs = [
77
+ stack.enter_context(open(input, "r", encoding="utf-8"))
78
+ if input != "-"
79
+ else sys.stdin
80
+ for input in args.inputs
81
+ ]
82
+ outputs = [
83
+ stack.enter_context(open(output, "w", encoding="utf-8"))
84
+ if output != "-"
85
+ else sys.stdout
86
+ for output in args.outputs
87
+ ]
88
+
89
+ stats = {
90
+ "num_empty": 0,
91
+ "num_filtered": 0,
92
+ }
93
+
94
+ def encode_line(line):
95
+ line = line.strip()
96
+ if len(line) > 0:
97
+ line = encode(line)
98
+ if valid(line):
99
+ return line
100
+ else:
101
+ stats["num_filtered"] += 1
102
+ else:
103
+ stats["num_empty"] += 1
104
+ return None
105
+
106
+ for i, lines in enumerate(zip(*inputs), start=1):
107
+ enc_lines = list(map(encode_line, lines))
108
+ if not any(enc_line is None for enc_line in enc_lines):
109
+ for enc_line, output_h in zip(enc_lines, outputs):
110
+ print(" ".join(enc_line), file=output_h)
111
+ if i % 10000 == 0:
112
+ print("processed {} lines".format(i), file=sys.stderr)
113
+
114
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
115
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
fairseq/scripts/spm_train.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import absolute_import, division, print_function, unicode_literals
9
+
10
+ import sys
11
+
12
+ import sentencepiece as spm
13
+
14
+
15
+ if __name__ == "__main__":
16
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
fairseq/scripts/test_fsdp.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ rm -rf fsdp_dummy
3
+ mkdir -p fsdp_dummy
4
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
5
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
6
+ --cpu-offload --checkpoint-activations \
7
+ --task language_modeling --tokens-per-sample 256 --batch-size 8 \
8
+ --arch transformer_lm_gpt2_tiny \
9
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
10
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
11
+ --max-update 5 --log-format json --log-interval 1 \
12
+ --save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \
13
+ --restore-file x.pt "$@"
14
+
15
+ # Now we try to load the checkpoint
16
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
17
+ --ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
18
+ --cpu-offload --checkpoint-activations \
19
+ --task language_modeling --tokens-per-sample 256 --batch-size 8 \
20
+ --arch transformer_lm_gpt2_tiny \
21
+ --optimizer cpu_adam --adam-betas "(0.9,0.98)" \
22
+ --lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
23
+ --max-update 2 --log-format json --log-interval 1 \
24
+ --save-interval-updates 2 --save-dir fsdp_dummy
fairseq/tests/__init__.py ADDED
File without changes
fairseq/tests/tasks/test_masked_lm.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import unittest
8
+ from tempfile import TemporaryDirectory
9
+
10
+ from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
11
+ from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
12
+ from tests.utils import build_vocab, make_data
13
+
14
+
15
+ class TestMaskedLM(unittest.TestCase):
16
+ def test_masks_tokens(self):
17
+ with TemporaryDirectory() as dirname:
18
+
19
+ # prep input file
20
+ raw_file = os.path.join(dirname, "raw")
21
+ data = make_data(out_file=raw_file)
22
+ vocab = build_vocab(data)
23
+
24
+ # binarize
25
+ binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
26
+ split = "train"
27
+ bin_file = os.path.join(dirname, split)
28
+ FileBinarizer.multiprocess_dataset(
29
+ input_file=raw_file,
30
+ binarizer=binarizer,
31
+ dataset_impl="mmap",
32
+ vocab_size=len(vocab),
33
+ output_prefix=bin_file,
34
+ )
35
+
36
+ # setup task
37
+ cfg = MaskedLMConfig(
38
+ data=dirname,
39
+ seed=42,
40
+ mask_prob=0.5, # increasing the odds of masking
41
+ random_token_prob=0, # avoiding random tokens for exact match
42
+ leave_unmasked_prob=0, # always masking for exact match
43
+ )
44
+ task = MaskedLMTask(cfg, binarizer.dict)
45
+
46
+ original_dataset = task._load_dataset_split(bin_file, 1, False)
47
+
48
+ # load datasets
49
+ task.load_dataset(split)
50
+ masked_dataset = task.dataset(split)
51
+
52
+ mask_index = task.source_dictionary.index("<mask>")
53
+ iterator = task.get_batch_iterator(
54
+ dataset=masked_dataset,
55
+ max_tokens=65_536,
56
+ max_positions=4_096,
57
+ ).next_epoch_itr(shuffle=False)
58
+ for batch in iterator:
59
+ for sample in range(len(batch)):
60
+ net_input = batch["net_input"]
61
+ masked_src_tokens = net_input["src_tokens"][sample]
62
+ masked_src_length = net_input["src_lengths"][sample]
63
+ masked_tgt_tokens = batch["target"][sample]
64
+
65
+ sample_id = batch["id"][sample]
66
+ original_tokens = original_dataset[sample_id]
67
+ original_tokens = original_tokens.masked_select(
68
+ masked_src_tokens[:masked_src_length] == mask_index
69
+ )
70
+ masked_tokens = masked_tgt_tokens.masked_select(
71
+ masked_tgt_tokens != task.source_dictionary.pad()
72
+ )
73
+
74
+ assert masked_tokens.equal(original_tokens)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ unittest.main()
fairseq/tests/tasks/test_span_masked_lm.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import unittest
8
+ from tempfile import TemporaryDirectory
9
+
10
+ from fairseq import options
11
+ from fairseq.binarizer import FileBinarizer, VocabularyDatasetBinarizer
12
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
13
+ from fairseq.tasks.span_masked_lm import SpanMaskedLMTask
14
+ from tests.utils import build_vocab, make_data
15
+
16
+
17
+ class TestSpanMaskedLM(unittest.TestCase):
18
+ def test_masks_token_spans(self):
19
+ with TemporaryDirectory() as dirname:
20
+
21
+ # prep input file
22
+ raw_file = os.path.join(dirname, "raw")
23
+ data = make_data(out_file=raw_file)
24
+ vocab = build_vocab(data)
25
+
26
+ # binarize
27
+ binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False)
28
+ split = "train"
29
+ bin_file = os.path.join(dirname, split)
30
+ dataset_impl = "mmap"
31
+
32
+ FileBinarizer.multiprocess_dataset(
33
+ input_file=raw_file,
34
+ binarizer=binarizer,
35
+ dataset_impl=dataset_impl,
36
+ vocab_size=len(vocab),
37
+ output_prefix=bin_file,
38
+ )
39
+
40
+ # adding sentinel tokens
41
+ for i in range(100):
42
+ vocab.add_symbol(f"<extra_id_{i}>")
43
+
44
+ # setup task
45
+ train_args = options.parse_args_and_arch(
46
+ options.get_training_parser(),
47
+ [
48
+ "--task",
49
+ "span_masked_lm",
50
+ "--arch",
51
+ "bart_base",
52
+ "--seed",
53
+ "42",
54
+ dirname,
55
+ ],
56
+ )
57
+ cfg = convert_namespace_to_omegaconf(train_args)
58
+ task = SpanMaskedLMTask(cfg.task, binarizer.dict)
59
+
60
+ # load datasets
61
+ original_dataset = task._load_dataset_split(bin_file, 1, False)
62
+ task.load_dataset(split)
63
+ masked_dataset = task.dataset(split)
64
+
65
+ iterator = task.get_batch_iterator(
66
+ dataset=masked_dataset,
67
+ max_tokens=65_536,
68
+ max_positions=4_096,
69
+ ).next_epoch_itr(shuffle=False)
70
+ num_tokens = len(vocab)
71
+ for batch in iterator:
72
+ for sample in range(len(batch)):
73
+ sample_id = batch["id"][sample]
74
+ original_tokens = original_dataset[sample_id]
75
+ masked_src_tokens = batch["net_input"]["src_tokens"][sample]
76
+ masked_src_length = batch["net_input"]["src_lengths"][sample]
77
+ masked_tgt_tokens = batch["target"][sample]
78
+
79
+ original_offset = 0
80
+ masked_tgt_offset = 0
81
+ extra_id_token = len(vocab) - 1
82
+ for masked_src_token in masked_src_tokens[:masked_src_length]:
83
+ if masked_src_token == extra_id_token:
84
+ assert (
85
+ masked_src_token == masked_tgt_tokens[masked_tgt_offset]
86
+ )
87
+ extra_id_token -= 1
88
+ masked_tgt_offset += 1
89
+ while (
90
+ original_offset < len(original_tokens)
91
+ and masked_tgt_tokens[masked_tgt_offset]
92
+ != extra_id_token
93
+ ):
94
+ assert (
95
+ original_tokens[original_offset]
96
+ == masked_tgt_tokens[masked_tgt_offset]
97
+ )
98
+ original_offset += 1
99
+ masked_tgt_offset += 1
100
+ else:
101
+ assert original_tokens[original_offset] == masked_src_token
102
+ original_offset += 1
103
+
104
+
105
+ if __name__ == "__main__":
106
+ unittest.main()
fairseq/tests/test_activation_checkpointing.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ class Model(nn.Module):
15
+ def __init__(
16
+ self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs
17
+ ):
18
+ super().__init__()
19
+ torch.manual_seed(0)
20
+ self.use_pytorch_checkpoint = use_pytorch_checkpoint
21
+ self.ffn = nn.Sequential(
22
+ nn.Linear(32, 128),
23
+ # add a Dropout layer to test RNG save/restore
24
+ nn.Dropout(p=0.5),
25
+ nn.Linear(128, 32),
26
+ )
27
+ if use_fairseq_checkpoint:
28
+ self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
29
+ self.out = nn.Linear(32, 1)
30
+
31
+ def forward(self, x):
32
+ if self.use_pytorch_checkpoint:
33
+ x = checkpoint(self.ffn, x)
34
+ else:
35
+ x = self.ffn(x)
36
+ return self.out(x)
37
+
38
+
39
+ class TestActivationCheckpointing(unittest.TestCase):
40
+ def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
41
+ def get_loss_and_gnorm(model):
42
+ torch.manual_seed(1)
43
+ input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
44
+ model.zero_grad()
45
+ loss = model(input).sum()
46
+ loss.backward()
47
+ gnorm = torch.norm(
48
+ torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])
49
+ )
50
+ return {"loss": loss, "gnorm": gnorm}
51
+
52
+ model = Model().to(device)
53
+ no_cpt = get_loss_and_gnorm(model)
54
+
55
+ model = Model(use_pytorch_checkpoint=True).to(device)
56
+ pyt_cpt = get_loss_and_gnorm(model)
57
+ torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
58
+ torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
59
+
60
+ model = Model(use_fairseq_checkpoint=True).to(device)
61
+ fairseq_cpt = get_loss_and_gnorm(model)
62
+ torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
63
+ torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
64
+
65
+ model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
66
+ fairseq_cpt_offload = get_loss_and_gnorm(model)
67
+ torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
68
+ torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
69
+
70
+ def test_checkpoint_wrapper_cpu(self):
71
+ self._test_checkpoint_wrapper(device=torch.device("cpu"))
72
+
73
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
74
+ def test_checkpoint_wrapper_cuda(self):
75
+ self._test_checkpoint_wrapper(device=torch.device("cuda"))
76
+
77
+
78
+ if __name__ == "__main__":
79
+ unittest.main()
fairseq/tests/test_amp_optimizer.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import copy
8
+ import unittest
9
+
10
+ import torch
11
+ from torch.cuda.amp import GradScaler, autocast
12
+
13
+ from fairseq.optim import build_optimizer
14
+
15
+
16
+ @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
17
+ class TestGradientScalingAMP(unittest.TestCase):
18
+ def setUp(self):
19
+ self.x = torch.tensor([2.0]).cuda().half()
20
+ weight = 3.0
21
+ bias = 5.0
22
+ self.error = 1.0
23
+ self.target = torch.tensor([self.x * weight + bias + self.error]).cuda()
24
+ self.loss_fn = torch.nn.L1Loss()
25
+
26
+ self.model = torch.nn.Linear(1, 1)
27
+ self.model.weight.data = torch.tensor([[weight]])
28
+ self.model.bias.data = torch.tensor([bias])
29
+ self.model.cuda()
30
+ self.params = list(self.model.parameters())
31
+
32
+ self.namespace_dls = argparse.Namespace(
33
+ optimizer="adam",
34
+ lr=[0.1],
35
+ adam_betas="(0.9, 0.999)",
36
+ adam_eps=1e-8,
37
+ weight_decay=0.0,
38
+ threshold_loss_scale=1,
39
+ min_loss_scale=1e-4,
40
+ )
41
+ self.scaler = GradScaler(
42
+ init_scale=1,
43
+ growth_interval=1,
44
+ )
45
+
46
+ def run_iter(self, model, params, optimizer):
47
+ optimizer.zero_grad()
48
+ with autocast():
49
+ y = model(self.x)
50
+ loss = self.loss_fn(y, self.target)
51
+ self.scaler.scale(loss).backward()
52
+ self.assertEqual(loss, torch.tensor(1.0, device="cuda:0", dtype=torch.float16))
53
+
54
+ self.scaler.unscale_(optimizer)
55
+ grad_norm = optimizer.clip_grad_norm(0)
56
+ self.assertAlmostEqual(grad_norm.item(), 2.2361, 4)
57
+
58
+ self.scaler.step(optimizer)
59
+ self.scaler.update()
60
+ self.assertEqual(
61
+ model.weight,
62
+ torch.tensor([[3.1]], device="cuda:0", requires_grad=True),
63
+ )
64
+ self.assertEqual(
65
+ model.bias,
66
+ torch.tensor([5.1], device="cuda:0", requires_grad=True),
67
+ )
68
+ self.assertEqual(self.scaler.get_scale(), 2.0)
69
+
70
+ def test_automatic_mixed_precision(self):
71
+ model = copy.deepcopy(self.model)
72
+ params = list(model.parameters())
73
+ optimizer = build_optimizer(self.namespace_dls, params)
74
+
75
+ self.run_iter(model, params, optimizer)
fairseq/tests/test_average_checkpoints.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import os
8
+ import shutil
9
+ import tempfile
10
+ import unittest
11
+
12
+ import numpy as np
13
+ import torch
14
+ from scripts.average_checkpoints import average_checkpoints
15
+ from torch import nn
16
+
17
+
18
+ class ModelWithSharedParameter(nn.Module):
19
+ def __init__(self):
20
+ super(ModelWithSharedParameter, self).__init__()
21
+ self.embedding = nn.Embedding(1000, 200)
22
+ self.FC1 = nn.Linear(200, 200)
23
+ self.FC2 = nn.Linear(200, 200)
24
+ # tie weight in FC2 to FC1
25
+ self.FC2.weight = nn.Parameter(self.FC1.weight)
26
+ self.FC2.bias = nn.Parameter(self.FC1.bias)
27
+
28
+ self.relu = nn.ReLU()
29
+
30
+ def forward(self, input):
31
+ return self.FC2(self.ReLU(self.FC1(input))) + self.FC1(input)
32
+
33
+
34
+ class TestAverageCheckpoints(unittest.TestCase):
35
+ def test_average_checkpoints(self):
36
+ params_0 = collections.OrderedDict(
37
+ [
38
+ ("a", torch.DoubleTensor([100.0])),
39
+ ("b", torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
40
+ ("c", torch.IntTensor([7, 8, 9])),
41
+ ]
42
+ )
43
+ params_1 = collections.OrderedDict(
44
+ [
45
+ ("a", torch.DoubleTensor([1.0])),
46
+ ("b", torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
47
+ ("c", torch.IntTensor([2, 2, 2])),
48
+ ]
49
+ )
50
+ params_avg = collections.OrderedDict(
51
+ [
52
+ ("a", torch.DoubleTensor([50.5])),
53
+ ("b", torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
54
+ # We expect truncation for integer division
55
+ ("c", torch.IntTensor([4, 5, 5])),
56
+ ]
57
+ )
58
+
59
+ fd_0, path_0 = tempfile.mkstemp()
60
+ fd_1, path_1 = tempfile.mkstemp()
61
+ torch.save(collections.OrderedDict([("model", params_0)]), path_0)
62
+ torch.save(collections.OrderedDict([("model", params_1)]), path_1)
63
+
64
+ output = average_checkpoints([path_0, path_1])["model"]
65
+
66
+ os.close(fd_0)
67
+ os.remove(path_0)
68
+ os.close(fd_1)
69
+ os.remove(path_1)
70
+
71
+ for (k_expected, v_expected), (k_out, v_out) in zip(
72
+ params_avg.items(), output.items()
73
+ ):
74
+ self.assertEqual(
75
+ k_expected,
76
+ k_out,
77
+ "Key mismatch - expected {} but found {}. "
78
+ "(Expected list of keys: {} vs actual list of keys: {})".format(
79
+ k_expected, k_out, params_avg.keys(), output.keys()
80
+ ),
81
+ )
82
+ np.testing.assert_allclose(
83
+ v_expected.numpy(),
84
+ v_out.numpy(),
85
+ err_msg="Tensor value mismatch for key {}".format(k_expected),
86
+ )
87
+
88
+ def test_average_checkpoints_with_shared_parameters(self):
89
+ def _construct_model_with_shared_parameters(path, value):
90
+ m = ModelWithSharedParameter()
91
+ nn.init.constant_(m.FC1.weight, value)
92
+ torch.save({"model": m.state_dict()}, path)
93
+ return m
94
+
95
+ tmpdir = tempfile.mkdtemp()
96
+ paths = []
97
+ path = os.path.join(tmpdir, "m1.pt")
98
+ m1 = _construct_model_with_shared_parameters(path, 1.0)
99
+ paths.append(path)
100
+
101
+ path = os.path.join(tmpdir, "m2.pt")
102
+ m2 = _construct_model_with_shared_parameters(path, 2.0)
103
+ paths.append(path)
104
+
105
+ path = os.path.join(tmpdir, "m3.pt")
106
+ m3 = _construct_model_with_shared_parameters(path, 3.0)
107
+ paths.append(path)
108
+
109
+ new_model = average_checkpoints(paths)
110
+ self.assertTrue(
111
+ torch.equal(
112
+ new_model["model"]["embedding.weight"],
113
+ (m1.embedding.weight + m2.embedding.weight + m3.embedding.weight) / 3.0,
114
+ )
115
+ )
116
+
117
+ self.assertTrue(
118
+ torch.equal(
119
+ new_model["model"]["FC1.weight"],
120
+ (m1.FC1.weight + m2.FC1.weight + m3.FC1.weight) / 3.0,
121
+ )
122
+ )
123
+
124
+ self.assertTrue(
125
+ torch.equal(
126
+ new_model["model"]["FC2.weight"],
127
+ (m1.FC2.weight + m2.FC2.weight + m3.FC2.weight) / 3.0,
128
+ )
129
+ )
130
+ shutil.rmtree(tmpdir)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ unittest.main()
fairseq/tests/test_backtranslation_dataset.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import unittest
7
+
8
+ import tests.utils as test_utils
9
+ import torch
10
+ from fairseq.data import (
11
+ BacktranslationDataset,
12
+ LanguagePairDataset,
13
+ TransformEosDataset,
14
+ )
15
+ from fairseq.sequence_generator import SequenceGenerator
16
+
17
+
18
+ class TestBacktranslationDataset(unittest.TestCase):
19
+ def setUp(self):
20
+ (
21
+ self.tgt_dict,
22
+ self.w1,
23
+ self.w2,
24
+ self.src_tokens,
25
+ self.src_lengths,
26
+ self.model,
27
+ ) = test_utils.sequence_generator_setup()
28
+
29
+ dummy_src_samples = self.src_tokens
30
+
31
+ self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
32
+ self.cuda = torch.cuda.is_available()
33
+
34
+ def _backtranslation_dataset_helper(
35
+ self,
36
+ remove_eos_from_input_src,
37
+ remove_eos_from_output_src,
38
+ ):
39
+ tgt_dataset = LanguagePairDataset(
40
+ src=self.tgt_dataset,
41
+ src_sizes=self.tgt_dataset.sizes,
42
+ src_dict=self.tgt_dict,
43
+ tgt=None,
44
+ tgt_sizes=None,
45
+ tgt_dict=None,
46
+ )
47
+
48
+ generator = SequenceGenerator(
49
+ [self.model],
50
+ tgt_dict=self.tgt_dict,
51
+ max_len_a=0,
52
+ max_len_b=200,
53
+ beam_size=2,
54
+ unk_penalty=0,
55
+ )
56
+
57
+ backtranslation_dataset = BacktranslationDataset(
58
+ tgt_dataset=TransformEosDataset(
59
+ dataset=tgt_dataset,
60
+ eos=self.tgt_dict.eos(),
61
+ # remove eos from the input src
62
+ remove_eos_from_src=remove_eos_from_input_src,
63
+ ),
64
+ src_dict=self.tgt_dict,
65
+ backtranslation_fn=(
66
+ lambda sample: generator.generate([self.model], sample)
67
+ ),
68
+ output_collater=TransformEosDataset(
69
+ dataset=tgt_dataset,
70
+ eos=self.tgt_dict.eos(),
71
+ # if we remove eos from the input src, then we need to add it
72
+ # back to the output tgt
73
+ append_eos_to_tgt=remove_eos_from_input_src,
74
+ remove_eos_from_src=remove_eos_from_output_src,
75
+ ).collater,
76
+ cuda=self.cuda,
77
+ )
78
+ dataloader = torch.utils.data.DataLoader(
79
+ backtranslation_dataset,
80
+ batch_size=2,
81
+ collate_fn=backtranslation_dataset.collater,
82
+ )
83
+ backtranslation_batch_result = next(iter(dataloader))
84
+
85
+ eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
86
+
87
+ # Note that we sort by src_lengths and add left padding, so actually
88
+ # ids will look like: [1, 0]
89
+ expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
90
+ if remove_eos_from_output_src:
91
+ expected_src = expected_src[:, :-1]
92
+ expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
93
+ generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
94
+ tgt_tokens = backtranslation_batch_result["target"]
95
+
96
+ self.assertTensorEqual(expected_src, generated_src)
97
+ self.assertTensorEqual(expected_tgt, tgt_tokens)
98
+
99
+ def test_backtranslation_dataset_no_eos_in_output_src(self):
100
+ self._backtranslation_dataset_helper(
101
+ remove_eos_from_input_src=False,
102
+ remove_eos_from_output_src=True,
103
+ )
104
+
105
+ def test_backtranslation_dataset_with_eos_in_output_src(self):
106
+ self._backtranslation_dataset_helper(
107
+ remove_eos_from_input_src=False,
108
+ remove_eos_from_output_src=False,
109
+ )
110
+
111
+ def test_backtranslation_dataset_no_eos_in_input_src(self):
112
+ self._backtranslation_dataset_helper(
113
+ remove_eos_from_input_src=True,
114
+ remove_eos_from_output_src=False,
115
+ )
116
+
117
+ def assertTensorEqual(self, t1, t2):
118
+ self.assertEqual(t1.size(), t2.size(), "size mismatch")
119
+ self.assertEqual(t1.ne(t2).long().sum(), 0)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ unittest.main()